From 91439209f1ec18f703f195f18900c503496f57e0 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 20:40:32 +0000 Subject: [PATCH 001/102] Add agent swarm for parallel behavior investigation Implements a SLURM-based system for launching parallel Claude Code agents that investigate behaviors in SPD model decompositions. Key components: - spd-swarm CLI: Submits SLURM array job for N agents - Each agent starts isolated app backend (unique port, separate database) - Detailed system prompt guides agents through investigation methodology - Findings written to append-only JSONL files (events.jsonl, explanations.jsonl) New files: - spd/agent_swarm/schemas.py: BehaviorExplanation, SwarmEvent schemas - spd/agent_swarm/agent_prompt.py: Detailed API and methodology instructions - spd/agent_swarm/scripts/run_slurm_cli.py: CLI entry point - spd/agent_swarm/scripts/run_slurm.py: SLURM submission logic - spd/agent_swarm/scripts/run_agent.py: Worker script for each job Also adds SPD_APP_DB_PATH env var support for database isolation. https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- CLAUDE.md | 24 ++ pyproject.toml | 1 + spd/agent_swarm/CLAUDE.md | 124 +++++++++ spd/agent_swarm/__init__.py | 22 ++ spd/agent_swarm/agent_prompt.py | 330 +++++++++++++++++++++++ spd/agent_swarm/schemas.py | 120 +++++++++ spd/agent_swarm/scripts/__init__.py | 1 + spd/agent_swarm/scripts/run_agent.py | 284 +++++++++++++++++++ spd/agent_swarm/scripts/run_slurm.py | 119 ++++++++ spd/agent_swarm/scripts/run_slurm_cli.py | 62 +++++ spd/app/backend/database.py | 18 +- 11 files changed, 1103 insertions(+), 2 deletions(-) create mode 100644 spd/agent_swarm/CLAUDE.md create mode 100644 spd/agent_swarm/__init__.py create mode 100644 spd/agent_swarm/agent_prompt.py create mode 100644 spd/agent_swarm/schemas.py create mode 100644 spd/agent_swarm/scripts/__init__.py create mode 100644 spd/agent_swarm/scripts/run_agent.py create mode 100644 spd/agent_swarm/scripts/run_slurm.py create mode 100644 spd/agent_swarm/scripts/run_slurm_cli.py diff --git a/CLAUDE.md b/CLAUDE.md index 30da03eb6..2e636885c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -156,6 +156,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: ├── scripts/ # Standalone utility scripts ├── tests/ # Test suite ├── spd/ # Main source code +│ ├── agent_swarm/ # Parallel agent investigation (see agent_swarm/CLAUDE.md) │ ├── app/ # Web visualization app (see app/CLAUDE.md) │ ├── autointerp/ # LLM interpretation (see autointerp/CLAUDE.md) │ ├── clustering/ # Component clustering (see clustering/CLAUDE.md) @@ -195,6 +196,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: | `spd-autointerp` | `spd/autointerp/scripts/cli.py` | Submit autointerp SLURM job | | `spd-attributions` | `spd/dataset_attributions/scripts/run_slurm_cli.py` | Submit dataset attribution SLURM job | | `spd-clustering` | `spd/clustering/scripts/run_pipeline.py` | Clustering pipeline | +| `spd-swarm` | `spd/agent_swarm/scripts/run_slurm_cli.py` | Launch parallel agent swarm | ### Files to Skip When Searching @@ -231,6 +233,9 @@ Use `spd/` as the search root (not repo root) to avoid noise. **Clustering Pipeline:** - `spd-clustering` → `spd/clustering/scripts/run_pipeline.py` → `spd/utils/slurm.py` → `spd/clustering/scripts/run_clustering.py` +**Agent Swarm Pipeline:** +- `spd-swarm` → `spd/agent_swarm/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/agent_swarm/scripts/run_agent.py` → Claude Code + ## Common Usage Patterns ### Running Experiments Locally (`spd-local`) @@ -277,6 +282,25 @@ spd-autointerp # Submit SLURM job to interpret component Requires `OPENROUTER_API_KEY` env var. See `spd/autointerp/CLAUDE.md` for details. +### Agent Swarm for Parallel Investigation (`spd-swarm`) + +Launch a swarm of Claude Code agents to investigate behaviors in an SPD model: + +```bash +spd-swarm --n_agents 10 # Launch 10 parallel agents +spd-swarm --n_agents 5 --time 4:00:00 # Custom time limit +``` + +Each agent: +- Runs in its own SLURM job with 1 GPU +- Starts an isolated app backend instance +- Investigates behaviors using the SPD app API +- Writes findings to append-only JSONL files + +Output: `SPD_OUT_DIR/agent_swarm//task_*/explanations.jsonl` + +See `spd/agent_swarm/CLAUDE.md` for details. + ### Running on SLURM Cluster (`spd-run`) For the core team, `spd-run` provides full-featured SLURM orchestration: diff --git a/pyproject.toml b/pyproject.toml index 76a539454..24f47fe64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ spd-clustering = "spd.clustering.scripts.run_pipeline:cli" spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" +spd-swarm = "spd.agent_swarm.scripts.run_slurm_cli:cli" [build-system] requires = ["setuptools", "wheel"] diff --git a/spd/agent_swarm/CLAUDE.md b/spd/agent_swarm/CLAUDE.md new file mode 100644 index 000000000..ee2e89be2 --- /dev/null +++ b/spd/agent_swarm/CLAUDE.md @@ -0,0 +1,124 @@ +# Agent Swarm Module + +This module provides infrastructure for launching parallel SLURM-based Claude Code agents +that investigate behaviors in SPD model decompositions. + +## Overview + +The agent swarm system allows you to: +1. Launch many parallel agents (each as a SLURM job with 1 GPU) +2. Each agent runs an isolated app backend instance +3. Agents investigate behaviors using the SPD app API +4. Findings are written to append-only JSONL files + +## Usage + +```bash +# Launch 10 agents to investigate a decomposition +spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 10 + +# With custom settings +spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --context_length 64 --time 4:00:00 +``` + +## Architecture + +``` +spd/agent_swarm/ +├── __init__.py # Public exports +├── CLAUDE.md # This file +├── schemas.py # Pydantic models for outputs +├── agent_prompt.py # System prompt for agents +└── scripts/ + ├── __init__.py + ├── run_slurm_cli.py # CLI entry point (spd-swarm) + ├── run_slurm.py # SLURM submission logic + └── run_agent.py # Worker script (runs in each SLURM job) +``` + +## Output Structure + +``` +SPD_OUT_DIR/agent_swarm// +├── metadata.json # Swarm configuration +├── task_1/ +│ ├── events.jsonl # Progress and observations +│ ├── explanations.jsonl # Complete behavior explanations +│ ├── app.db # Isolated SQLite database +│ ├── agent_prompt.md # The prompt given to the agent +│ └── claude_output.txt # Raw Claude Code output +├── task_2/ +│ └── ... +└── task_N/ + └── ... +``` + +## Key Files + +| File | Purpose | +|------|---------| +| `schemas.py` | Defines `BehaviorExplanation`, `SwarmEvent`, `Evidence` schemas | +| `agent_prompt.py` | Contains detailed instructions for agents on using the API | +| `run_slurm.py` | Creates git snapshot, generates commands, submits SLURM array | +| `run_agent.py` | Starts backend, loads run, launches Claude Code | + +## Schemas + +### BehaviorExplanation +The primary output - documents a discovered behavior: +- `subject_prompt`: Prompt demonstrating the behavior +- `behavior_description`: What the model does +- `components_involved`: List of components and their roles +- `explanation`: How components work together +- `supporting_evidence`: Ablations, attributions, etc. +- `confidence`: high/medium/low +- `alternative_hypotheses`: Other considered explanations +- `limitations`: Known caveats + +### SwarmEvent +General logging: +- `event_type`: start, progress, observation, hypothesis, test_result, error, complete +- `timestamp`: When it occurred +- `message`: Human-readable description +- `details`: Structured data + +## Database Isolation + +Each agent gets its own SQLite database via the `SPD_APP_DB_PATH` environment variable. +This prevents conflicts when multiple agents run on the same machine. + +## Monitoring + +```bash +# Watch events from all agents +tail -f SPD_OUT_DIR/agent_swarm//task_*/events.jsonl + +# View all explanations +cat SPD_OUT_DIR/agent_swarm//task_*/explanations.jsonl | jq . + +# Check SLURM job status +squeue --me + +# View specific job logs +tail -f ~/slurm_logs/slurm-_.out +``` + +## Configuration + +CLI arguments: +- `wandb_path`: Required - WandB run path for the SPD decomposition +- `--n_agents`: Required - Number of parallel agents to launch +- `--context_length`: Token context length (default: 128) +- `--partition`: SLURM partition (default: h200-reserved) +- `--time`: Time limit per agent (default: 8:00:00) +- `--job_suffix`: Optional suffix for job names + +## Extending + +To modify agent behavior: +1. Edit `agent_prompt.py` to change investigation instructions +2. Update `schemas.py` to add new output fields +3. Modify `run_agent.py` to change the worker flow + +The agent prompt is the primary way to guide agent behavior - it contains +detailed API documentation and scientific methodology guidance. diff --git a/spd/agent_swarm/__init__.py b/spd/agent_swarm/__init__.py new file mode 100644 index 000000000..cac91de2d --- /dev/null +++ b/spd/agent_swarm/__init__.py @@ -0,0 +1,22 @@ +"""Agent Swarm: Parallel SLURM-based agent investigation of model behaviors. + +This module provides infrastructure for launching many parallel Claude Code agents, +each investigating behaviors in an SPD model decomposition. Each agent: +1. Starts an isolated app backend instance (separate database, unique port) +2. Receives detailed instructions on using the SPD app API +3. Investigates behaviors and writes findings to append-only JSONL files +""" + +from spd.agent_swarm.schemas import ( + BehaviorExplanation, + ComponentInfo, + Evidence, + SwarmEvent, +) + +__all__ = [ + "BehaviorExplanation", + "ComponentInfo", + "Evidence", + "SwarmEvent", +] diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py new file mode 100644 index 000000000..e3cf4fd21 --- /dev/null +++ b/spd/agent_swarm/agent_prompt.py @@ -0,0 +1,330 @@ +"""System prompt for SPD investigation agents. + +This module contains the detailed instructions given to each agent in the swarm. +The prompt explains how to use the SPD app API and the scientific methodology +for investigating model behaviors. +""" + +AGENT_SYSTEM_PROMPT = """ +# SPD Behavior Investigation Agent + +You are a research agent investigating behaviors in a neural network model decomposition. +Your goal is to find interesting behaviors, understand how components interact to produce +them, and document your findings as explanations. + +## Your Mission + +You are part of a swarm of agents, each independently investigating behaviors in the same +model. Your task is to: + +1. **Find a behavior**: Discover a prompt where the model does something interesting + (e.g., predicts the correct gendered pronoun, completes a pattern, etc.) + +2. **Understand the mechanism**: Figure out which components are involved and how they + work together to produce the behavior + +3. **Document your findings**: Write a clear explanation with supporting evidence + +## The SPD App Backend + +You have access to an SPD (Stochastic Parameter Decomposition) app backend running at: +`http://localhost:{port}` + +This app provides APIs for: +- Loading decomposed models +- Computing attribution graphs showing how components interact +- Optimizing sparse circuits for specific behaviors +- Running interventions (ablations) to test hypotheses +- Viewing component interpretations and correlations +- Searching the training dataset + +## API Reference + +### Health Check +```bash +curl http://localhost:{port}/api/health +# Returns: {{"status": "ok"}} +``` + +### Load a Run (ALREADY DONE FOR YOU) +The run is pre-loaded. Check status with: +```bash +curl http://localhost:{port}/api/status +``` + +### Create a Custom Prompt +To analyze a specific prompt: +```bash +curl -X POST "http://localhost:{port}/api/prompts/custom?text=The%20boy%20ate%20his" +# Returns: {{"id": , "token_ids": [...], "tokens": [...], "preview": "...", "next_token_probs": [...]}} +``` + +### Compute Optimized Attribution Graph (MOST IMPORTANT) +This optimizes a sparse circuit that achieves a behavior: +```bash +curl -X POST "http://localhost:{port}/api/graphs/optimized/stream?prompt_id=&loss_type=ce&loss_position=&label_token=&steps=100&imp_min_coeff=0.1&pnorm=0.5&mask_type=hard&loss_coeff=1.0&ci_threshold=0.01&normalize=target" +# Streams SSE events, final event has type="complete" with graph data +``` + +Parameters: +- `prompt_id`: ID from creating custom prompt +- `loss_type`: "ce" for cross-entropy (predicting specific token) or "kl" (matching full distribution) +- `loss_position`: Token position to optimize (0-indexed, usually last position) +- `label_token`: Token ID to predict (for CE loss) +- `steps`: Optimization steps (50-200 typical) +- `imp_min_coeff`: Importance minimization coefficient (0.05-0.3) +- `pnorm`: P-norm for sparsity (0.3-1.0, lower = sparser) +- `mask_type`: "hard" for binary masks, "soft" for continuous +- `ci_threshold`: Threshold for including nodes in graph (0.01-0.1) +- `normalize`: "target" normalizes by target layer, "none" for raw values + +### Get Component Interpretations +```bash +curl "http://localhost:{port}/api/correlations/interpretations" +# Returns: {{"h.0.mlp.c_fc:5": {{"label": "...", "confidence": "high"}}, ...}} +``` + +Get full interpretation details: +```bash +curl "http://localhost:{port}/api/correlations/interpretations/h.0.mlp.c_fc/5" +# Returns: {{"reasoning": "...", "prompt": "..."}} +``` + +### Get Component Token Statistics +```bash +curl "http://localhost:{port}/api/correlations/token_stats/h.0.mlp.c_fc/5?top_k=20" +# Returns input/output token associations +``` + +### Get Component Correlations +```bash +curl "http://localhost:{port}/api/correlations/components/h.0.mlp.c_fc/5?top_k=20" +# Returns components that frequently co-activate +``` + +### Run Intervention (Ablation) +Test a hypothesis by running the model with only selected components active: +```bash +curl -X POST "http://localhost:{port}/api/intervention/run" \\ + -H "Content-Type: application/json" \\ + -d '{{"graph_id": , "text": "The boy ate his", "selected_nodes": ["h.0.mlp.c_fc:3:5", "h.1.attn.o_proj:3:10"], "top_k": 10}}' +# Returns predictions with only selected components active vs full model +``` + +Node format: "layer:seq_pos:component_idx" +- `layer`: e.g., "h.0.mlp.c_fc", "h.1.attn.o_proj" +- `seq_pos`: Position in sequence (0-indexed) +- `component_idx`: Component index within layer + +### Search Dataset +Find prompts with specific patterns: +```bash +curl -X POST "http://localhost:{port}/api/dataset/search?query=she%20said&split=train" +curl "http://localhost:{port}/api/dataset/results?page=1&page_size=20" +``` + +### Get Random Samples with Loss +Find high/low loss examples: +```bash +curl "http://localhost:{port}/api/dataset/random_with_loss?n_samples=20&seed=42" +``` + +### Probe Component Activation +See how a component responds to arbitrary text: +```bash +curl -X POST "http://localhost:{port}/api/activation_contexts/probe" \\ + -H "Content-Type: application/json" \\ + -d '{{"text": "The boy ate his", "layer": "h.0.mlp.c_fc", "component_idx": 5}}' +# Returns CI values and activations at each position +``` + +### Get Dataset Attributions +See which components influence each other across the training data: +```bash +curl "http://localhost:{port}/api/dataset_attributions/h.0.mlp.c_fc/5?k=10" +# Returns positive/negative sources and targets +``` + +## Investigation Methodology + +### Step 1: Find an Interesting Behavior + +Start by exploring the model's behavior: + +1. **Search for patterns**: Use `/api/dataset/search` to find prompts with specific + linguistic patterns (pronouns, verb conjugations, completions, etc.) + +2. **Look at high-loss examples**: Use `/api/dataset/random_with_loss` to find where + the model struggles or succeeds + +3. **Create test prompts**: Use `/api/prompts/custom` to create prompts that test + specific capabilities + +Good behaviors to investigate: +- Gendered pronoun prediction ("The doctor said she" vs "The doctor said he") +- Subject-verb agreement ("The cats are" vs "The cat is") +- Pattern completion ("1, 2, 3," → "4") +- Semantic associations ("The capital of France is" → "Paris") +- Grammatical structure (completing sentences correctly) + +### Step 2: Optimize a Sparse Circuit + +Once you have a behavior: + +1. **Create the prompt** via `/api/prompts/custom` + +2. **Identify the target token**: What token should be predicted? Get its ID from + the tokenizer or from the prompt creation response. + +3. **Run optimization** via `/api/graphs/optimized/stream`: + - Use `loss_type=ce` with the target token + - Set `loss_position` to the position where prediction matters + - Start with `imp_min_coeff=0.1`, `pnorm=0.5`, `steps=100` + - Use `ci_threshold=0.01` to see active components + +4. **Examine the graph**: The response shows: + - `nodeCiVals`: Which components are active (high CI = important) + - `edges`: How components connect (gradient flow) + - `outputProbs`: Model predictions + +### Step 3: Understand Component Roles + +For each important component in the graph: + +1. **Check the interpretation**: Use `/api/correlations/interpretations//` + to see if we already have an idea what this component does + +2. **Look at token stats**: Use `/api/correlations/token_stats//` to see + what tokens activate this component (input) and what it predicts (output) + +3. **Check correlations**: Use `/api/correlations/components//` to see + what other components co-activate + +4. **Probe on variations**: Use `/api/activation_contexts/probe` to see how the + component responds to related prompts + +### Step 4: Test with Ablations + +Form hypotheses and test them: + +1. **Hypothesis**: "Component X stores information about gender" + +2. **Test**: Run intervention with and without component X + - If prediction changes as expected → supports hypothesis + - If no change → component may not be necessary for this + - If unexpected change → revise hypothesis + +3. **Control**: Try ablating other components to ensure specificity + +### Step 5: Document Your Findings + +Write a `BehaviorExplanation` with: +- Clear subject prompt +- Description of the behavior +- Components and their roles +- How they work together +- Supporting evidence from ablations/attributions +- Confidence level +- Alternative hypotheses you considered +- Limitations + +## Scientific Principles + +### Be Epistemologically Humble +- Your first hypothesis is probably wrong or incomplete +- Always consider alternative explanations +- A single confirming example doesn't prove a theory +- Look for disconfirming evidence + +### Be Bayesian +- Start with priors from component interpretations +- Update beliefs based on evidence +- Consider the probability of the evidence under different hypotheses +- Don't anchor too strongly on initial observations + +### Triangulate Evidence +- Don't rely on a single type of evidence +- Ablation results + attribution patterns + token stats together are stronger +- Look for convergent evidence from multiple sources + +### Document Uncertainty +- Be explicit about what you're confident in vs. uncertain about +- Note when evidence is weak or ambiguous +- Identify what additional tests would strengthen the explanation + +## Output Format + +Write your findings by appending to the output files: + +### events.jsonl +Log progress and observations: +```json +{{"event_type": "observation", "message": "Component h.0.mlp.c_fc:5 has high CI when subject is male", "details": {{"ci_value": 0.85}}, "timestamp": "..."}} +``` + +### explanations.jsonl +When you have a complete explanation: +```json +{{ + "subject_prompt": "The boy ate his lunch", + "behavior_description": "Correctly predicts gendered pronoun 'his' after male subject", + "components_involved": [ + {{"component_key": "h.0.mlp.c_fc:5", "role": "Encodes subject gender as male", "interpretation": "male names/subjects"}}, + {{"component_key": "h.1.attn.o_proj:10", "role": "Transmits gender information to output", "interpretation": null}} + ], + "explanation": "Component h.0.mlp.c_fc:5 activates on male subjects and stores gender information...", + "supporting_evidence": [ + {{"evidence_type": "ablation", "description": "Removing component causes prediction to change from 'his' to 'her'", "details": {{"without_component": {{"his": 0.1, "her": 0.6}}, "with_component": {{"his": 0.8, "her": 0.1}}}}}} + ], + "confidence": "medium", + "alternative_hypotheses": ["Component might encode broader concept of masculine entities, not just humans"], + "limitations": ["Only tested on simple subject-pronoun sentences"] +}} +``` + +## Getting Started + +1. Check the current status: `curl http://localhost:{port}/api/status` +2. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` +3. Search for interesting prompts or create your own +4. Optimize a sparse circuit for a behavior you find +5. Investigate the components involved +6. Test hypotheses with ablations +7. Document your findings + +Remember: You are exploring! Not every investigation will lead to a clear explanation. +Document what you learn, even if it's "this was more complicated than expected." + +Good luck, and happy investigating! +""" + + +def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) -> str: + """Generate the full agent prompt with runtime parameters filled in. + + Args: + port: The port the backend is running on. + wandb_path: The WandB path of the loaded run. + task_id: The SLURM task ID for this agent. + output_dir: Path to the agent's output directory. + + Returns: + The complete agent prompt with parameters substituted. + """ + prompt = AGENT_SYSTEM_PROMPT.format(port=port) + + runtime_context = f""" +## Runtime Context + +- **Backend URL**: http://localhost:{port} +- **Loaded Run**: {wandb_path} +- **Task ID**: {task_id} +- **Output Directory**: {output_dir} + +Your output files: +- `{output_dir}/events.jsonl` - Log events and observations here +- `{output_dir}/explanations.jsonl` - Write complete explanations here + +To append to these files, use the Write tool or shell redirection. +""" + return prompt + runtime_context diff --git a/spd/agent_swarm/schemas.py b/spd/agent_swarm/schemas.py new file mode 100644 index 000000000..d554db855 --- /dev/null +++ b/spd/agent_swarm/schemas.py @@ -0,0 +1,120 @@ +"""Schemas for agent swarm outputs. + +All agent outputs are append-only JSONL files. Each line is a JSON object +conforming to one of the schemas defined here. +""" + +from datetime import UTC, datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ComponentInfo(BaseModel): + """Information about a component involved in a behavior.""" + + component_key: str = Field( + ..., + description="Component key in format 'layer:component_idx' (e.g., 'h.0.mlp.c_fc:5')", + ) + role: str = Field( + ..., + description="The role this component plays in the behavior (e.g., 'stores subject gender')", + ) + interpretation: str | None = Field( + default=None, + description="Auto-interp label for this component if available", + ) + + +class Evidence(BaseModel): + """A piece of supporting evidence for an explanation.""" + + evidence_type: Literal["ablation", "attribution", "activation_pattern", "correlation", "other"] + description: str = Field( + ..., + description="Description of the evidence", + ) + details: dict[str, Any] = Field( + default_factory=dict, + description="Additional structured details (e.g., ablation results, attribution values)", + ) + + +class BehaviorExplanation(BaseModel): + """A candidate explanation for a behavior discovered by an agent. + + This is the primary output schema for agent investigations. Each explanation + describes a behavior (demonstrated by a subject prompt), the components involved, + and supporting evidence. + """ + + subject_prompt: str = Field( + ..., + description="A prompt that demonstrates the behavior being explained", + ) + behavior_description: str = Field( + ..., + description="Clear description of the behavior (e.g., 'correctly predicts gendered pronoun')", + ) + components_involved: list[ComponentInfo] = Field( + ..., + description="List of components involved in this behavior and their roles", + ) + explanation: str = Field( + ..., + description="Explanation of how the components work together to produce the behavior", + ) + supporting_evidence: list[Evidence] = Field( + default_factory=list, + description="Evidence supporting this explanation (ablations, attributions, etc.)", + ) + confidence: Literal["high", "medium", "low"] = Field( + ..., + description="Agent's confidence in this explanation", + ) + alternative_hypotheses: list[str] = Field( + default_factory=list, + description="Alternative hypotheses that were considered but not fully supported", + ) + limitations: list[str] = Field( + default_factory=list, + description="Known limitations of this explanation", + ) + + +class SwarmEvent(BaseModel): + """A generic event logged by an agent during investigation. + + Used for logging progress, observations, and other non-explanation events. + """ + + event_type: Literal[ + "start", + "progress", + "observation", + "hypothesis", + "test_result", + "explanation", + "error", + "complete", + ] + timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC)) + message: str + details: dict[str, Any] = Field(default_factory=dict) + + +class AgentOutput(BaseModel): + """Container for all outputs from a single agent run. + + Written to the agent's output directory as output.json upon completion. + """ + + task_id: int + wandb_path: str + started_at: datetime + completed_at: datetime | None = None + explanations: list[BehaviorExplanation] = Field(default_factory=list) + events: list[SwarmEvent] = Field(default_factory=list) + status: Literal["running", "completed", "failed"] = "running" + error: str | None = None diff --git a/spd/agent_swarm/scripts/__init__.py b/spd/agent_swarm/scripts/__init__.py new file mode 100644 index 000000000..9d0e8ed1b --- /dev/null +++ b/spd/agent_swarm/scripts/__init__.py @@ -0,0 +1 @@ +"""Agent swarm SLURM scripts.""" diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py new file mode 100644 index 000000000..83a32752c --- /dev/null +++ b/spd/agent_swarm/scripts/run_agent.py @@ -0,0 +1,284 @@ +"""Worker script that runs inside each SLURM job. + +This script: +1. Creates an isolated output directory for this agent +2. Starts the app backend with an isolated database +3. Loads the SPD run +4. Launches Claude Code with investigation instructions +5. Handles cleanup on exit +""" + +import os +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from types import FrameType + +import fire +import requests + +from spd.agent_swarm.agent_prompt import get_agent_prompt +from spd.agent_swarm.schemas import SwarmEvent +from spd.agent_swarm.scripts.run_slurm import get_swarm_output_dir +from spd.log import logger + + +def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: + """Find an available port starting from start_port.""" + for offset in range(max_attempts): + port = start_port + offset + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return port + except OSError: + continue + raise RuntimeError( + f"Could not find available port in range {start_port}-{start_port + max_attempts}" + ) + + +def wait_for_backend(port: int, timeout: float = 120.0) -> bool: + """Wait for the backend to become healthy.""" + url = f"http://localhost:{port}/api/health" + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + return True + except requests.exceptions.ConnectionError: + pass + time.sleep(1) + return False + + +def load_run(port: int, wandb_path: str, context_length: int) -> bool: + """Load the SPD run into the backend.""" + url = f"http://localhost:{port}/api/runs/load" + params = {"wandb_path": wandb_path, "context_length": context_length} + try: + resp = requests.post(url, params=params, timeout=300) + return resp.status_code == 200 + except Exception as e: + logger.error(f"Failed to load run: {e}") + return False + + +def log_event(events_path: Path, event: SwarmEvent) -> None: + """Append an event to the events log.""" + with open(events_path, "a") as f: + f.write(event.model_dump_json() + "\n") + + +def run_agent( + wandb_path: str, + task_id: int, + swarm_id: str, + context_length: int = 128, +) -> None: + """Run a single investigation agent. + + Args: + wandb_path: WandB path of the SPD run. + task_id: SLURM task ID (1-indexed). + swarm_id: Unique identifier for this swarm. + context_length: Context length for prompts. + """ + # Setup output directory + swarm_dir = get_swarm_output_dir(swarm_id) + task_dir = swarm_dir / f"task_{task_id}" + task_dir.mkdir(parents=True, exist_ok=True) + + events_path = task_dir / "events.jsonl" + explanations_path = task_dir / "explanations.jsonl" + db_path = task_dir / "app.db" + + # Initialize empty output files + explanations_path.touch() + + log_event( + events_path, + SwarmEvent( + event_type="start", + message=f"Agent {task_id} starting", + details={"wandb_path": wandb_path, "swarm_id": swarm_id}, + ), + ) + + # Find available port (offset by task_id to reduce collisions) + port = find_available_port(start_port=8000 + (task_id - 1) * 10) + logger.info(f"[Task {task_id}] Using port {port}") + + log_event( + events_path, + SwarmEvent( + event_type="progress", + message=f"Starting backend on port {port}", + details={"port": port, "db_path": str(db_path)}, + ), + ) + + # Start backend with isolated database + env = os.environ.copy() + env["SPD_APP_DB_PATH"] = str(db_path) + + backend_cmd = [ + sys.executable, + "-m", + "spd.app.backend.server", + "--port", + str(port), + ] + + backend_proc = subprocess.Popen( + backend_cmd, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + + # Setup cleanup handler + def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: + _ = frame # Unused but required by signal handler signature + logger.info(f"[Task {task_id}] Cleaning up...") + if backend_proc.poll() is None: + backend_proc.terminate() + try: + backend_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + backend_proc.kill() + if signum is not None: + sys.exit(1) + + signal.signal(signal.SIGTERM, cleanup) + signal.signal(signal.SIGINT, cleanup) + + try: + # Wait for backend to be ready + logger.info(f"[Task {task_id}] Waiting for backend...") + if not wait_for_backend(port): + log_event( + events_path, + SwarmEvent( + event_type="error", + message="Backend failed to start", + ), + ) + raise RuntimeError("Backend failed to start") + + logger.info(f"[Task {task_id}] Backend ready, loading run...") + log_event( + events_path, + SwarmEvent( + event_type="progress", + message="Backend ready, loading run", + ), + ) + + # Load the SPD run + if not load_run(port, wandb_path, context_length): + log_event( + events_path, + SwarmEvent( + event_type="error", + message="Failed to load run", + details={"wandb_path": wandb_path}, + ), + ) + raise RuntimeError(f"Failed to load run: {wandb_path}") + + logger.info(f"[Task {task_id}] Run loaded, launching Claude Code...") + log_event( + events_path, + SwarmEvent( + event_type="progress", + message="Run loaded, launching Claude Code agent", + ), + ) + + # Generate agent prompt + agent_prompt = get_agent_prompt( + port=port, + wandb_path=wandb_path, + task_id=task_id, + output_dir=str(task_dir), + ) + + # Write prompt to file for reference + prompt_path = task_dir / "agent_prompt.md" + prompt_path.write_text(agent_prompt) + + # Launch Claude Code + # The agent will investigate behaviors and write to the output files + claude_cmd = [ + "claude", + "--print", # Print output to stdout + "--dangerously-skip-permissions", # Allow file writes + ] + + logger.info(f"[Task {task_id}] Starting Claude Code session...") + + claude_proc = subprocess.Popen( + claude_cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + cwd=str(task_dir), + ) + + # Send the investigation prompt + investigation_request = f""" +{agent_prompt} + +--- + +Please begin your investigation. Start by checking the backend status and exploring +available component interpretations. Then find an interesting behavior and investigate it. + +Remember to log your progress to events.jsonl and write complete explanations to +explanations.jsonl when you discover something. +""" + + stdout, _ = claude_proc.communicate(input=investigation_request) + + # Save Claude's output + output_path = task_dir / "claude_output.txt" + output_path.write_text(stdout or "") + + log_event( + events_path, + SwarmEvent( + event_type="complete", + message="Investigation complete", + details={"exit_code": claude_proc.returncode}, + ), + ) + + logger.info(f"[Task {task_id}] Investigation complete") + + except Exception as e: + log_event( + events_path, + SwarmEvent( + event_type="error", + message=f"Agent failed: {e}", + details={"error_type": type(e).__name__}, + ), + ) + logger.error(f"[Task {task_id}] Failed: {e}") + raise + finally: + cleanup() + + +def cli() -> None: + fire.Fire(run_agent) + + +if __name__ == "__main__": + cli() diff --git a/spd/agent_swarm/scripts/run_slurm.py b/spd/agent_swarm/scripts/run_slurm.py new file mode 100644 index 000000000..8b99e253d --- /dev/null +++ b/spd/agent_swarm/scripts/run_slurm.py @@ -0,0 +1,119 @@ +"""SLURM launcher for agent swarm. + +Submits a SLURM array job where each task runs an independent agent investigating +behaviors in an SPD model decomposition. + +Each agent: +1. Starts an isolated app backend (unique port, isolated database) +2. Launches Claude Code with investigation instructions +3. Writes findings to append-only JSONL files +""" + +import secrets +from pathlib import Path + +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.git_utils import create_git_snapshot +from spd.utils.slurm import ( + SlurmArrayConfig, + generate_array_script, + submit_slurm_job, +) + + +def get_swarm_output_dir(swarm_id: str) -> Path: + """Get the output directory for a swarm run.""" + return SPD_OUT_DIR / "agent_swarm" / swarm_id + + +def launch_agent_swarm( + wandb_path: str, + n_agents: int, + context_length: int = 128, + partition: str = "h200-reserved", + time: str = "8:00:00", + job_suffix: str | None = None, +) -> None: + """Launch a swarm of agents to investigate behaviors. + + Args: + wandb_path: WandB run path for the SPD decomposition. + n_agents: Number of agents to launch. + context_length: Context length for prompts. + partition: SLURM partition. + time: Time limit per agent. + job_suffix: Optional suffix for job names. + """ + swarm_id = f"swarm-{secrets.token_hex(4)}" + output_dir = get_swarm_output_dir(swarm_id) + output_dir.mkdir(parents=True, exist_ok=True) + + snapshot_branch, commit_hash = create_git_snapshot(swarm_id) + logger.info(f"Created git snapshot: {snapshot_branch} ({commit_hash[:8]})") + + suffix = f"-{job_suffix}" if job_suffix else "" + job_name = f"spd-swarm{suffix}" + + # Write swarm metadata + metadata_path = output_dir / "metadata.json" + import json + + metadata = { + "swarm_id": swarm_id, + "wandb_path": wandb_path, + "n_agents": n_agents, + "context_length": context_length, + "snapshot_branch": snapshot_branch, + "commit_hash": commit_hash, + } + metadata_path.write_text(json.dumps(metadata, indent=2)) + + # Build worker commands (SLURM arrays are 1-indexed) + worker_commands = [] + for task_id in range(1, n_agents + 1): + cmd = ( + f"python -m spd.agent_swarm.scripts.run_agent " + f'"{wandb_path}" ' + f"--task_id {task_id} " + f"--swarm_id {swarm_id} " + f"--context_length {context_length}" + ) + worker_commands.append(cmd) + + array_config = SlurmArrayConfig( + job_name=job_name, + partition=partition, + n_gpus=1, + time=time, + snapshot_branch=snapshot_branch, + max_concurrent_tasks=min(n_agents, 8), # Respect cluster limits + ) + array_script = generate_array_script(array_config, worker_commands) + array_result = submit_slurm_job( + array_script, + "agent_swarm", + is_array=True, + n_array_tasks=n_agents, + ) + + logger.section("Agent swarm jobs submitted!") + logger.values( + { + "Swarm ID": swarm_id, + "WandB path": wandb_path, + "N agents": n_agents, + "Context length": context_length, + "Output directory": str(output_dir), + "Snapshot": f"{snapshot_branch} ({commit_hash[:8]})", + "Job ID": array_result.job_id, + "Logs": array_result.log_pattern, + "Script": str(array_result.script_path), + } + ) + logger.info("") + logger.info("Monitor progress:") + logger.info(f" tail -f {output_dir}/task_*/events.jsonl") + logger.info("") + logger.info("View explanations:") + logger.info(f" cat {output_dir}/task_*/explanations.jsonl | jq .") diff --git a/spd/agent_swarm/scripts/run_slurm_cli.py b/spd/agent_swarm/scripts/run_slurm_cli.py new file mode 100644 index 000000000..20a6d8457 --- /dev/null +++ b/spd/agent_swarm/scripts/run_slurm_cli.py @@ -0,0 +1,62 @@ +"""CLI entry point for agent swarm SLURM launcher. + +Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. + +Usage: + spd-swarm --n_agents 10 + spd-swarm --n_agents 5 --context_length 128 + +Examples: + # Launch 10 agents to investigate a decomposition + spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 10 + + # Launch 5 agents with custom context length and time limit + spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --context_length 64 --time 4:00:00 +""" + +import fire + +from spd.settings import DEFAULT_PARTITION_NAME + + +def main( + wandb_path: str, + n_agents: int, + context_length: int = 128, + partition: str = DEFAULT_PARTITION_NAME, + time: str = "8:00:00", + job_suffix: str | None = None, +) -> None: + """Launch a swarm of agents to investigate behaviors in an SPD model. + + Each agent runs in its own SLURM job with an isolated app backend instance. + Agents use Claude Code to investigate behaviors and write findings to + append-only JSONL files. + + Args: + wandb_path: WandB run path for the SPD decomposition to investigate. + Format: "entity/project/runs/run_id" or "wandb:entity/project/run_id" + n_agents: Number of agents to launch (each gets 1 GPU). + context_length: Context length for prompts (default 128). + partition: SLURM partition name. + time: Job time limit per agent (default 8 hours). + job_suffix: Optional suffix for SLURM job names. + """ + from spd.agent_swarm.scripts.run_slurm import launch_agent_swarm + + launch_agent_swarm( + wandb_path=wandb_path, + n_agents=n_agents, + context_length=context_length, + partition=partition, + time=time, + job_suffix=job_suffix, + ) + + +def cli() -> None: + fire.Fire(main) + + +if __name__ == "__main__": + cli() diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index e5ee4db59..1ee06ce5a 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -8,6 +8,7 @@ import hashlib import json +import os import sqlite3 from dataclasses import asdict from pathlib import Path @@ -23,8 +24,21 @@ GraphType = Literal["standard", "optimized", "manual"] # Persistent data directories +# Can be overridden via SPD_APP_DB_PATH environment variable for isolation _APP_DATA_DIR = REPO_ROOT / ".data" / "app" -DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" + + +def get_default_db_path() -> Path: + """Get the default database path, respecting SPD_APP_DB_PATH env var.""" + env_path = os.environ.get("SPD_APP_DB_PATH") + if env_path: + return Path(env_path) + return _DEFAULT_DB_PATH + + +# For backwards compatibility +DEFAULT_DB_PATH = _DEFAULT_DB_PATH class Run(BaseModel): @@ -107,7 +121,7 @@ class PromptAttrDB: """ def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): - self.db_path = db_path or DEFAULT_DB_PATH + self.db_path = db_path or get_default_db_path() self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None From 498d459e89f1360464dbdfce4a8681e9ab75f093 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 21:04:25 +0000 Subject: [PATCH 002/102] Stream Claude Code output to file in real-time Previously used communicate() which buffers all output until process completes. Now streams directly to claude_output.txt so you can monitor agent activity with: tail -f /claude_output.txt https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- spd/agent_swarm/scripts/run_agent.py | 39 +++++++++++++++------------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index 83a32752c..072d8a8c6 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -212,8 +212,8 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: prompt_path = task_dir / "agent_prompt.md" prompt_path.write_text(agent_prompt) - # Launch Claude Code - # The agent will investigate behaviors and write to the output files + # Launch Claude Code with output streaming to file + claude_output_path = task_dir / "claude_output.txt" claude_cmd = [ "claude", "--print", # Print output to stdout @@ -221,18 +221,21 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: ] logger.info(f"[Task {task_id}] Starting Claude Code session...") + logger.info(f"[Task {task_id}] Monitor with: tail -f {claude_output_path}") + + # Open output file for streaming writes + with open(claude_output_path, "w") as output_file: + claude_proc = subprocess.Popen( + claude_cmd, + stdin=subprocess.PIPE, + stdout=output_file, + stderr=subprocess.STDOUT, + text=True, + cwd=str(task_dir), + ) - claude_proc = subprocess.Popen( - claude_cmd, - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT, - text=True, - cwd=str(task_dir), - ) - - # Send the investigation prompt - investigation_request = f""" + # Send the investigation prompt and close stdin + investigation_request = f""" {agent_prompt} --- @@ -243,12 +246,12 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: Remember to log your progress to events.jsonl and write complete explanations to explanations.jsonl when you discover something. """ + assert claude_proc.stdin is not None + claude_proc.stdin.write(investigation_request) + claude_proc.stdin.close() - stdout, _ = claude_proc.communicate(input=investigation_request) - - # Save Claude's output - output_path = task_dir / "claude_output.txt" - output_path.write_text(stdout or "") + # Wait for Claude to finish (output streams to file in real-time) + claude_proc.wait() log_event( events_path, From efe5928ebafca30d30d9dcc1b9e40a1412b1fa19 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 21:31:05 +0000 Subject: [PATCH 003/102] Use stream-json output format and add max_turns limit - Switch to --output-format stream-json for structured JSONL output - Add --max-turns parameter (default 50) to prevent runaway agents - Output file changed from claude_output.txt to claude_output.jsonl - Updated monitoring commands in logs to use jq for parsing Monitor with: tail -f task_*/claude_output.jsonl | jq -r '.result // empty' https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- spd/agent_swarm/scripts/run_agent.py | 15 ++++++++++----- spd/agent_swarm/scripts/run_slurm.py | 10 +++++++++- spd/agent_swarm/scripts/run_slurm_cli.py | 9 ++++++--- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index 072d8a8c6..f048335e3 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -79,6 +79,7 @@ def run_agent( task_id: int, swarm_id: str, context_length: int = 128, + max_turns: int = 50, ) -> None: """Run a single investigation agent. @@ -87,6 +88,7 @@ def run_agent( task_id: SLURM task ID (1-indexed). swarm_id: Unique identifier for this swarm. context_length: Context length for prompts. + max_turns: Maximum agentic turns before stopping (prevents runaway agents). """ # Setup output directory swarm_dir = get_swarm_output_dir(swarm_id) @@ -212,16 +214,19 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: prompt_path = task_dir / "agent_prompt.md" prompt_path.write_text(agent_prompt) - # Launch Claude Code with output streaming to file - claude_output_path = task_dir / "claude_output.txt" + # Launch Claude Code with streaming JSON output + claude_output_path = task_dir / "claude_output.jsonl" claude_cmd = [ "claude", - "--print", # Print output to stdout - "--dangerously-skip-permissions", # Allow file writes + "--print", + "--output-format", "stream-json", # Structured JSONL for parsing + "--max-turns", str(max_turns), # Prevent runaway agents + "--dangerously-skip-permissions", ] - logger.info(f"[Task {task_id}] Starting Claude Code session...") + logger.info(f"[Task {task_id}] Starting Claude Code (max_turns={max_turns})...") logger.info(f"[Task {task_id}] Monitor with: tail -f {claude_output_path}") + logger.info(f"[Task {task_id}] Parse with: tail -f {claude_output_path} | jq -r '.result // empty'") # Open output file for streaming writes with open(claude_output_path, "w") as output_file: diff --git a/spd/agent_swarm/scripts/run_slurm.py b/spd/agent_swarm/scripts/run_slurm.py index 8b99e253d..f596e1ed9 100644 --- a/spd/agent_swarm/scripts/run_slurm.py +++ b/spd/agent_swarm/scripts/run_slurm.py @@ -31,6 +31,7 @@ def launch_agent_swarm( wandb_path: str, n_agents: int, context_length: int = 128, + max_turns: int = 50, partition: str = "h200-reserved", time: str = "8:00:00", job_suffix: str | None = None, @@ -41,6 +42,7 @@ def launch_agent_swarm( wandb_path: WandB run path for the SPD decomposition. n_agents: Number of agents to launch. context_length: Context length for prompts. + max_turns: Maximum agentic turns per agent (prevents runaway). partition: SLURM partition. time: Time limit per agent. job_suffix: Optional suffix for job names. @@ -64,6 +66,7 @@ def launch_agent_swarm( "wandb_path": wandb_path, "n_agents": n_agents, "context_length": context_length, + "max_turns": max_turns, "snapshot_branch": snapshot_branch, "commit_hash": commit_hash, } @@ -77,7 +80,8 @@ def launch_agent_swarm( f'"{wandb_path}" ' f"--task_id {task_id} " f"--swarm_id {swarm_id} " - f"--context_length {context_length}" + f"--context_length {context_length} " + f"--max_turns {max_turns}" ) worker_commands.append(cmd) @@ -104,6 +108,7 @@ def launch_agent_swarm( "WandB path": wandb_path, "N agents": n_agents, "Context length": context_length, + "Max turns": max_turns, "Output directory": str(output_dir), "Snapshot": f"{snapshot_branch} ({commit_hash[:8]})", "Job ID": array_result.job_id, @@ -115,5 +120,8 @@ def launch_agent_swarm( logger.info("Monitor progress:") logger.info(f" tail -f {output_dir}/task_*/events.jsonl") logger.info("") + logger.info("Monitor Claude output (stream-json):") + logger.info(f" tail -f {output_dir}/task_*/claude_output.jsonl | jq -r '.result // empty'") + logger.info("") logger.info("View explanations:") logger.info(f" cat {output_dir}/task_*/explanations.jsonl | jq .") diff --git a/spd/agent_swarm/scripts/run_slurm_cli.py b/spd/agent_swarm/scripts/run_slurm_cli.py index 20a6d8457..9b75ce95f 100644 --- a/spd/agent_swarm/scripts/run_slurm_cli.py +++ b/spd/agent_swarm/scripts/run_slurm_cli.py @@ -4,14 +4,14 @@ Usage: spd-swarm --n_agents 10 - spd-swarm --n_agents 5 --context_length 128 + spd-swarm --n_agents 5 --max_turns 30 Examples: # Launch 10 agents to investigate a decomposition spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 10 - # Launch 5 agents with custom context length and time limit - spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --context_length 64 --time 4:00:00 + # Launch 5 agents with custom settings + spd-swarm goodfire-ai/spd/runs/abc123 --n_agents 5 --max_turns 30 --time 4:00:00 """ import fire @@ -23,6 +23,7 @@ def main( wandb_path: str, n_agents: int, context_length: int = 128, + max_turns: int = 50, partition: str = DEFAULT_PARTITION_NAME, time: str = "8:00:00", job_suffix: str | None = None, @@ -38,6 +39,7 @@ def main( Format: "entity/project/runs/run_id" or "wandb:entity/project/run_id" n_agents: Number of agents to launch (each gets 1 GPU). context_length: Context length for prompts (default 128). + max_turns: Maximum agentic turns per agent (default 50, prevents runaway). partition: SLURM partition name. time: Job time limit per agent (default 8 hours). job_suffix: Optional suffix for SLURM job names. @@ -48,6 +50,7 @@ def main( wandb_path=wandb_path, n_agents=n_agents, context_length=context_length, + max_turns=max_turns, partition=partition, time=time, job_suffix=job_suffix, From ef5b0fd80ee91eb9ed5c392ef00581693de964c8 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 30 Jan 2026 22:08:49 +0000 Subject: [PATCH 004/102] Fix stream-json output requiring --verbose flag Claude Code requires --verbose when using --output-format=stream-json with --print mode. https://claude.ai/code/session_01UMpYFZ3A98vsPkqoq6zvT6 --- spd/agent_swarm/scripts/run_agent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index f048335e3..b37f4e451 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -219,8 +219,9 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: claude_cmd = [ "claude", "--print", - "--output-format", "stream-json", # Structured JSONL for parsing - "--max-turns", str(max_turns), # Prevent runaway agents + "--verbose", # Required for stream-json output + "--output-format", "stream-json", + "--max-turns", str(max_turns), "--dangerously-skip-permissions", ] From f40f02e443bdd7099fd11d1bdf56915e797ea09f Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Fri, 30 Jan 2026 23:24:50 +0000 Subject: [PATCH 005/102] Add GPU lock to prevent concurrent GPU operations When multiple GPU-intensive requests are made concurrently (graph computation, optimization, intervention), the backend would hang. This adds a lock that returns HTTP 503 immediately if a GPU operation is already in progress, allowing clients to retry later. Co-Authored-By: Claude Opus 4.5 --- spd/agent_swarm/scripts/run_agent.py | 10 ++- spd/app/backend/routers/graphs.py | 61 +++++++++------ spd/app/backend/routers/intervention.py | 100 +++++++++++++----------- spd/app/backend/state.py | 23 ++++++ 4 files changed, 121 insertions(+), 73 deletions(-) diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index b37f4e451..c41c8f30c 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -220,14 +220,18 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: "claude", "--print", "--verbose", # Required for stream-json output - "--output-format", "stream-json", - "--max-turns", str(max_turns), + "--output-format", + "stream-json", + "--max-turns", + str(max_turns), "--dangerously-skip-permissions", ] logger.info(f"[Task {task_id}] Starting Claude Code (max_turns={max_turns})...") logger.info(f"[Task {task_id}] Monitor with: tail -f {claude_output_path}") - logger.info(f"[Task {task_id}] Parse with: tail -f {claude_output_path} | jq -r '.result // empty'") + logger.info( + f"[Task {task_id}] Parse with: tail -f {claude_output_path} | jq -r '.result // empty'" + ) # Open output file for streaming writes with open(claude_output_path, "w") as output_file: diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index c0f478744..1ea02e5c6 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -240,8 +240,20 @@ def build_out_probs( def stream_computation( work: Callable[[ProgressCallback], GraphData | GraphDataWithOptimization], + gpu_lock: threading.Lock, ) -> StreamingResponse: - """Run graph computation in a thread with SSE streaming for progress updates.""" + """Run graph computation in a thread with SSE streaming for progress updates. + + Acquires gpu_lock before starting and holds it until computation completes. + Raises 503 if the lock is already held by another operation. + """ + # Try to acquire lock non-blocking - fail fast if GPU is busy + if not gpu_lock.acquire(blocking=False): + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() def on_progress(current: int, total: int, stage: str) -> None: @@ -256,28 +268,31 @@ def compute_thread() -> None: progress_queue.put({"type": "error", "error": str(e)}) def generate() -> Generator[str]: - thread = threading.Thread(target=compute_thread) - thread.start() - - while True: - try: - msg = progress_queue.get(timeout=0.1) - except queue.Empty: - if not thread.is_alive(): + try: + thread = threading.Thread(target=compute_thread) + thread.start() + + while True: + try: + msg = progress_queue.get(timeout=0.1) + except queue.Empty: + if not thread.is_alive(): + break + continue + + if msg["type"] == "progress": + yield f"data: {json.dumps(msg)}\n\n" + elif msg["type"] == "error": + yield f"data: {json.dumps(msg)}\n\n" + break + elif msg["type"] == "result": + complete_data = {"type": "complete", "data": msg["result"].model_dump()} + yield f"data: {json.dumps(complete_data)}\n\n" break - continue - - if msg["type"] == "progress": - yield f"data: {json.dumps(msg)}\n\n" - elif msg["type"] == "error": - yield f"data: {json.dumps(msg)}\n\n" - break - elif msg["type"] == "result": - complete_data = {"type": "complete", "data": msg["result"].model_dump()} - yield f"data: {json.dumps(complete_data)}\n\n" - break - thread.join() + thread.join() + finally: + gpu_lock.release() return StreamingResponse(generate(), media_type="text/event-stream") @@ -456,7 +471,7 @@ def work(on_progress: ProgressCallback) -> GraphData: l0_total=len(filtered_node_ci_vals), ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _edge_to_edge_data(edge: Edge) -> EdgeData: @@ -660,7 +675,7 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: ), ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _add_pseudo_layer_nodes( diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 4c46e136c..a8fcb3fbc 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -148,45 +148,48 @@ def _run_intervention_forward( @router.post("") @log_errors -def run_intervention(request: InterventionRequest, loaded: DepLoadedRun) -> InterventionResponse: +def run_intervention( + request: InterventionRequest, loaded: DepLoadedRun, manager: DepStateManager +) -> InterventionResponse: """Run intervention forward pass with specified nodes active (legacy endpoint).""" - token_ids = loaded.tokenizer.encode(request.text, add_special_tokens=False) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [(n.layer, n.seq_pos, n.component_idx) for n in request.nodes] - - seq_len = tokens.shape[1] - for _, seq_pos, _ in active_nodes: - if seq_pos >= seq_len: - raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=request.top_k, - tokenizer=loaded.tokenizer, - ) + with manager.gpu_lock(): + token_ids = loaded.tokenizer.encode(request.text, add_special_tokens=False) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + active_nodes = [(n.layer, n.seq_pos, n.component_idx) for n in request.nodes] + + seq_len = tokens.shape[1] + for _, seq_pos, _ in active_nodes: + if seq_pos >= seq_len: + raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") + + result = compute_intervention_forward( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + top_k=request.top_k, + tokenizer=loaded.tokenizer, + ) - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions + predictions_per_position = [ + [ + TokenPrediction( + token=token, + token_id=token_id, + spd_prob=spd_prob, + target_prob=target_prob, + logit=logit, + target_logit=target_logit, + ) + for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions + ] + for pos_predictions in result.predictions_per_position ] - for pos_predictions in result.predictions_per_position - ] - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) + return InterventionResponse( + input_tokens=result.input_tokens, + predictions_per_position=predictions_per_position, + ) @router.post("/run") @@ -195,14 +198,16 @@ def run_and_save_intervention( request: RunInterventionRequest, loaded: DepLoadedRun, db: DepDB, + manager: DepStateManager, ) -> InterventionRunSummary: """Run an intervention and save the result.""" - response = _run_intervention_forward( - text=request.text, - selected_nodes=request.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + response = _run_intervention_forward( + text=request.text, + selected_nodes=request.selected_nodes, + top_k=request.top_k, + loaded=loaded, + ) run_id = db.save_intervention_run( graph_id=request.graph_id, @@ -310,12 +315,13 @@ def fork_intervention_run( modified_text = loaded.tokenizer.decode(modified_token_ids) # Run the intervention forward pass with modified tokens but same selected nodes - response = _run_intervention_forward( - text=modified_text, - selected_nodes=parent_run.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + response = _run_intervention_forward( + text=modified_text, + selected_nodes=parent_run.selected_nodes, + top_k=request.top_k, + loaded=loaded, + ) # Save the forked run fork_id = db.save_forked_intervention_run( diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index 47dacfe51..7364ff1d1 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -5,9 +5,13 @@ - StateManager: Singleton managing app-wide state with proper lifecycle """ +import threading +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any +from fastapi import HTTPException from transformers.tokenization_utils_base import PreTrainedTokenizerBase from spd.app.backend.database import PromptAttrDB, Run @@ -147,6 +151,7 @@ class StateManager: def __init__(self) -> None: self._state: AppState | None = None + self._gpu_lock = threading.Lock() @classmethod def get(cls) -> "StateManager": @@ -189,3 +194,21 @@ def close(self) -> None: """Clean up resources.""" if self._state is not None: self._state.db.close() + + @contextmanager + def gpu_lock(self) -> Generator[None]: + """Acquire GPU lock or fail with 503 if another GPU operation is in progress. + + Use this for GPU-intensive endpoints to prevent concurrent operations + that would cause the server to hang. + """ + acquired = self._gpu_lock.acquire(blocking=False) + if not acquired: + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + try: + yield + finally: + self._gpu_lock.release() From 567fb198c938513c2a348c61dacc97f9435e45b8 Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Fri, 30 Jan 2026 23:38:11 +0000 Subject: [PATCH 006/102] Add research_log.md for human-readable agent progress Agents now create and update a research_log.md file with readable progress updates. This makes it easy to follow what the agent is doing and discovering without parsing JSONL files. Co-Authored-By: Claude Opus 4.5 --- spd/agent_swarm/CLAUDE.md | 11 ++++- spd/agent_swarm/agent_prompt.py | 70 +++++++++++++++++++++++----- spd/agent_swarm/scripts/run_agent.py | 14 ++++-- 3 files changed, 77 insertions(+), 18 deletions(-) diff --git a/spd/agent_swarm/CLAUDE.md b/spd/agent_swarm/CLAUDE.md index ee2e89be2..ee4a57db4 100644 --- a/spd/agent_swarm/CLAUDE.md +++ b/spd/agent_swarm/CLAUDE.md @@ -42,11 +42,12 @@ spd/agent_swarm/ SPD_OUT_DIR/agent_swarm// ├── metadata.json # Swarm configuration ├── task_1/ -│ ├── events.jsonl # Progress and observations +│ ├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) +│ ├── events.jsonl # Structured progress and observations │ ├── explanations.jsonl # Complete behavior explanations │ ├── app.db # Isolated SQLite database │ ├── agent_prompt.md # The prompt given to the agent -│ └── claude_output.txt # Raw Claude Code output +│ └── claude_output.jsonl # Raw Claude Code output (stream-json format) ├── task_2/ │ └── ... └── task_N/ @@ -90,6 +91,12 @@ This prevents conflicts when multiple agents run on the same machine. ## Monitoring ```bash +# Watch research logs (best way to follow agent progress) +tail -f SPD_OUT_DIR/agent_swarm//task_*/research_log.md + +# Watch a specific agent's research log +cat SPD_OUT_DIR/agent_swarm//task_1/research_log.md + # Watch events from all agents tail -f SPD_OUT_DIR/agent_swarm//task_*/events.jsonl diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py index e3cf4fd21..8ee02d6dd 100644 --- a/spd/agent_swarm/agent_prompt.py +++ b/spd/agent_swarm/agent_prompt.py @@ -254,10 +254,50 @@ ## Output Format -Write your findings by appending to the output files: +Write your findings to the output files. **The research log is your primary output for humans to read.** + +### research_log.md (MOST IMPORTANT - Write here frequently!) +This is a human-readable log of your investigation. Write here often so someone can follow your progress. +Use clear markdown formatting: + +```markdown +## [HH:MM] Starting Investigation + +Looking at component interpretations to find interesting patterns... + +## [HH:MM] Hypothesis: Gendered Pronoun Circuit + +Found components that seem related to pronouns: +- h.0.mlp.c_fc:42 - "he/his pronouns after male subjects" +- h.0.mlp.c_fc:89 - "she/her pronouns after female subjects" + +Testing with prompt: "The boy said that he" + +## [HH:MM] Optimization Results + +Ran optimization for "he" prediction at position 4: +- Found 15 active components +- Key components: h.0.mlp.c_fc:42 (CI=0.92), h.1.attn.o_proj:156 (CI=0.78) + +## [HH:MM] Ablation Test + +Ablating h.0.mlp.c_fc:42: +- Before: P(he)=0.82, P(she)=0.11 +- After: P(he)=0.23, P(she)=0.45 + +This confirms the component is important for masculine pronoun prediction! + +## [HH:MM] Conclusion + +Found a circuit for gendered pronoun prediction. Components h.0.mlp.c_fc:42 and +h.1.attn.o_proj:156 work together to predict masculine pronouns after male subjects. +``` + +**IMPORTANT**: Update the research log every few minutes with your current progress, +findings, and next steps. This is how humans monitor your work! ### events.jsonl -Log progress and observations: +Log structured progress and observations: ```json {{"event_type": "observation", "message": "Component h.0.mlp.c_fc:5 has high CI when subject is male", "details": {{"ci_value": 0.85}}, "timestamp": "..."}} ``` @@ -284,15 +324,20 @@ ## Getting Started -1. Check the current status: `curl http://localhost:{port}/api/status` -2. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` -3. Search for interesting prompts or create your own -4. Optimize a sparse circuit for a behavior you find -5. Investigate the components involved -6. Test hypotheses with ablations -7. Document your findings +1. **Create your research log**: Start by creating `research_log.md` with a header +2. Check the current status: `curl http://localhost:{port}/api/status` +3. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` +4. Search for interesting prompts or create your own +5. **Update research_log.md** with what you're investigating +6. Optimize a sparse circuit for a behavior you find +7. Investigate the components involved +8. Test hypotheses with ablations +9. **Update research_log.md** with findings +10. Document complete explanations in `explanations.jsonl` + +**Remember to update research_log.md frequently** - this is how humans follow your progress! -Remember: You are exploring! Not every investigation will lead to a clear explanation. +You are exploring! Not every investigation will lead to a clear explanation. Document what you learn, even if it's "this was more complicated than expected." Good luck, and happy investigating! @@ -322,9 +367,10 @@ def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) - **Output Directory**: {output_dir} Your output files: -- `{output_dir}/events.jsonl` - Log events and observations here +- `{output_dir}/research_log.md` - **PRIMARY OUTPUT** - Write readable progress updates here frequently! +- `{output_dir}/events.jsonl` - Log structured events and observations here - `{output_dir}/explanations.jsonl` - Write complete explanations here -To append to these files, use the Write tool or shell redirection. +**Start by creating research_log.md with a header, then update it every few minutes!** """ return prompt + runtime_context diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index c41c8f30c..3c5a78449 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -250,11 +250,17 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: --- -Please begin your investigation. Start by checking the backend status and exploring -available component interpretations. Then find an interesting behavior and investigate it. +Please begin your investigation: -Remember to log your progress to events.jsonl and write complete explanations to -explanations.jsonl when you discover something. +1. **FIRST**: Create `{task_dir}/research_log.md` with a header like "# Research Log - Task {task_id}" +2. Check the backend status and explore component interpretations +3. Find an interesting behavior to investigate +4. **Update research_log.md frequently** with your progress, findings, and next steps + +Remember: +- research_log.md is your PRIMARY output - humans will read this to follow your work +- Update it every few minutes with what you're doing and discovering +- Write complete explanations to explanations.jsonl when you finish investigating a behavior """ assert claude_proc.stdin is not None claude_proc.stdin.write(investigation_request) From 4c4a843b2bb6a9b6e7fb67508e6a6c63960494fc Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Fri, 30 Jan 2026 23:56:57 +0000 Subject: [PATCH 007/102] Add full timestamps to research log examples Show YYYY-MM-DD HH:MM:SS format and provide tip for getting timestamps. Co-Authored-By: Claude Opus 4.5 --- spd/agent_swarm/agent_prompt.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py index 8ee02d6dd..b469b47b5 100644 --- a/spd/agent_swarm/agent_prompt.py +++ b/spd/agent_swarm/agent_prompt.py @@ -261,11 +261,11 @@ Use clear markdown formatting: ```markdown -## [HH:MM] Starting Investigation +## [2026-01-30 14:23:15] Starting Investigation Looking at component interpretations to find interesting patterns... -## [HH:MM] Hypothesis: Gendered Pronoun Circuit +## [2026-01-30 14:25:42] Hypothesis: Gendered Pronoun Circuit Found components that seem related to pronouns: - h.0.mlp.c_fc:42 - "he/his pronouns after male subjects" @@ -273,13 +273,13 @@ Testing with prompt: "The boy said that he" -## [HH:MM] Optimization Results +## [2026-01-30 14:28:03] Optimization Results Ran optimization for "he" prediction at position 4: - Found 15 active components - Key components: h.0.mlp.c_fc:42 (CI=0.92), h.1.attn.o_proj:156 (CI=0.78) -## [HH:MM] Ablation Test +## [2026-01-30 14:31:17] Ablation Test Ablating h.0.mlp.c_fc:42: - Before: P(he)=0.82, P(she)=0.11 @@ -287,12 +287,14 @@ This confirms the component is important for masculine pronoun prediction! -## [HH:MM] Conclusion +## [2026-01-30 14:35:44] Conclusion Found a circuit for gendered pronoun prediction. Components h.0.mlp.c_fc:42 and h.1.attn.o_proj:156 work together to predict masculine pronouns after male subjects. ``` +**TIP**: Get the current timestamp with `date '+%Y-%m-%d %H:%M:%S'` for your log entries. + **IMPORTANT**: Update the research log every few minutes with your current progress, findings, and next steps. This is how humans monitor your work! From cb6e6f063808af3f46f976f41531dd9ff92b4351 Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Sat, 31 Jan 2026 19:08:07 +0000 Subject: [PATCH 008/102] wip: Integrate agent swarm with MCP for Claude Code tool access --- spd/agent_swarm/CLAUDE.md | 69 +- spd/agent_swarm/agent_prompt.py | 369 ++---- spd/agent_swarm/scripts/run_agent.py | 52 +- spd/app/CLAUDE.md | 3 +- spd/app/backend/routers/__init__.py | 4 + spd/app/backend/routers/investigations.py | 262 ++++ spd/app/backend/routers/mcp.py | 1171 +++++++++++++++++ spd/app/backend/server.py | 29 + .../src/components/InvestigationsTab.svelte | 497 +++++++ .../frontend/src/components/RunView.svelte | 15 +- spd/app/frontend/src/lib/api/index.ts | 1 + .../frontend/src/lib/api/investigations.ts | 55 + 12 files changed, 2211 insertions(+), 316 deletions(-) create mode 100644 spd/app/backend/routers/investigations.py create mode 100644 spd/app/backend/routers/mcp.py create mode 100644 spd/app/frontend/src/components/InvestigationsTab.svelte create mode 100644 spd/app/frontend/src/lib/api/investigations.ts diff --git a/spd/agent_swarm/CLAUDE.md b/spd/agent_swarm/CLAUDE.md index ee4a57db4..48fabc504 100644 --- a/spd/agent_swarm/CLAUDE.md +++ b/spd/agent_swarm/CLAUDE.md @@ -7,9 +7,10 @@ that investigate behaviors in SPD model decompositions. The agent swarm system allows you to: 1. Launch many parallel agents (each as a SLURM job with 1 GPU) -2. Each agent runs an isolated app backend instance -3. Agents investigate behaviors using the SPD app API -4. Findings are written to append-only JSONL files +2. Each agent runs an isolated app backend instance with MCP support +3. Agents investigate behaviors using SPD tools via MCP (Model Context Protocol) +4. Progress is streamed in real-time via MCP SSE events +5. Findings are written to append-only JSONL files ## Usage @@ -36,22 +37,55 @@ spd/agent_swarm/ └── run_agent.py # Worker script (runs in each SLURM job) ``` +## MCP Tools + +Agents access ALL SPD functionality via MCP (Model Context Protocol). The backend exposes +these tools at `/mcp`. Agents don't need file system access - everything is done through MCP. + +**Analysis Tools:** + +| Tool | Description | +|------|-------------| +| `optimize_graph` | Find minimal circuit for a behavior (streams progress) | +| `get_component_info` | Get component interpretation, token stats, correlations | +| `run_ablation` | Test circuit by running with selected components only | +| `search_dataset` | Search SimpleStories training data for patterns | +| `create_prompt` | Tokenize text and get next-token probabilities | + +**Output Tools:** + +| Tool | Description | +|------|-------------| +| `update_research_log` | Append content to the agent's research log (PRIMARY OUTPUT) | +| `save_explanation` | Save a complete, validated behavior explanation | +| `set_investigation_summary` | Set title and summary shown in the investigations UI | +| `submit_suggestion` | Submit ideas for improving the tools or system | + +The `optimize_graph` tool streams progress events via SSE, giving real-time visibility +into long-running optimization operations. + +Suggestions from all agents are collected in `SPD_OUT_DIR/agent_swarm/suggestions.jsonl` (global file). + ## Output Structure ``` -SPD_OUT_DIR/agent_swarm// -├── metadata.json # Swarm configuration -├── task_1/ -│ ├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) -│ ├── events.jsonl # Structured progress and observations -│ ├── explanations.jsonl # Complete behavior explanations -│ ├── app.db # Isolated SQLite database -│ ├── agent_prompt.md # The prompt given to the agent -│ └── claude_output.jsonl # Raw Claude Code output (stream-json format) -├── task_2/ -│ └── ... -└── task_N/ - └── ... +SPD_OUT_DIR/agent_swarm/ +├── suggestions.jsonl # System improvement suggestions from ALL agents (global) +└── / + ├── metadata.json # Swarm configuration + ├── task_1/ + │ ├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) + │ ├── events.jsonl # Structured progress and observations + │ ├── explanations.jsonl # Complete behavior explanations + │ ├── summary.json # Agent-provided title and summary for UI + │ ├── app.db # Isolated SQLite database + │ ├── agent_prompt.md # The prompt given to the agent + │ ├── mcp_config.json # MCP server configuration for Claude Code + │ └── claude_output.jsonl # Raw Claude Code output (stream-json format) + ├── task_2/ + │ └── ... + └── task_N/ + └── ... ``` ## Key Files @@ -103,6 +137,9 @@ tail -f SPD_OUT_DIR/agent_swarm//task_*/events.jsonl # View all explanations cat SPD_OUT_DIR/agent_swarm//task_*/explanations.jsonl | jq . +# View agent suggestions for system improvement (global file) +cat SPD_OUT_DIR/agent_swarm/suggestions.jsonl | jq . + # Check SLURM job status squeue --me diff --git a/spd/agent_swarm/agent_prompt.py b/spd/agent_swarm/agent_prompt.py index b469b47b5..44424c190 100644 --- a/spd/agent_swarm/agent_prompt.py +++ b/spd/agent_swarm/agent_prompt.py @@ -1,8 +1,7 @@ """System prompt for SPD investigation agents. This module contains the detailed instructions given to each agent in the swarm. -The prompt explains how to use the SPD app API and the scientific methodology -for investigating model behaviors. +The agent has access to SPD tools via MCP - tools are self-documenting. """ AGENT_SYSTEM_PROMPT = """ @@ -10,7 +9,7 @@ You are a research agent investigating behaviors in a neural network model decomposition. Your goal is to find interesting behaviors, understand how components interact to produce -them, and document your findings as explanations. +them, and document your findings. ## Your Mission @@ -23,326 +22,126 @@ 2. **Understand the mechanism**: Figure out which components are involved and how they work together to produce the behavior -3. **Document your findings**: Write a clear explanation with supporting evidence +3. **Document your findings**: Write clear explanations with supporting evidence -## The SPD App Backend +## Available Tools (via MCP) -You have access to an SPD (Stochastic Parameter Decomposition) app backend running at: -`http://localhost:{port}` +You have access to SPD analysis tools. Use them directly - they have full documentation. -This app provides APIs for: -- Loading decomposed models -- Computing attribution graphs showing how components interact -- Optimizing sparse circuits for specific behaviors -- Running interventions (ablations) to test hypotheses -- Viewing component interpretations and correlations -- Searching the training dataset +**Analysis Tools:** +- **optimize_graph**: Find the minimal circuit for a behavior (e.g., "boy" → "he") +- **get_component_info**: Get interpretation and token stats for a component +- **run_ablation**: Test a circuit by running with only selected components +- **search_dataset**: Find examples in the training data +- **create_prompt**: Tokenize text for analysis -## API Reference - -### Health Check -```bash -curl http://localhost:{port}/api/health -# Returns: {{"status": "ok"}} -``` - -### Load a Run (ALREADY DONE FOR YOU) -The run is pre-loaded. Check status with: -```bash -curl http://localhost:{port}/api/status -``` - -### Create a Custom Prompt -To analyze a specific prompt: -```bash -curl -X POST "http://localhost:{port}/api/prompts/custom?text=The%20boy%20ate%20his" -# Returns: {{"id": , "token_ids": [...], "tokens": [...], "preview": "...", "next_token_probs": [...]}} -``` - -### Compute Optimized Attribution Graph (MOST IMPORTANT) -This optimizes a sparse circuit that achieves a behavior: -```bash -curl -X POST "http://localhost:{port}/api/graphs/optimized/stream?prompt_id=&loss_type=ce&loss_position=&label_token=&steps=100&imp_min_coeff=0.1&pnorm=0.5&mask_type=hard&loss_coeff=1.0&ci_threshold=0.01&normalize=target" -# Streams SSE events, final event has type="complete" with graph data -``` - -Parameters: -- `prompt_id`: ID from creating custom prompt -- `loss_type`: "ce" for cross-entropy (predicting specific token) or "kl" (matching full distribution) -- `loss_position`: Token position to optimize (0-indexed, usually last position) -- `label_token`: Token ID to predict (for CE loss) -- `steps`: Optimization steps (50-200 typical) -- `imp_min_coeff`: Importance minimization coefficient (0.05-0.3) -- `pnorm`: P-norm for sparsity (0.3-1.0, lower = sparser) -- `mask_type`: "hard" for binary masks, "soft" for continuous -- `ci_threshold`: Threshold for including nodes in graph (0.01-0.1) -- `normalize`: "target" normalizes by target layer, "none" for raw values - -### Get Component Interpretations -```bash -curl "http://localhost:{port}/api/correlations/interpretations" -# Returns: {{"h.0.mlp.c_fc:5": {{"label": "...", "confidence": "high"}}, ...}} -``` - -Get full interpretation details: -```bash -curl "http://localhost:{port}/api/correlations/interpretations/h.0.mlp.c_fc/5" -# Returns: {{"reasoning": "...", "prompt": "..."}} -``` - -### Get Component Token Statistics -```bash -curl "http://localhost:{port}/api/correlations/token_stats/h.0.mlp.c_fc/5?top_k=20" -# Returns input/output token associations -``` - -### Get Component Correlations -```bash -curl "http://localhost:{port}/api/correlations/components/h.0.mlp.c_fc/5?top_k=20" -# Returns components that frequently co-activate -``` - -### Run Intervention (Ablation) -Test a hypothesis by running the model with only selected components active: -```bash -curl -X POST "http://localhost:{port}/api/intervention/run" \\ - -H "Content-Type: application/json" \\ - -d '{{"graph_id": , "text": "The boy ate his", "selected_nodes": ["h.0.mlp.c_fc:3:5", "h.1.attn.o_proj:3:10"], "top_k": 10}}' -# Returns predictions with only selected components active vs full model -``` - -Node format: "layer:seq_pos:component_idx" -- `layer`: e.g., "h.0.mlp.c_fc", "h.1.attn.o_proj" -- `seq_pos`: Position in sequence (0-indexed) -- `component_idx`: Component index within layer - -### Search Dataset -Find prompts with specific patterns: -```bash -curl -X POST "http://localhost:{port}/api/dataset/search?query=she%20said&split=train" -curl "http://localhost:{port}/api/dataset/results?page=1&page_size=20" -``` - -### Get Random Samples with Loss -Find high/low loss examples: -```bash -curl "http://localhost:{port}/api/dataset/random_with_loss?n_samples=20&seed=42" -``` - -### Probe Component Activation -See how a component responds to arbitrary text: -```bash -curl -X POST "http://localhost:{port}/api/activation_contexts/probe" \\ - -H "Content-Type: application/json" \\ - -d '{{"text": "The boy ate his", "layer": "h.0.mlp.c_fc", "component_idx": 5}}' -# Returns CI values and activations at each position -``` - -### Get Dataset Attributions -See which components influence each other across the training data: -```bash -curl "http://localhost:{port}/api/dataset_attributions/h.0.mlp.c_fc/5?k=10" -# Returns positive/negative sources and targets -``` +**Output Tools:** +- **update_research_log**: Append to your research log (PRIMARY OUTPUT - use frequently!) +- **save_explanation**: Save a complete, validated behavior explanation +- **set_investigation_summary**: Set a title and summary for your investigation (shown in UI) +- **submit_suggestion**: Submit ideas for improving the tools or system ## Investigation Methodology ### Step 1: Find an Interesting Behavior -Start by exploring the model's behavior: - -1. **Search for patterns**: Use `/api/dataset/search` to find prompts with specific - linguistic patterns (pronouns, verb conjugations, completions, etc.) - -2. **Look at high-loss examples**: Use `/api/dataset/random_with_loss` to find where - the model struggles or succeeds - -3. **Create test prompts**: Use `/api/prompts/custom` to create prompts that test - specific capabilities - -Good behaviors to investigate: -- Gendered pronoun prediction ("The doctor said she" vs "The doctor said he") -- Subject-verb agreement ("The cats are" vs "The cat is") -- Pattern completion ("1, 2, 3," → "4") -- Semantic associations ("The capital of France is" → "Paris") -- Grammatical structure (completing sentences correctly) +Start by exploring: +- Search for linguistic patterns: pronouns, verb agreement, completions +- Create test prompts that show clear model behavior +- Good targets: gendered pronouns, subject-verb agreement, semantic associations ### Step 2: Optimize a Sparse Circuit Once you have a behavior: - -1. **Create the prompt** via `/api/prompts/custom` - -2. **Identify the target token**: What token should be predicted? Get its ID from - the tokenizer or from the prompt creation response. - -3. **Run optimization** via `/api/graphs/optimized/stream`: - - Use `loss_type=ce` with the target token - - Set `loss_position` to the position where prediction matters - - Start with `imp_min_coeff=0.1`, `pnorm=0.5`, `steps=100` - - Use `ci_threshold=0.01` to see active components - -4. **Examine the graph**: The response shows: - - `nodeCiVals`: Which components are active (high CI = important) - - `edges`: How components connect (gradient flow) - - `outputProbs`: Model predictions +1. Use `optimize_graph` with your prompt and target token +2. Examine which components have high CI values +3. Note the circuit size (fewer = cleaner mechanism) ### Step 3: Understand Component Roles -For each important component in the graph: - -1. **Check the interpretation**: Use `/api/correlations/interpretations//` - to see if we already have an idea what this component does - -2. **Look at token stats**: Use `/api/correlations/token_stats//` to see - what tokens activate this component (input) and what it predicts (output) - -3. **Check correlations**: Use `/api/correlations/components//` to see - what other components co-activate - -4. **Probe on variations**: Use `/api/activation_contexts/probe` to see how the - component responds to related prompts +For each important component: +1. Use `get_component_info` to see its interpretation and token stats +2. Look at what tokens activate it (input) and what it predicts (output) +3. Check correlated components ### Step 4: Test with Ablations Form hypotheses and test them: - -1. **Hypothesis**: "Component X stores information about gender" - -2. **Test**: Run intervention with and without component X - - If prediction changes as expected → supports hypothesis - - If no change → component may not be necessary for this - - If unexpected change → revise hypothesis - -3. **Control**: Try ablating other components to ensure specificity +1. Use `run_ablation` with the circuit's components +2. Verify predictions match expectations +3. Try removing individual components to find critical ones ### Step 5: Document Your Findings -Write a `BehaviorExplanation` with: -- Clear subject prompt -- Description of the behavior -- Components and their roles -- How they work together -- Supporting evidence from ablations/attributions -- Confidence level -- Alternative hypotheses you considered -- Limitations +Use `update_research_log` frequently - this is how humans monitor your work! +When you complete an investigation, use `save_explanation` to create a structured record. ## Scientific Principles -### Be Epistemologically Humble -- Your first hypothesis is probably wrong or incomplete -- Always consider alternative explanations -- A single confirming example doesn't prove a theory -- Look for disconfirming evidence - -### Be Bayesian -- Start with priors from component interpretations -- Update beliefs based on evidence -- Consider the probability of the evidence under different hypotheses -- Don't anchor too strongly on initial observations - -### Triangulate Evidence -- Don't rely on a single type of evidence -- Ablation results + attribution patterns + token stats together are stronger -- Look for convergent evidence from multiple sources - -### Document Uncertainty -- Be explicit about what you're confident in vs. uncertain about -- Note when evidence is weak or ambiguous -- Identify what additional tests would strengthen the explanation +- **Be skeptical**: Your first hypothesis is probably incomplete +- **Triangulate**: Don't rely on a single type of evidence +- **Document uncertainty**: Note what you're confident in vs. uncertain about +- **Consider alternatives**: What else could explain the behavior? ## Output Format -Write your findings to the output files. **The research log is your primary output for humans to read.** +### Research Log (PRIMARY OUTPUT - Update frequently!) -### research_log.md (MOST IMPORTANT - Write here frequently!) -This is a human-readable log of your investigation. Write here often so someone can follow your progress. -Use clear markdown formatting: +Use `update_research_log` with markdown content. Call it every few minutes to show progress: -```markdown -## [2026-01-30 14:23:15] Starting Investigation - -Looking at component interpretations to find interesting patterns... - -## [2026-01-30 14:25:42] Hypothesis: Gendered Pronoun Circuit - -Found components that seem related to pronouns: -- h.0.mlp.c_fc:42 - "he/his pronouns after male subjects" -- h.0.mlp.c_fc:89 - "she/her pronouns after female subjects" - -Testing with prompt: "The boy said that he" - -## [2026-01-30 14:28:03] Optimization Results - -Ran optimization for "he" prediction at position 4: -- Found 15 active components -- Key components: h.0.mlp.c_fc:42 (CI=0.92), h.1.attn.o_proj:156 (CI=0.78) +Example calls: +``` +update_research_log("# Research Log - Task 1\n\nStarting investigation...\n\n") -## [2026-01-30 14:31:17] Ablation Test +update_research_log("## [14:25:42] Hypothesis: Gendered Pronoun Circuit\n\nTesting prompt: 'The boy said that' → expecting ' he'\n\nUsed optimize_graph - found 15 active components:\n- h.0.mlp.c_fc:407 (CI=0.95) - 'male subjects'\n- h.3.attn.o_proj:262 (CI=0.92) - 'masculine pronouns'\n\n") -Ablating h.0.mlp.c_fc:42: -- Before: P(he)=0.82, P(she)=0.11 -- After: P(he)=0.23, P(she)=0.45 +update_research_log("## [14:28:03] Ablation Test\n\nResult: P(he) = 0.89 (vs 0.22 baseline)\n\nThis confirms the circuit is sufficient!\n\n") +``` -This confirms the component is important for masculine pronoun prediction! +### Saving Explanations -## [2026-01-30 14:35:44] Conclusion +When you have a complete explanation, use `save_explanation`: -Found a circuit for gendered pronoun prediction. Components h.0.mlp.c_fc:42 and -h.1.attn.o_proj:156 work together to predict masculine pronouns after male subjects. +``` +save_explanation( + subject_prompt="The boy said that", + behavior_description="Predicts masculine pronoun 'he' after male subject", + components_involved=[ + {{"component_key": "h.0.mlp.c_fc:407", "role": "Male subject detector"}}, + {{"component_key": "h.3.attn.o_proj:262", "role": "Masculine pronoun promoter"}} + ], + explanation="Component h.0.mlp.c_fc:407 activates on male subjects...", + confidence="medium", + limitations=["Only tested on simple sentences"] +) ``` -**TIP**: Get the current timestamp with `date '+%Y-%m-%d %H:%M:%S'` for your log entries. +### Submitting Suggestions -**IMPORTANT**: Update the research log every few minutes with your current progress, -findings, and next steps. This is how humans monitor your work! +If you have ideas for improving the system, use `submit_suggestion`: -### events.jsonl -Log structured progress and observations: -```json -{{"event_type": "observation", "message": "Component h.0.mlp.c_fc:5 has high CI when subject is male", "details": {{"ci_value": 0.85}}, "timestamp": "..."}} ``` - -### explanations.jsonl -When you have a complete explanation: -```json -{{ - "subject_prompt": "The boy ate his lunch", - "behavior_description": "Correctly predicts gendered pronoun 'his' after male subject", - "components_involved": [ - {{"component_key": "h.0.mlp.c_fc:5", "role": "Encodes subject gender as male", "interpretation": "male names/subjects"}}, - {{"component_key": "h.1.attn.o_proj:10", "role": "Transmits gender information to output", "interpretation": null}} - ], - "explanation": "Component h.0.mlp.c_fc:5 activates on male subjects and stores gender information...", - "supporting_evidence": [ - {{"evidence_type": "ablation", "description": "Removing component causes prediction to change from 'his' to 'her'", "details": {{"without_component": {{"his": 0.1, "her": 0.6}}, "with_component": {{"his": 0.8, "her": 0.1}}}}}} - ], - "confidence": "medium", - "alternative_hypotheses": ["Component might encode broader concept of masculine entities, not just humans"], - "limitations": ["Only tested on simple subject-pronoun sentences"] -}} +submit_suggestion( + category="tool_improvement", + title="Add batch ablation support", + description="It would be faster to test multiple ablations at once...", + context="I was testing 10 different component subsets one at a time" +) ``` ## Getting Started -1. **Create your research log**: Start by creating `research_log.md` with a header -2. Check the current status: `curl http://localhost:{port}/api/status` -3. Explore available interpretations: `curl http://localhost:{port}/api/correlations/interpretations` -4. Search for interesting prompts or create your own -5. **Update research_log.md** with what you're investigating -6. Optimize a sparse circuit for a behavior you find -7. Investigate the components involved -8. Test hypotheses with ablations -9. **Update research_log.md** with findings -10. Document complete explanations in `explanations.jsonl` - -**Remember to update research_log.md frequently** - this is how humans follow your progress! +1. **Create your research log** with `update_research_log("# Research Log - Task N\n\n...")` +2. Use analysis tools to explore the model +3. Find an interesting behavior to investigate +4. **Call `update_research_log` frequently** - humans are watching! +5. Use `save_explanation` for complete findings +6. **Call `set_investigation_summary`** with a title and summary when done (or periodically for updates) You are exploring! Not every investigation will lead to a clear explanation. Document what you learn, even if it's "this was more complicated than expected." -Good luck, and happy investigating! +Good luck! """ @@ -350,7 +149,7 @@ def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) """Generate the full agent prompt with runtime parameters filled in. Args: - port: The port the backend is running on. + port: The port the backend is running on (for reference, tools use MCP). wandb_path: The WandB path of the loaded run. task_id: The SLURM task ID for this agent. output_dir: Path to the agent's output directory. @@ -358,21 +157,19 @@ def get_agent_prompt(port: int, wandb_path: str, task_id: int, output_dir: str) Returns: The complete agent prompt with parameters substituted. """ - prompt = AGENT_SYSTEM_PROMPT.format(port=port) - runtime_context = f""" ## Runtime Context -- **Backend URL**: http://localhost:{port} -- **Loaded Run**: {wandb_path} +- **Model Run**: {wandb_path} - **Task ID**: {task_id} -- **Output Directory**: {output_dir} -Your output files: -- `{output_dir}/research_log.md` - **PRIMARY OUTPUT** - Write readable progress updates here frequently! -- `{output_dir}/events.jsonl` - Log structured events and observations here -- `{output_dir}/explanations.jsonl` - Write complete explanations here +Use the MCP tools for ALL output: +- `update_research_log` → **PRIMARY OUTPUT** - Update frequently with your progress! +- `save_explanation` → Save complete, validated behavior explanations +- `submit_suggestion` → Share ideas for improving the system -**Start by creating research_log.md with a header, then update it every few minutes!** +**Start by calling update_research_log to create your log, then investigate!** """ - return prompt + runtime_context + # Note: output_dir and port are available but agents shouldn't need them + _ = output_dir, port + return AGENT_SYSTEM_PROMPT + runtime_context diff --git a/spd/agent_swarm/scripts/run_agent.py b/spd/agent_swarm/scripts/run_agent.py index 3c5a78449..627b6d473 100644 --- a/spd/agent_swarm/scripts/run_agent.py +++ b/spd/agent_swarm/scripts/run_agent.py @@ -4,10 +4,12 @@ 1. Creates an isolated output directory for this agent 2. Starts the app backend with an isolated database 3. Loads the SPD run -4. Launches Claude Code with investigation instructions -5. Handles cleanup on exit +4. Configures MCP server for Claude Code +5. Launches Claude Code with investigation instructions +6. Handles cleanup on exit """ +import json import os import signal import socket @@ -26,6 +28,21 @@ from spd.log import logger +def write_mcp_config(task_dir: Path, port: int) -> Path: + """Write MCP configuration file for Claude Code.""" + mcp_config = { + "mcpServers": { + "spd": { + "type": "http", + "url": f"http://localhost:{port}/mcp", + } + } + } + config_path = task_dir / "mcp_config.json" + config_path.write_text(json.dumps(mcp_config, indent=2)) + return config_path + + def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: """Find an available port starting from start_port.""" for offset in range(max_attempts): @@ -124,9 +141,13 @@ def run_agent( ), ) - # Start backend with isolated database + # Start backend with isolated database and swarm configuration env = os.environ.copy() env["SPD_APP_DB_PATH"] = str(db_path) + env["SPD_MCP_EVENTS_PATH"] = str(events_path) + env["SPD_MCP_TASK_DIR"] = str(task_dir) + # Suggestions go to a global file (one level above swarm dirs) + env["SPD_MCP_SUGGESTIONS_PATH"] = str(swarm_dir.parent / "suggestions.jsonl") backend_cmd = [ sys.executable, @@ -214,7 +235,12 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: prompt_path = task_dir / "agent_prompt.md" prompt_path.write_text(agent_prompt) - # Launch Claude Code with streaming JSON output + # Write MCP config for Claude Code + mcp_config_path = write_mcp_config(task_dir, port) + logger.info(f"[Task {task_id}] MCP config written to {mcp_config_path}") + + # Launch Claude Code with streaming JSON output and MCP + # No --dangerously-skip-permissions needed - agents use MCP tools for all I/O claude_output_path = task_dir / "claude_output.jsonl" claude_cmd = [ "claude", @@ -224,7 +250,8 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: "stream-json", "--max-turns", str(max_turns), - "--dangerously-skip-permissions", + "--mcp-config", + str(mcp_config_path), ] logger.info(f"[Task {task_id}] Starting Claude Code (max_turns={max_turns})...") @@ -252,15 +279,16 @@ def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: Please begin your investigation: -1. **FIRST**: Create `{task_dir}/research_log.md` with a header like "# Research Log - Task {task_id}" -2. Check the backend status and explore component interpretations -3. Find an interesting behavior to investigate -4. **Update research_log.md frequently** with your progress, findings, and next steps +1. **FIRST**: Use the `update_research_log` tool to create your research log with a header like: + "# Research Log - Task {task_id}\\n\\nStarting investigation of {wandb_path}\\n\\n" +2. Explore component interpretations using `get_component_info` +3. Find an interesting behavior to investigate with `optimize_graph` +4. **Use `update_research_log` frequently** to document your progress, findings, and next steps Remember: -- research_log.md is your PRIMARY output - humans will read this to follow your work -- Update it every few minutes with what you're doing and discovering -- Write complete explanations to explanations.jsonl when you finish investigating a behavior +- The research log is your PRIMARY output - use `update_research_log` every few minutes +- Use `save_explanation` to record complete, validated explanations +- Use `submit_suggestion` if you have ideas for improving the tools or system """ assert claude_proc.stdin is not None claude_proc.stdin.write(investigation_request) diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 95d6bc1b3..c1eed4f49 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -48,7 +48,8 @@ backend/ ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering ├── dataset_search.py # SimpleStories dataset search - └── agents.py # Various useful endpoints that AI agents should look at when helping + ├── agents.py # Various useful endpoints that AI agents should look at when helping + └── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ``` Note: Activation contexts, correlations, and token stats are now loaded from pre-harvested data (see `spd/harvest/`). The app no longer computes these on-the-fly. diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index 79cea1087..83f6e8ac2 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -9,6 +9,8 @@ from spd.app.backend.routers.dataset_search import router as dataset_search_router from spd.app.backend.routers.graphs import router as graphs_router from spd.app.backend.routers.intervention import router as intervention_router +from spd.app.backend.routers.investigations import router as investigations_router +from spd.app.backend.routers.mcp import router as mcp_router from spd.app.backend.routers.prompts import router as prompts_router from spd.app.backend.routers.runs import router as runs_router @@ -22,6 +24,8 @@ "dataset_search_router", "graphs_router", "intervention_router", + "investigations_router", + "mcp_router", "prompts_router", "runs_router", ] diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py new file mode 100644 index 000000000..3ea6244c3 --- /dev/null +++ b/spd/app/backend/routers/investigations.py @@ -0,0 +1,262 @@ +"""Investigations endpoint for viewing agent swarm results. + +Lists and serves investigation data from SPD_OUT_DIR/agent_swarm/. +Each task is treated as an independent investigation (flattened across swarms). +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.settings import SPD_OUT_DIR + +router = APIRouter(prefix="/api/investigations", tags=["investigations"]) + +SWARM_DIR = SPD_OUT_DIR / "agent_swarm" + + +class InvestigationSummary(BaseModel): + """Summary of a single investigation (task).""" + + id: str # swarm_id/task_id + swarm_id: str + task_id: int + wandb_path: str | None + created_at: str + has_research_log: bool + has_explanations: bool + event_count: int + last_event_time: str | None + last_event_message: str | None + # Agent-provided summary + title: str | None + summary: str | None + status: str | None # in_progress, completed, inconclusive + + +class EventEntry(BaseModel): + """A single event from events.jsonl.""" + + event_type: str + timestamp: str + message: str + details: dict[str, Any] | None = None + + +class InvestigationDetail(BaseModel): + """Full detail of an investigation including logs.""" + + id: str + swarm_id: str + task_id: int + wandb_path: str | None + created_at: str + research_log: str | None + events: list[EventEntry] + explanations: list[dict[str, Any]] + # Agent-provided summary + title: str | None + summary: str | None + status: str | None + + +def _parse_swarm_metadata(swarm_path: Path) -> dict[str, Any] | None: + """Parse metadata.json from a swarm directory.""" + metadata_path = swarm_path / "metadata.json" + if not metadata_path.exists(): + return None + try: + data: dict[str, Any] = json.loads(metadata_path.read_text()) + return data + except Exception: + return None + + +def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: + """Get the last event timestamp, message, and total count from events.jsonl.""" + if not events_path.exists(): + return None, None, 0 + + last_time = None + last_msg = None + count = 0 + + try: + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue + except Exception: + pass + + return last_time, last_msg, count + + +def _parse_task_summary(task_path: Path) -> tuple[str | None, str | None, str | None]: + """Parse summary.json from a task directory. Returns (title, summary, status).""" + summary_path = task_path / "summary.json" + if not summary_path.exists(): + return None, None, None + try: + data: dict[str, Any] = json.loads(summary_path.read_text()) + return data.get("title"), data.get("summary"), data.get("status") + except Exception: + return None, None, None + + +def _get_task_created_at(task_path: Path, swarm_metadata: dict[str, Any] | None) -> str: + """Get creation time for a task.""" + # Try to get from first event + events_path = task_path / "events.jsonl" + if events_path.exists(): + try: + with open(events_path) as f: + first_line = f.readline().strip() + if first_line: + event = json.loads(first_line) + if "timestamp" in event: + return event["timestamp"] + except Exception: + pass + + # Fall back to swarm metadata + if swarm_metadata and "created_at" in swarm_metadata: + return swarm_metadata["created_at"] + + # Fall back to directory mtime + return datetime.fromtimestamp(task_path.stat().st_mtime).isoformat() + + +@router.get("") +def list_investigations() -> list[InvestigationSummary]: + """List all investigations (tasks) flattened across swarms.""" + if not SWARM_DIR.exists(): + return [] + + results = [] + + for swarm_path in SWARM_DIR.iterdir(): + if not swarm_path.is_dir() or not swarm_path.name.startswith("swarm-"): + continue + + swarm_id = swarm_path.name + metadata = _parse_swarm_metadata(swarm_path) + wandb_path = metadata.get("wandb_path") if metadata else None + + for task_path in swarm_path.iterdir(): + if not task_path.is_dir() or not task_path.name.startswith("task_"): + continue + + try: + task_id = int(task_path.name.split("_")[1]) + except (ValueError, IndexError): + continue + + events_path = task_path / "events.jsonl" + last_time, last_msg, event_count = _get_last_event(events_path) + title, summary, status = _parse_task_summary(task_path) + + results.append( + InvestigationSummary( + id=f"{swarm_id}/{task_id}", + swarm_id=swarm_id, + task_id=task_id, + wandb_path=wandb_path, + created_at=_get_task_created_at(task_path, metadata), + has_research_log=(task_path / "research_log.md").exists(), + has_explanations=(task_path / "explanations.jsonl").exists() + and (task_path / "explanations.jsonl").stat().st_size > 0, + event_count=event_count, + last_event_time=last_time, + last_event_message=last_msg, + title=title, + summary=summary, + status=status, + ) + ) + + # Sort by creation time, newest first + results.sort(key=lambda x: x.created_at, reverse=True) + return results + + +@router.get("/{swarm_id}/{task_id}") +def get_investigation(swarm_id: str, task_id: int) -> InvestigationDetail: + """Get full details of an investigation.""" + swarm_path = SWARM_DIR / swarm_id + task_path = swarm_path / f"task_{task_id}" + + if not task_path.exists() or not task_path.is_dir(): + raise HTTPException(status_code=404, detail=f"Investigation {swarm_id}/{task_id} not found") + + metadata = _parse_swarm_metadata(swarm_path) + + # Read research log + research_log = None + research_log_path = task_path / "research_log.md" + if research_log_path.exists(): + research_log = research_log_path.read_text() + + # Read events + events = [] + events_path = task_path / "events.jsonl" + if events_path.exists(): + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append( + EventEntry( + event_type=event.get("event_type", "unknown"), + timestamp=event.get("timestamp", ""), + message=event.get("message", ""), + details=event.get("details"), + ) + ) + except json.JSONDecodeError: + continue + + # Read explanations + explanations: list[dict[str, Any]] = [] + explanations_path = task_path / "explanations.jsonl" + if explanations_path.exists(): + with open(explanations_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + explanations.append(json.loads(line)) + except json.JSONDecodeError: + continue + + title, summary, status = _parse_task_summary(task_path) + + return InvestigationDetail( + id=f"{swarm_id}/{task_id}", + swarm_id=swarm_id, + task_id=task_id, + wandb_path=metadata.get("wandb_path") if metadata else None, + created_at=_get_task_created_at(task_path, metadata), + research_log=research_log, + events=events, + explanations=explanations, + title=title, + summary=summary, + status=status, + ) diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py new file mode 100644 index 000000000..c986109a0 --- /dev/null +++ b/spd/app/backend/routers/mcp.py @@ -0,0 +1,1171 @@ +"""MCP (Model Context Protocol) endpoint for Claude Code integration. + +This router implements the MCP JSON-RPC protocol over HTTP, allowing Claude Code +to use SPD tools directly with proper schemas and streaming progress. + +MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports +""" + +import inspect +import json +import queue +import threading +import traceback +from collections.abc import Generator +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +import torch +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from spd.app.backend.compute import ( + compute_intervention_forward, + compute_prompt_attributions_optimized, +) +from spd.app.backend.database import StoredGraph +from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig +from spd.app.backend.routers.graphs import build_out_probs +from spd.app.backend.state import StateManager +from spd.configs import ImportanceMinimalityLossConfig +from spd.harvest import analysis +from spd.log import logger +from spd.utils.distributed_utils import get_device + +router = APIRouter(tags=["mcp"]) + +DEVICE = get_device() + +# MCP protocol version +MCP_PROTOCOL_VERSION = "2024-11-05" + +# Optional paths for swarm integration (set via environment at runtime) +_events_log_path: Path | None = None +_task_dir: Path | None = None +_suggestions_path: Path | None = None + + +def set_events_log_path(path: Path | None) -> None: + """Set the path for logging MCP tool events (for swarm monitoring).""" + global _events_log_path + _events_log_path = path + + +def set_task_dir(path: Path | None) -> None: + """Set the task directory for research log and explanations output.""" + global _task_dir + _task_dir = path + + +def set_suggestions_path(path: Path | None) -> None: + """Set the path for the central suggestions file.""" + global _suggestions_path + _suggestions_path = path + + +def _log_event(event_type: str, message: str, details: dict[str, Any] | None = None) -> None: + """Log an event to the events file if configured.""" + if _events_log_path is None: + return + event = { + "event_type": event_type, + "timestamp": datetime.now(UTC).isoformat(), + "message": message, + "details": details or {}, + } + with open(_events_log_path, "a") as f: + f.write(json.dumps(event) + "\n") + + +# ============================================================================= +# MCP Protocol Types +# ============================================================================= + + +class MCPRequest(BaseModel): + """JSON-RPC 2.0 request.""" + + jsonrpc: Literal["2.0"] + id: int | str | None = None + method: str + params: dict[str, Any] | None = None + + +class MCPResponse(BaseModel): + """JSON-RPC 2.0 response.""" + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None + result: Any | None = None + error: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + """MCP tool definition.""" + + name: str + description: str + inputSchema: dict[str, Any] + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS: list[ToolDefinition] = [ + ToolDefinition( + name="optimize_graph", + description="""Optimize a sparse circuit for a specific behavior. + +Given a prompt and target token, finds the minimal set of components that produce the target prediction. +Returns the optimized graph with component CI values and edges showing information flow. + +This is the primary tool for understanding how the model produces a specific output.""", + inputSchema={ + "type": "object", + "properties": { + "prompt_text": { + "type": "string", + "description": "The input text to analyze (e.g., 'The boy said that')", + }, + "target_token": { + "type": "string", + "description": "The token to predict (e.g., ' he'). Include leading space if needed.", + }, + "loss_position": { + "type": "integer", + "description": "Position to optimize prediction at (0-indexed, usually last position). If not specified, uses the last position.", + }, + "steps": { + "type": "integer", + "description": "Optimization steps (default: 100, more = sparser but slower)", + "default": 100, + }, + "ci_threshold": { + "type": "number", + "description": "CI threshold for including components (default: 0.5, lower = more components)", + "default": 0.5, + }, + }, + "required": ["prompt_text", "target_token"], + }, + ), + ToolDefinition( + name="get_component_info", + description="""Get detailed information about a component. + +Returns the component's interpretation (what it does), token statistics (what tokens +activate it and what it predicts), and correlated components. + +Use this to understand what role a component plays in a circuit.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Layer name (e.g., 'h.0.mlp.c_fc', 'h.2.attn.o_proj')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "top_k": { + "type": "integer", + "description": "Number of top tokens/correlations to return (default: 20)", + "default": 20, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="run_ablation", + description="""Run an ablation experiment with only selected components active. + +Tests a hypothesis by running the model with a sparse set of components. +Returns predictions showing what the circuit produces vs the full model. + +Use this to verify that identified components are necessary and sufficient.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Input text for the ablation", + }, + "selected_nodes": { + "type": "array", + "items": {"type": "string"}, + "description": "Node keys to keep active (format: 'layer:seq_pos:component_idx')", + }, + "top_k": { + "type": "integer", + "description": "Number of top predictions to return per position (default: 10)", + "default": 10, + }, + }, + "required": ["text", "selected_nodes"], + }, + ), + ToolDefinition( + name="search_dataset", + description="""Search the SimpleStories training dataset for patterns. + +Finds stories containing the query string. Use this to find examples of +specific linguistic patterns (pronouns, verb forms, etc.) for investigation.""", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Text to search for (case-insensitive)", + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 20)", + "default": 20, + }, + }, + "required": ["query"], + }, + ), + ToolDefinition( + name="create_prompt", + description="""Create a prompt for analysis. + +Tokenizes the text and returns token IDs and next-token probabilities. +The returned prompt_id can be used with other tools.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to create a prompt from", + }, + }, + "required": ["text"], + }, + ), + ToolDefinition( + name="update_research_log", + description="""Append content to your research log. + +Use this to document your investigation progress, findings, and next steps. +The research log is your primary output for humans to follow your work. + +Call this frequently (every few minutes) with updates on what you're doing.""", + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Markdown content to append to the research log", + }, + }, + "required": ["content"], + }, + ), + ToolDefinition( + name="save_explanation", + description="""Save a complete behavior explanation. + +Use this when you have finished investigating a behavior and want to document +your findings. This creates a structured record of the behavior, the components +involved, and your explanation of how they work together. + +Only call this for complete, validated explanations - not preliminary hypotheses.""", + inputSchema={ + "type": "object", + "properties": { + "subject_prompt": { + "type": "string", + "description": "A prompt that demonstrates the behavior", + }, + "behavior_description": { + "type": "string", + "description": "Clear description of the behavior", + }, + "components_involved": { + "type": "array", + "items": { + "type": "object", + "properties": { + "component_key": { + "type": "string", + "description": "Component key (e.g., 'h.0.mlp.c_fc:5')", + }, + "role": { + "type": "string", + "description": "The role this component plays", + }, + "interpretation": { + "type": "string", + "description": "Auto-interp label if available", + }, + }, + "required": ["component_key", "role"], + }, + "description": "List of components and their roles", + }, + "explanation": { + "type": "string", + "description": "How the components work together", + }, + "supporting_evidence": { + "type": "array", + "items": { + "type": "object", + "properties": { + "evidence_type": { + "type": "string", + "enum": [ + "ablation", + "attribution", + "activation_pattern", + "correlation", + "other", + ], + }, + "description": {"type": "string"}, + "details": {"type": "object"}, + }, + "required": ["evidence_type", "description"], + }, + "description": "Evidence supporting this explanation", + }, + "confidence": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Your confidence level", + }, + "alternative_hypotheses": { + "type": "array", + "items": {"type": "string"}, + "description": "Other hypotheses you considered", + }, + "limitations": { + "type": "array", + "items": {"type": "string"}, + "description": "Known limitations of this explanation", + }, + }, + "required": [ + "subject_prompt", + "behavior_description", + "components_involved", + "explanation", + "confidence", + ], + }, + ), + ToolDefinition( + name="submit_suggestion", + description="""Submit a suggestion for improving the SPD system. + +Use this when you encounter limitations, have ideas for new tools, or think +of ways the system could better support investigation work. + +Suggestions are collected centrally and reviewed by humans to improve the system.""", + inputSchema={ + "type": "object", + "properties": { + "category": { + "type": "string", + "enum": ["tool_improvement", "new_tool", "documentation", "bug", "other"], + "description": "Category of suggestion", + }, + "title": { + "type": "string", + "description": "Brief title for the suggestion", + }, + "description": { + "type": "string", + "description": "Detailed description of the suggestion", + }, + "context": { + "type": "string", + "description": "What you were trying to do when you had this idea", + }, + }, + "required": ["category", "title", "description"], + }, + ), + ToolDefinition( + name="set_investigation_summary", + description="""Set a title and summary for your investigation. + +Call this when you've completed your investigation (or periodically as you make progress) +to provide a human-readable title and summary that will be shown in the investigations UI. + +The title should be short and descriptive. The summary should be 1-3 sentences +explaining what you investigated and what you found.""", + inputSchema={ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Short title for the investigation (e.g., 'Gendered Pronoun Circuit')", + }, + "summary": { + "type": "string", + "description": "Brief summary of findings (1-3 sentences)", + }, + "status": { + "type": "string", + "enum": ["in_progress", "completed", "inconclusive"], + "description": "Current status of the investigation", + "default": "in_progress", + }, + }, + "required": ["title", "summary"], + }, + ), +] + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + + +def _get_state(): + """Get state manager and loaded run, raising clear errors if not available.""" + manager = StateManager.get() + if manager.run_state is None: + raise ValueError("No run loaded. The backend must load a run first.") + return manager, manager.run_state + + +def _tool_optimize_graph(params: dict[str, Any]) -> Generator[dict[str, Any]]: + """Optimize a sparse circuit for a behavior. Yields progress events.""" + manager, loaded = _get_state() + + prompt_text = params["prompt_text"] + target_token = params["target_token"] + steps = params.get("steps", 100) + ci_threshold = params.get("ci_threshold", 0.5) + + # Tokenize prompt + token_ids = loaded.tokenizer.encode(prompt_text, add_special_tokens=False) + if not token_ids: + raise ValueError("Prompt text produced no tokens") + + # Find target token ID + target_token_ids = loaded.tokenizer.encode(target_token, add_special_tokens=False) + if len(target_token_ids) != 1: + raise ValueError( + f"Target token '{target_token}' tokenizes to {len(target_token_ids)} tokens, expected 1. " + f"Token IDs: {target_token_ids}" + ) + label_token = target_token_ids[0] + + # Determine loss position + loss_position = params.get("loss_position") + if loss_position is None: + loss_position = len(token_ids) - 1 + + if loss_position >= len(token_ids): + raise ValueError( + f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + _log_event( + "tool_start", + f"optimize_graph: '{prompt_text}' → '{target_token}'", + {"steps": steps, "loss_position": loss_position}, + ) + + yield {"type": "progress", "current": 0, "total": steps, "stage": "starting optimization"} + + # Create prompt in DB + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Build optimization config + loss_config = CELossConfig(coeff=1.0, position=loss_position, label_token=label_token) + + optim_config = OptimCIConfig( + seed=0, + lr=1e-2, + steps=steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, steps // 10), + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=0.0), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type="ci", + ) + + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() + + def on_progress(current: int, total: int, stage: str) -> None: + progress_queue.put({"current": current, "total": total, "stage": stage}) + + # Run optimization in thread + result_holder: list[Any] = [] + error_holder: list[Exception] = [] + + def compute(): + try: + with manager.gpu_lock(): + result = compute_prompt_attributions_optimized( + model=loaded.model, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + optim_config=optim_config, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + ) + result_holder.append(result) + except Exception as e: + error_holder.append(e) + + thread = threading.Thread(target=compute) + thread.start() + + # Yield progress events (throttle logging to every 10% or 10 steps) + last_logged_step = -1 + log_interval = max(1, steps // 10) + + while thread.is_alive() or not progress_queue.empty(): + try: + progress = progress_queue.get(timeout=0.1) + current = progress["current"] + # Log to events.jsonl at intervals (for human monitoring) + if current - last_logged_step >= log_interval or current == progress["total"]: + _log_event( + "optimization_progress", + f"optimize_graph: step {current}/{progress['total']} ({progress['stage']})", + {"prompt": prompt_text, "target": target_token, **progress}, + ) + last_logged_step = current + # Always yield to SSE stream (for Claude) + yield {"type": "progress", **progress} + except queue.Empty: + continue + + thread.join() + + if error_holder: + raise error_holder[0] + + if not result_holder: + raise RuntimeError("Optimization completed but no result was produced") + + result = result_holder[0] + + # Build output + out_probs = build_out_probs( + ci_masked_out_probs=result.ci_masked_out_probs.cpu(), + ci_masked_out_logits=result.ci_masked_out_logits.cpu(), + target_out_probs=result.target_out_probs.cpu(), + target_out_logits=result.target_out_logits.cpu(), + output_prob_threshold=0.01, + token_strings=loaded.token_strings, + ) + + # Save graph to DB + from spd.app.backend.database import OptimizationParams + + opt_params = OptimizationParams( + imp_min_coeff=0.1, + steps=steps, + pnorm=0.5, + beta=0.0, + mask_type="ci", + loss=loss_config, + ) + graph_id = manager.db.save_graph( + prompt_id=prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + out_probs=out_probs, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + # Filter nodes by CI threshold + active_components = {k: v for k, v in result.node_ci_vals.items() if v >= ci_threshold} + + # Get target token probability + target_key = f"{loss_position}:{label_token}" + target_prob = out_probs.get(target_key) + + token_strings = [loaded.token_strings[t] for t in token_ids] + + final_result = { + "graph_id": graph_id, + "prompt_id": prompt_id, + "tokens": token_strings, + "target_token": target_token, + "target_token_id": label_token, + "target_position": loss_position, + "target_probability": target_prob.prob if target_prob else None, + "target_probability_baseline": target_prob.target_prob if target_prob else None, + "active_components": active_components, + "total_active": len(active_components), + "output_probs": {k: {"prob": v.prob, "token": v.token} for k, v in out_probs.items()}, + } + + _log_event( + "tool_complete", + f"optimize_graph complete: {len(active_components)} active components", + {"graph_id": graph_id, "target_prob": target_prob.prob if target_prob else None}, + ) + + yield {"type": "result", "data": final_result} + + +def _tool_get_component_info(params: dict[str, Any]) -> dict[str, Any]: + """Get detailed information about a component.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + top_k = params.get("top_k", 20) + component_key = f"{layer}:{component_idx}" + + _log_event( + "tool_call", f"get_component_info: {component_key}", {"layer": layer, "idx": component_idx} + ) + + result: dict[str, Any] = {"component_key": component_key} + + # Get interpretation + interpretations = loaded.harvest.interpretations + if component_key in interpretations: + interp = interpretations[component_key] + result["interpretation"] = { + "label": interp.label, + "confidence": interp.confidence, + "reasoning": interp.reasoning, + } + else: + result["interpretation"] = None + + # Get token stats + token_stats = loaded.harvest.token_stats + input_stats = analysis.get_input_token_stats( + token_stats, component_key, loaded.tokenizer, top_k + ) + output_stats = analysis.get_output_token_stats( + token_stats, component_key, loaded.tokenizer, top_k + ) + + if input_stats and output_stats: + result["token_stats"] = { + "input": { + "top_recall": input_stats.top_recall, + "top_precision": input_stats.top_precision, + "top_pmi": input_stats.top_pmi, + }, + "output": { + "top_recall": output_stats.top_recall, + "top_precision": output_stats.top_precision, + "top_pmi": output_stats.top_pmi, + "bottom_pmi": output_stats.bottom_pmi, + }, + } + else: + result["token_stats"] = None + + # Get correlations + correlations = loaded.harvest.correlations + if analysis.has_component(correlations, component_key): + result["correlated_components"] = { + "precision": [ + {"key": c.component_key, "score": c.score} + for c in analysis.get_correlated_components( + correlations, component_key, "precision", top_k + ) + ], + "pmi": [ + {"key": c.component_key, "score": c.score} + for c in analysis.get_correlated_components( + correlations, component_key, "pmi", top_k + ) + ], + } + else: + result["correlated_components"] = None + + return result + + +def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: + """Run ablation with selected components.""" + manager, loaded = _get_state() + + text = params["text"] + selected_nodes = params["selected_nodes"] + top_k = params.get("top_k", 10) + + _log_event( + "tool_call", + f"run_ablation: '{text[:50]}...' with {len(selected_nodes)} nodes", + {"text": text, "n_nodes": len(selected_nodes)}, + ) + + token_ids = loaded.tokenizer.encode(text, add_special_tokens=False) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + # Parse node keys + active_nodes = [] + for key in selected_nodes: + parts = key.split(":") + if len(parts) != 3: + raise ValueError(f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')") + layer, seq_str, cidx_str = parts + if layer in ("wte", "output"): + raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") + active_nodes.append((layer, int(seq_str), int(cidx_str))) + + with manager.gpu_lock(): + result = compute_intervention_forward( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + top_k=top_k, + tokenizer=loaded.tokenizer, + ) + + predictions = [] + for pos_predictions in result.predictions_per_position: + pos_result = [] + for token, token_id, spd_prob, _logit, target_prob, _target_logit in pos_predictions: + pos_result.append( + { + "token": token, + "token_id": token_id, + "circuit_prob": round(spd_prob, 6), + "full_model_prob": round(target_prob, 6), + } + ) + predictions.append(pos_result) + + return { + "input_tokens": result.input_tokens, + "predictions_per_position": predictions, + "selected_nodes": selected_nodes, + } + + +def _tool_search_dataset(params: dict[str, Any]) -> dict[str, Any]: + """Search the SimpleStories dataset.""" + import time + + from datasets import Dataset, load_dataset + + query = params["query"] + limit = params.get("limit", 20) + search_query = query.lower() + + _log_event("tool_call", f"search_dataset: '{query}'", {"query": query, "limit": limit}) + + start_time = time.time() + dataset = load_dataset("lennart-finke/SimpleStories", split="train") + assert isinstance(dataset, Dataset) + + filtered = dataset.filter( + lambda x: search_query in x["story"].lower(), + num_proc=4, + ) + + results = [] + for i, item in enumerate(filtered): + if i >= limit: + break + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + results.append( + { + "story": story[:500] + "..." if len(story) > 500 else story, + "occurrence_count": story.lower().count(search_query), + } + ) + + return { + "query": query, + "total_matches": len(filtered), + "returned": len(results), + "search_time_seconds": round(time.time() - start_time, 2), + "results": results, + } + + +def _tool_create_prompt(params: dict[str, Any]) -> dict[str, Any]: + """Create a prompt from text.""" + manager, loaded = _get_state() + + text = params["text"] + + _log_event("tool_call", f"create_prompt: '{text[:50]}...'", {"text": text}) + + token_ids = loaded.tokenizer.encode(text, add_special_tokens=False) + if not token_ids: + raise ValueError("Text produced no tokens") + + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Compute next token probs + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.token_strings[t] for t in token_ids] + + return { + "prompt_id": prompt_id, + "text": text, + "tokens": token_strings, + "token_ids": token_ids, + "next_token_probs": next_token_probs, + } + + +def _tool_update_research_log(params: dict[str, Any]) -> dict[str, Any]: + """Append content to the research log.""" + if _task_dir is None: + raise ValueError("Research log not available - not running in swarm mode") + + content = params["content"] + research_log_path = _task_dir / "research_log.md" + + _log_event( + "tool_call", f"update_research_log: {len(content)} chars", {"preview": content[:100]} + ) + + # Append content with a newline separator + with open(research_log_path, "a") as f: + f.write(content) + if not content.endswith("\n"): + f.write("\n") + + return {"status": "ok", "path": str(research_log_path)} + + +def _tool_save_explanation(params: dict[str, Any]) -> dict[str, Any]: + """Save a behavior explanation to explanations.jsonl.""" + from spd.agent_swarm.schemas import BehaviorExplanation, ComponentInfo, Evidence + + if _task_dir is None: + raise ValueError("Explanations file not available - not running in swarm mode") + + _log_event( + "tool_call", + f"save_explanation: '{params['behavior_description'][:50]}...'", + {"prompt": params["subject_prompt"]}, + ) + + # Build components + components = [ + ComponentInfo( + component_key=c["component_key"], + role=c["role"], + interpretation=c.get("interpretation"), + ) + for c in params["components_involved"] + ] + + # Build evidence + evidence = [ + Evidence( + evidence_type=e["evidence_type"], + description=e["description"], + details=e.get("details", {}), + ) + for e in params.get("supporting_evidence", []) + ] + + explanation = BehaviorExplanation( + subject_prompt=params["subject_prompt"], + behavior_description=params["behavior_description"], + components_involved=components, + explanation=params["explanation"], + supporting_evidence=evidence, + confidence=params["confidence"], + alternative_hypotheses=params.get("alternative_hypotheses", []), + limitations=params.get("limitations", []), + ) + + explanations_path = _task_dir / "explanations.jsonl" + with open(explanations_path, "a") as f: + f.write(explanation.model_dump_json() + "\n") + + _log_event( + "explanation", + f"Saved explanation: {params['behavior_description']}", + {"confidence": params["confidence"], "n_components": len(components)}, + ) + + return {"status": "ok", "path": str(explanations_path)} + + +def _tool_submit_suggestion(params: dict[str, Any]) -> dict[str, Any]: + """Submit a suggestion for system improvement.""" + if _suggestions_path is None: + raise ValueError("Suggestions not available - not running in swarm mode") + + suggestion = { + "timestamp": datetime.now(UTC).isoformat(), + "category": params["category"], + "title": params["title"], + "description": params["description"], + "context": params.get("context"), + } + + _log_event( + "tool_call", + f"submit_suggestion: [{params['category']}] {params['title']}", + suggestion, + ) + + # Ensure parent directory exists + _suggestions_path.parent.mkdir(parents=True, exist_ok=True) + + with open(_suggestions_path, "a") as f: + f.write(json.dumps(suggestion) + "\n") + + return {"status": "ok", "message": "Suggestion recorded. Thank you!"} + + +def _tool_set_investigation_summary(params: dict[str, Any]) -> dict[str, Any]: + """Set the investigation title and summary.""" + if _task_dir is None: + raise ValueError("Summary not available - not running in swarm mode") + + summary = { + "title": params["title"], + "summary": params["summary"], + "status": params.get("status", "in_progress"), + "updated_at": datetime.now(UTC).isoformat(), + } + + _log_event( + "tool_call", + f"set_investigation_summary: {params['title']}", + summary, + ) + + summary_path = _task_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + + return {"status": "ok", "path": str(summary_path)} + + +# ============================================================================= +# MCP Protocol Handler +# ============================================================================= + + +def _handle_initialize(_params: dict[str, Any] | None) -> dict[str, Any]: + """Handle initialize request.""" + return { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": "spd-app", "version": "1.0.0"}, + } + + +def _handle_tools_list() -> dict[str, Any]: + """Handle tools/list request.""" + return {"tools": [t.model_dump() for t in TOOLS]} + + +def _handle_tools_call( + params: dict[str, Any], +) -> Generator[dict[str, Any]] | dict[str, Any]: + """Handle tools/call request. May return generator for streaming tools.""" + name = params.get("name") + arguments = params.get("arguments", {}) + + if name == "optimize_graph": + # This tool streams progress + return _tool_optimize_graph(arguments) + elif name == "get_component_info": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_get_component_info(arguments), indent=2)} + ] + } + elif name == "run_ablation": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_run_ablation(arguments), indent=2)} + ] + } + elif name == "search_dataset": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_search_dataset(arguments), indent=2)} + ] + } + elif name == "create_prompt": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_create_prompt(arguments), indent=2)} + ] + } + elif name == "update_research_log": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_update_research_log(arguments), indent=2)} + ] + } + elif name == "save_explanation": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_save_explanation(arguments), indent=2)} + ] + } + elif name == "submit_suggestion": + return { + "content": [ + {"type": "text", "text": json.dumps(_tool_submit_suggestion(arguments), indent=2)} + ] + } + elif name == "set_investigation_summary": + return { + "content": [ + { + "type": "text", + "text": json.dumps(_tool_set_investigation_summary(arguments), indent=2), + } + ] + } + else: + raise ValueError(f"Unknown tool: {name}") + + +@router.post("/mcp") +async def mcp_endpoint(request: Request): + """MCP JSON-RPC endpoint. + + Handles initialize, tools/list, and tools/call methods. + Returns SSE stream for streaming tools, JSON for others. + """ + try: + body = await request.json() + mcp_request = MCPRequest(**body) + except Exception as e: + return JSONResponse( + status_code=400, + content=MCPResponse( + id=None, error={"code": -32700, "message": f"Parse error: {e}"} + ).model_dump(), + ) + + logger.info(f"[MCP] {mcp_request.method} (id={mcp_request.id})") + + try: + if mcp_request.method == "initialize": + result = _handle_initialize(mcp_request.params) + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(), + headers={"Mcp-Session-Id": "spd-session"}, + ) + + elif mcp_request.method == "notifications/initialized": + # Client confirms initialization + return JSONResponse(status_code=202, content={}) + + elif mcp_request.method == "tools/list": + result = _handle_tools_list() + return JSONResponse(content=MCPResponse(id=mcp_request.id, result=result).model_dump()) + + elif mcp_request.method == "tools/call": + if mcp_request.params is None: + raise ValueError("tools/call requires params") + + result = _handle_tools_call(mcp_request.params) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + # Streaming response via SSE + gen = result # Capture for closure + + def generate_sse() -> Generator[str]: + try: + final_result = None + for event in gen: + if event.get("type") == "progress": + # Send progress notification + progress_msg = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": event, + } + yield f"data: {json.dumps(progress_msg)}\n\n" + elif event.get("type") == "result": + final_result = event["data"] + + # Send final response + response = MCPResponse( + id=mcp_request.id, + result={ + "content": [ + {"type": "text", "text": json.dumps(final_result, indent=2)} + ] + }, + ) + yield f"data: {json.dumps(response.model_dump())}\n\n" + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Tool error: {e}\n{tb}") + error_response = MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ) + yield f"data: {json.dumps(error_response.model_dump())}\n\n" + + return StreamingResponse(generate_sse(), media_type="text/event-stream") + + else: + # Non-streaming response + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump() + ) + + else: + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {mcp_request.method}"}, + ).model_dump() + ) + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Error handling {mcp_request.method}: {e}\n{tb}") + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ).model_dump() + ) diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 45f5d9afb..68316bb69 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -34,6 +34,8 @@ dataset_search_router, graphs_router, intervention_router, + investigations_router, + mcp_router, prompts_router, runs_router, ) @@ -47,6 +49,15 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" + import os + from pathlib import Path + + from spd.app.backend.routers.mcp import ( + set_events_log_path, + set_suggestions_path, + set_task_dir, + ) + manager = StateManager.get() db = PromptAttrDB(check_same_thread=False) @@ -57,6 +68,22 @@ async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] logger.info(f"[STARTUP] Device: {DEVICE}") logger.info(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") + # Configure MCP for agent swarm mode + mcp_events_path = os.environ.get("SPD_MCP_EVENTS_PATH") + if mcp_events_path: + set_events_log_path(Path(mcp_events_path)) + logger.info(f"[STARTUP] MCP events logging to: {mcp_events_path}") + + mcp_task_dir = os.environ.get("SPD_MCP_TASK_DIR") + if mcp_task_dir: + set_task_dir(Path(mcp_task_dir)) + logger.info(f"[STARTUP] MCP task dir: {mcp_task_dir}") + + mcp_suggestions_path = os.environ.get("SPD_MCP_SUGGESTIONS_PATH") + if mcp_suggestions_path: + set_suggestions_path(Path(mcp_suggestions_path)) + logger.info(f"[STARTUP] MCP suggestions file: {mcp_suggestions_path}") + yield manager.close() @@ -157,6 +184,8 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(dataset_attributions_router) app.include_router(agents_router) app.include_router(component_data_router) +app.include_router(investigations_router) +app.include_router(mcp_router) def cli(port: int = 8000) -> None: diff --git a/spd/app/frontend/src/components/InvestigationsTab.svelte b/spd/app/frontend/src/components/InvestigationsTab.svelte new file mode 100644 index 000000000..e9b4a7cb8 --- /dev/null +++ b/spd/app/frontend/src/components/InvestigationsTab.svelte @@ -0,0 +1,497 @@ + + +
+ {#if selected?.status === "loaded"} + +
+ +

{selected.data.title || formatId(selected.data.id)}

+ {#if selected.data.status} + + {selected.data.status} + + {/if} +
+ + {#if selected.data.summary} +

{selected.data.summary}

+ {/if} + + +

+ {formatId(selected.data.id)} · Started {formatDate(selected.data.created_at)} + {#if selected.data.wandb_path} + · {selected.data.wandb_path} + {/if} +

+ +
+ + +
+ +
+ {#if activeTab === "research"} +
+ {#if selected.data.research_log} +
{selected.data.research_log}
+ {:else} +

No research log available

+ {/if} +
+ {:else} +
+ {#each selected.data.events as event, i (i)} +
+ + {event.event_type} + + {formatDate(event.timestamp)} + {event.message} + {#if event.details && Object.keys(event.details).length > 0} +
+ Details +
{JSON.stringify(event.details, null, 2)}
+
+ {/if} +
+ {:else} +

No events recorded

+ {/each} +
+ {/if} +
+ {:else if selected?.status === "loading"} +
Loading investigation...
+ {:else} + +
+

Investigations

+ +
+ + {#if investigations.status === "loading"} +
Loading investigations...
+ {:else if investigations.status === "error"} +
{investigations.error}
+ {:else if investigations.status === "loaded"} +
+ {#each investigations.data as inv (inv.id)} + + {:else} +

No investigations found. Run spd-swarm to create one.

+ {/each} +
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/RunView.svelte b/spd/app/frontend/src/components/RunView.svelte index 734c9657d..06fd2ebbe 100644 --- a/spd/app/frontend/src/components/RunView.svelte +++ b/spd/app/frontend/src/components/RunView.svelte @@ -3,13 +3,14 @@ import { RUN_KEY, type RunContext } from "../lib/useRun.svelte"; import ClusterPathInput from "./ClusterPathInput.svelte"; import DatasetExplorerTab from "./DatasetExplorerTab.svelte"; + import InvestigationsTab from "./InvestigationsTab.svelte"; import PromptAttributionsTab from "./PromptAttributionsTab.svelte"; import DisplaySettingsDropdown from "./ui/DisplaySettingsDropdown.svelte"; import ActivationContextsTab from "./ActivationContextsTab.svelte"; const runState = getContext(RUN_KEY); - let activeTab = $state<"prompts" | "components" | "dataset-search" | null>(null); + let activeTab = $state<"prompts" | "components" | "dataset-search" | "investigations" | null>(null); let showRunMenu = $state(false); @@ -32,6 +33,14 @@ {/if} @@ -150,6 +167,11 @@
+ {#if runState.clusterMapping} +
+ +
+ {/if} {:else if runState.run.status === "loading" || runState.prompts.status === "loading"}

Loading run...

diff --git a/spd/app/frontend/src/lib/registry.ts b/spd/app/frontend/src/lib/registry.ts index 218193657..731de6f32 100644 --- a/spd/app/frontend/src/lib/registry.ts +++ b/spd/app/frontend/src/lib/registry.ts @@ -23,7 +23,13 @@ const DEFAULT_ENTITY_PROJECT = "goodfire/spd"; export const CANONICAL_RUNS: RegistryEntry[] = [ { wandbRunId: "goodfire/spd/s-55ea3f9b", - notes: "Primary canonical run candidate", + notes: "Jose. pile_llama_simple_mlp-4L", + clusterMappings: [ + { + path: "/mnt/polished-lake/artifacts/mechanisms/spd/clustering/runs/c-70b28465/cluster_mapping.json", + notes: "All layers, 9100 iterations", + }, + ], }, { wandbRunId: "goodfire/spd/s-275c8f21", From 85929e58107f821863bfea8606a235421834de3e Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 2 Mar 2026 11:46:39 +0000 Subject: [PATCH 075/102] Autointerp lazy loading, worker pool, shared cost tracking, postprocess improvements Autointerp: - Lazy component loading: use get_summary() + per-component get_component() instead of loading all 38k components upfront (was 24min, now instant start) - Worker pool concurrency in map_llm_calls: bounded job queue replaces semaphore+done_callback+active set. try/finally guarantees sentinel delivery. - CostTracker is now public; graph_interp shares one across all 3 passes - Resume support: --autointerp_subrun_id to continue from existing subrun - Add Pile dataset description to compact_skeptical strategy Postprocess: - Add graph_interp to PostprocessConfig (with attributions validation) - Add --dependency flag to spd-postprocess CLI (argparse, not fire) - Thread dependency_job_id through submit_harvest - Add s-82ffb969 postprocess config Other: - Harvest config: activation_examples_per_component 1000 -> 400 - Harvest DB: remove stale debug logging - App TODO: remove resolved SQLite immutable audit item - App + harvest + autointerp repo changes (pre-existing) Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/TODO.md | 1 - spd/app/backend/routers/runs.py | 2 + .../src/components/DataSourcesTab.svelte | 317 ++++++++++-------- .../src/components/ModelGraphTab.svelte | 53 --- .../frontend/src/components/RunView.svelte | 28 +- .../investigations/ArtifactGraph.svelte | 18 +- spd/app/frontend/src/lib/api/correlations.ts | 2 +- spd/app/frontend/src/lib/api/runs.ts | 1 + spd/app/frontend/src/lib/useRun.svelte.ts | 3 + spd/app/frontend/vite.config.ts | 2 +- spd/app/run_app.py | 2 +- spd/autointerp/config.py | 6 +- spd/autointerp/interpret.py | 32 +- spd/autointerp/llm_api.py | 96 +++--- spd/autointerp/repo.py | 2 + spd/autointerp/scripts/run_interpret.py | 18 +- spd/graph_interp/interpret.py | 7 +- spd/harvest/config.py | 4 +- spd/harvest/db.py | 3 - spd/harvest/repo.py | 2 + spd/postprocess/config.py | 7 + spd/postprocess/s-82ffb969.yaml | 71 ++++ 22 files changed, 371 insertions(+), 306 deletions(-) delete mode 100644 spd/app/frontend/src/components/ModelGraphTab.svelte create mode 100644 spd/postprocess/s-82ffb969.yaml diff --git a/spd/app/TODO.md b/spd/app/TODO.md index a3c7eb1aa..f7658851f 100644 --- a/spd/app/TODO.md +++ b/spd/app/TODO.md @@ -1,3 +1,2 @@ # App TODOs -- Audit SQLite access pragma stuff — `immutable=1` in `HarvestDB` causes "database disk image is malformed" errors when the app reads a harvest DB mid-write (WAL not yet checkpointed). Investigate whether to check for WAL file existence, use normal locking mode, or add another safeguard. See `spd/harvest/db.py:79`. diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index daac71373..b0e323cc2 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -45,6 +45,7 @@ class LoadedRun(BaseModel): dataset_attributions_available: bool dataset_search_enabled: bool graph_interp_available: bool + autointerp_available: bool router = APIRouter(prefix="/api", tags=["runs"]) @@ -172,6 +173,7 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: # TODO(oli): Remove MOCK_MODE import after real data available graph_interp_available=manager.run_state.graph_interp is not None or _GRAPH_INTERP_MOCK_MODE, + autointerp_available=manager.run_state.interp is not None, ) diff --git a/spd/app/frontend/src/components/DataSourcesTab.svelte b/spd/app/frontend/src/components/DataSourcesTab.svelte index 6dd4d31c1..c54a1fa0a 100644 --- a/spd/app/frontend/src/components/DataSourcesTab.svelte +++ b/spd/app/frontend/src/components/DataSourcesTab.svelte @@ -29,7 +29,7 @@ }); function formatConfigValue(value: unknown): string { - if (value === null || value === undefined) return "—"; + if (value === null || value === undefined) return "\u2014"; if (typeof value === "object") return JSON.stringify(value); return String(value); } @@ -50,115 +50,149 @@ } -
- {#if runState.run.status === "loaded" && runState.run.data.config_yaml} -
-

Run Config

-
{runState.run.data.config_yaml}
-
- {/if} - - - {#if pretrainData.status === "loading"} -
-

Target Model

-

Loading target model info...

-
- {:else if pretrainData.status === "loaded"} - {@const pt = pretrainData.data} -
-

Target Model

-
- Architecture - {pt.summary} - - {#if pt.pretrain_wandb_path} - Pretrain run - {pt.pretrain_wandb_path} - {/if} -
+
+ +
+ {#if runState.run.status === "loaded" && runState.run.data.config_yaml} +
+

Run Config

+
{runState.run.data.config_yaml}
+
+ {/if} - {#if pt.topology} -
-

Topology

- +
+

Target Model

+ {#if pretrainData.status === "loading"} +

Loading...

+ {:else if pretrainData.status === "loaded"} + {@const pt = pretrainData.data} +
+ Architecture + {pt.summary} + {#if pt.pretrain_wandb_path} + Pretrain run + {pt.pretrain_wandb_path} + {/if}
+ {#if pt.topology} +
+ +
+ {/if} + {#if pt.pretrain_config} +
+ Pretraining config +
{formatPretrainConfigYaml(pt.pretrain_config)}
+
+ {/if} + {:else if pretrainData.status === "error"} +

Failed to load target model info

{/if} - - {#if pt.pretrain_config} -
- Pretraining config -
{formatPretrainConfigYaml(pt.pretrain_config)}
-
- {/if} -
- {:else if pretrainData.status === "error"} -
-

Target Model

-

Failed to load target model info

- {/if} - - {#if data.status === "loading"} -

Loading data sources...

- {:else if data.status === "error"} -

Failed to load data sources: {data.error}

- {:else if data.status === "loaded"} - {@const { harvest, autointerp, attributions, graph_interp } = data.data} - - {#if !harvest && !autointerp && !attributions && !graph_interp} -

No pipeline data available for this run.

- {/if} - - {#if harvest} -
-

Harvest

+
+ + +
+ +
+
+ +

Harvest

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.harvest} + {@const harvest = data.data.harvest}
Subrun {harvest.subrun_id} - Components {harvest.n_components.toLocaleString()} - Intruder eval {harvest.has_intruder_scores ? "yes" : "no"} - {#each Object.entries(harvest.config) as [key, value] (key)} {key} {formatConfigValue(value)} {/each}
-
- {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
- {#if attributions} -
-

Dataset Attributions

+ +
+
+ +

Autointerp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.autointerp} + {@const autointerp = data.data.autointerp}
Subrun - {attributions.subrun_id} + {autointerp.subrun_id} + Interpretations + {autointerp.n_interpretations.toLocaleString()} + Eval scores + + {#if autointerp.eval_scores.length > 0} + {autointerp.eval_scores.join(", ")} + {:else} + none + {/if} + + {#each Object.entries(autointerp.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ +
+
+ +

Dataset Attributions

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.attributions} + {@const attributions = data.data.attributions} +
+ Subrun + {attributions.subrun_id} Tokens {attributions.n_tokens_processed.toLocaleString()} - CI threshold {attributions.ci_threshold}
-
- {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
- {#if graph_interp} -
-

Graph Interp

+ +
+
+ +

Graph Interp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.graph_interp} + {@const graph_interp = data.data.graph_interp}
Subrun {graph_interp.subrun_id} - {#each Object.entries(graph_interp.label_counts) as [key, value] (key)} {key} labels {value.toLocaleString()} {/each} - {#if graph_interp.config} {#each Object.entries(graph_interp.config) as [key, value] (key)} {key} @@ -166,86 +200,99 @@ {/each} {/if}
-
- {/if} - - {#if autointerp} -
-

Autointerp

-
- Subrun - {autointerp.subrun_id} - - Interpretations - {autointerp.n_interpretations.toLocaleString()} - - Eval scores - - {#if autointerp.eval_scores.length > 0} - {autointerp.eval_scores.join(", ")} - {:else} - none - {/if} - - - {#each Object.entries(autointerp.config) as [key, value] (key)} - {key} - {formatConfigValue(value)} - {/each} -
-
- {/if} - {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
+
diff --git a/spd/app/frontend/src/components/ModelGraphTab.svelte b/spd/app/frontend/src/components/ModelGraphTab.svelte deleted file mode 100644 index 2e95e315d..000000000 --- a/spd/app/frontend/src/components/ModelGraphTab.svelte +++ /dev/null @@ -1,53 +0,0 @@ - - -
- {#if data.status === "loading"} -
Loading model graph...
- {:else if data.status === "error"} -
Failed to load graph: {String(data.error)}
- {:else if data.status === "loaded"} - - {:else} -
Initializing...
- {/if} -
- - diff --git a/spd/app/frontend/src/components/RunView.svelte b/spd/app/frontend/src/components/RunView.svelte index 717977274..b505c576d 100644 --- a/spd/app/frontend/src/components/RunView.svelte +++ b/spd/app/frontend/src/components/RunView.svelte @@ -7,7 +7,6 @@ import DatasetExplorerTab from "./DatasetExplorerTab.svelte"; import InvestigationsTab from "./InvestigationsTab.svelte"; import DataSourcesTab from "./DataSourcesTab.svelte"; - import ModelGraphTab from "./ModelGraphTab.svelte"; import PromptAttributionsTab from "./PromptAttributionsTab.svelte"; import DisplaySettingsDropdown from "./ui/DisplaySettingsDropdown.svelte"; @@ -17,8 +16,6 @@ runState.run?.status === "loaded" && runState.run.data.dataset_search_enabled, ); - const graphInterpAvailable = $derived(runState.graphInterpAvailable); - let activeTab = $state< "prompts" | "components" | "dataset-search" | "model-graph" | "data-sources" | "investigations" | "clusters" | null >(null); @@ -61,14 +58,6 @@ > Investigations - {#if runState.run.status === "loaded" && runState.run.data} - {/if} - {#if graphInterpAvailable} - {/if} diff --git a/spd/app/frontend/src/components/prompt-attr/types.ts b/spd/app/frontend/src/components/prompt-attr/types.ts index 61ac6d946..cbe1d59ee 100644 --- a/spd/app/frontend/src/components/prompt-attr/types.ts +++ b/spd/app/frontend/src/components/prompt-attr/types.ts @@ -23,12 +23,6 @@ export type StoredGraph = { interventionRuns: InterventionRunSummary[]; }; -/** Transient UI state for the intervention composer, keyed by graph ID */ -export type ComposerState = { - selection: Set; // currently selected node keys - activeRunId: number | null; // which run is selected (for restoring selection) -}; - export type PromptCard = { id: number; // database prompt ID tokens: string[]; diff --git a/spd/app/frontend/src/lib/api/index.ts b/spd/app/frontend/src/lib/api/index.ts index 88187c5d2..d2d810283 100644 --- a/spd/app/frontend/src/lib/api/index.ts +++ b/spd/app/frontend/src/lib/api/index.ts @@ -55,3 +55,4 @@ export * from "./investigations"; export * from "./dataSources"; export * from "./graphInterp"; export * from "./pretrainInfo"; +export * from "./runRegistry"; diff --git a/spd/app/frontend/src/lib/api/intervention.ts b/spd/app/frontend/src/lib/api/intervention.ts index 689c29cc1..dfd76783d 100644 --- a/spd/app/frontend/src/lib/api/intervention.ts +++ b/spd/app/frontend/src/lib/api/intervention.ts @@ -66,3 +66,4 @@ export async function deleteForkedInterventionRun(forkId: number): Promise throw new Error(error.detail || "Failed to delete forked intervention run"); } } + diff --git a/spd/app/frontend/src/lib/api/runRegistry.ts b/spd/app/frontend/src/lib/api/runRegistry.ts new file mode 100644 index 000000000..a55ab5a3d --- /dev/null +++ b/spd/app/frontend/src/lib/api/runRegistry.ts @@ -0,0 +1,24 @@ +/** + * API client for /api/run_registry endpoint. + */ + +import { fetchJson } from "./index"; + +export type DataAvailability = { + harvest: boolean; + autointerp: boolean; + attributions: boolean; + graph_interp: boolean; +}; + +export type RegistryRunInfo = { + wandb_run_id: string; + name: string | null; + notes: string | null; + architecture: string | null; + availability: DataAvailability; +}; + +export async function fetchRunRegistry(): Promise { + return fetchJson("/api/run_registry"); +} diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index c364be243..821a46ee0 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -33,6 +33,7 @@ export type InterventionRunSummary = { id: number; selected_nodes: string[]; // node keys (layer:seq:cIdx) result: InterventionResponse; + masked_predictions: MaskedPredictions; created_at: string; forked_runs?: ForkedInterventionRunSummary[]; // child runs with modified tokens }; @@ -42,5 +43,96 @@ export type RunInterventionRequest = { graph_id: number; text: string; selected_nodes: string[]; - top_k?: number; + top_k: number; + adv_pgd_n_steps: number; + adv_pgd_step_size: number; }; + +export type TokenPred = { + token: string; + prob: number; +}; + +export type MaskedPredictions = { + ci: TokenPred[][]; + stochastic: TokenPred[][]; + adversarial: TokenPred[][]; +}; + +// --- Frontend-only run lifecycle types --- + +import { SvelteSet } from "svelte/reactivity"; +import { isInterventableNode } from "./promptAttributionsTypes"; + +/** Base run: synthesized from the graph's own data. All interventable nodes selected. Not editable. */ +export type BaseRun = { + kind: "base"; +}; + +/** Draft run: cloned from a parent, editable node selection. No forwarded results yet. */ +export type DraftRun = { + kind: "draft"; + parentId: "base" | number; + selectedNodes: SvelteSet; +}; + +/** Baked run: forwarded and immutable. Wraps a persisted InterventionRunSummary. */ +export type BakedRun = { + kind: "baked"; + id: number; + selectedNodes: Set; + result: InterventionResponse; + maskedPredictions: MaskedPredictions; + createdAt: string; +}; + +export type InterventionRun = BaseRun | DraftRun | BakedRun; + +export type InterventionState = { + runs: InterventionRun[]; + activeIndex: number; +}; + +/** Get the effective node selection for a run */ +export function getRunSelection(run: InterventionRun, allInterventableNodes: Set): Set { + switch (run.kind) { + case "base": + return allInterventableNodes; + case "draft": + return run.selectedNodes; + case "baked": + return run.selectedNodes; + } +} + +/** Whether a run's selection is editable */ +export function isRunEditable(run: InterventionRun): run is DraftRun { + return run.kind === "draft"; +} + +/** Build initial InterventionState from persisted runs */ +export function buildInterventionState(persistedRuns: InterventionRunSummary[]): InterventionState { + const runs: InterventionRun[] = [ + { kind: "base" }, + ...persistedRuns.map( + (r): BakedRun => ({ + kind: "baked", + id: r.id, + selectedNodes: new Set(r.selected_nodes), + result: r.result, + maskedPredictions: r.masked_predictions, + createdAt: r.created_at, + }), + ), + ]; + return { runs, activeIndex: 0 }; +} + +/** Get all interventable node keys from a nodeCiVals record */ +export function getInterventableNodes(nodeCiVals: Record): Set { + const nodes = new Set(); + for (const nodeKey of Object.keys(nodeCiVals)) { + if (isInterventableNode(nodeKey)) nodes.add(nodeKey); + } + return nodes; +} diff --git a/spd/app/frontend/src/lib/registry.ts b/spd/app/frontend/src/lib/registry.ts index a2fa17e85..dbe9e2841 100644 --- a/spd/app/frontend/src/lib/registry.ts +++ b/spd/app/frontend/src/lib/registry.ts @@ -1,39 +1,33 @@ /** * Registry of canonical SPD runs for quick access in the app. + * + * Static data renders instantly; availability + architecture are hydrated + * lazily from the backend via /api/run_registry. */ export type RegistryEntry = { - /** Full wandb run id (e.g., "goodfire/spd/jyo9duz5") */ wandbRunId: string; - /** Optional notes about the run */ + name?: string; notes?: string; - /** Optional cluster mappings for the run */ - clusterMappings?: { - path: string; - notes: string; - }[]; }; const DEFAULT_ENTITY_PROJECT = "goodfire/spd"; -/** - * Canonical runs registry - add new entries here. - * These appear in the dropdown for quick selection. - */ export const CANONICAL_RUNS: RegistryEntry[] = [ { - wandbRunId: "goodfire/spd/runs/s-82ffb969", - notes: "Thomas - pile_llama_simple_mlp-4L", + name: "Thomas", + wandbRunId: "goodfire/spd/s-82ffb969", + notes: "pile_llama_simple_mlp-4L", }, { + name: "Jose", wandbRunId: "goodfire/spd/s-55ea3f9b", - notes: "Jose - pile_llama_simple_mlp-4L", - clusterMappings: [ - { - path: "/mnt/polished-lake/artifacts/mechanisms/spd/clustering/runs/c-70b28465/cluster_mapping.json", - notes: "All layers, 9100 iterations", - }, - ], + notes: "pile_llama_simple_mlp-4L", + }, + { + name: "finetune", + wandbRunId: "goodfire/spd/s-17805b61", + notes: "finetune", }, { wandbRunId: "goodfire/spd/s-275c8f21", @@ -47,10 +41,6 @@ export const CANONICAL_RUNS: RegistryEntry[] = [ wandbRunId: "goodfire/spd/s-892f140b", notes: "Lucius run, Jan 22", }, - { - wandbRunId: "goodfire/spd/s-7884efcc", - notes: "Lucius' new run, Jan 8", - }, ]; /** @@ -60,7 +50,6 @@ export const CANONICAL_RUNS: RegistryEntry[] = [ */ export function formatRunIdForDisplay(wandbRunId: string): string { if (wandbRunId.startsWith(`${DEFAULT_ENTITY_PROJECT}/`)) { - // Extract just the run id (last segment) const parts = wandbRunId.split("/"); return parts[parts.length - 1]; } diff --git a/spd/graph_interp/db.py b/spd/graph_interp/db.py index 0af8c796f..052b06f26 100644 --- a/spd/graph_interp/db.py +++ b/spd/graph_interp/db.py @@ -5,6 +5,8 @@ from spd.graph_interp.schemas import LabelResult, PromptEdge +DONE_MARKER = ".done" + _SCHEMA = """\ CREATE TABLE IF NOT EXISTS output_labels ( component_key TEXT PRIMARY KEY, @@ -60,8 +62,12 @@ def __init__(self, db_path: Path, readonly: bool = False) -> None: self._conn = sqlite3.connect(str(db_path), check_same_thread=False) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.executescript(_SCHEMA) + self._db_path = db_path self._conn.row_factory = sqlite3.Row + def mark_done(self) -> None: + (self._db_path.parent / DONE_MARKER).touch() + # -- Output labels --------------------------------------------------------- def save_output_label(self, result: LabelResult) -> None: diff --git a/spd/graph_interp/interpret.py b/spd/graph_interp/interpret.py index 6ded42200..e10b1ccd2 100644 --- a/spd/graph_interp/interpret.py +++ b/spd/graph_interp/interpret.py @@ -301,6 +301,7 @@ async def _run() -> None: f"{db.get_label_count('input_labels')} input, " f"{db.get_label_count('unified_labels')} unified labels -> {db_path}" ) + db.mark_done() try: asyncio.run(_run()) diff --git a/spd/graph_interp/repo.py b/spd/graph_interp/repo.py index 9906ff138..6667c4e1e 100644 --- a/spd/graph_interp/repo.py +++ b/spd/graph_interp/repo.py @@ -11,7 +11,7 @@ import yaml -from spd.graph_interp.db import GraphInterpDB +from spd.graph_interp.db import DONE_MARKER, GraphInterpDB from spd.graph_interp.schemas import LabelResult, PromptEdge, get_graph_interp_dir @@ -31,7 +31,11 @@ def open(cls, run_id: str) -> "GraphInterpRepo | None": if not base_dir.exists(): return None candidates = sorted( - [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("ti-")], + [ + d + for d in base_dir.iterdir() + if d.is_dir() and d.name.startswith("ti-") and (d / DONE_MARKER).exists() + ], key=lambda d: d.name, ) if not candidates: diff --git a/spd/harvest/intruder.py b/spd/harvest/intruder.py index f91a5e0c2..4fb9b2052 100644 --- a/spd/harvest/intruder.py +++ b/spd/harvest/intruder.py @@ -175,7 +175,9 @@ async def run_intruder_scoring( jobs: list[LLMJob] = [] ground_truth: dict[str, _TrialGroundTruth] = {} - for component in remaining: + for i, component in enumerate(remaining): + if i > 0 and i % 1000 == 0: + logger.info(f"Building trials: {i}/{len(remaining)} components") for trial_idx in range(n_trials): real_examples = rng.sample(component.activation_examples, n_real) intruder = _sample_intruder(component, density_index, rng, density_tolerance) @@ -194,6 +196,7 @@ async def run_intruder_scoring( component_key=component.component_key, correct_answer=correct_answer, ) + logger.info(f"Built {len(jobs)} trials") component_trials: defaultdict[str, list[IntruderTrial]] = defaultdict(list) component_errors: defaultdict[str, int] = defaultdict(int) diff --git a/spd/harvest/scripts/run_intruder.py b/spd/harvest/scripts/run_intruder.py index 4dbef810a..1870d1a06 100644 --- a/spd/harvest/scripts/run_intruder.py +++ b/spd/harvest/scripts/run_intruder.py @@ -9,6 +9,7 @@ from spd.harvest.db import HarvestDB from spd.harvest.intruder import run_intruder_scoring from spd.harvest.repo import HarvestRepo +from spd.log import logger def main( @@ -28,7 +29,9 @@ def main( harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=True) score_db = HarvestDB(harvest._dir / "harvest.db") + logger.info("Loading components from harvest DB...") components = harvest.get_all_components() + logger.info(f"Loaded {len(components)} components") asyncio.run( run_intruder_scoring( From 2d0cc5d01b8ba37c0c525a124bdbdb88dbc696b1 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 10:07:19 +0000 Subject: [PATCH 081/102] wip: Add KL metrics to masked predictions and auto-save base intervention runs --- Screenshot 2026-03-02 at 21.13.33.png | Bin 0 -> 16141 bytes spd/app/CLAUDE.md | 15 +- spd/app/backend/compute.py | 152 ++++++--- spd/app/backend/database.py | 22 +- spd/app/backend/routers/graphs.py | 147 ++++++++- spd/app/backend/routers/intervention.py | 24 +- spd/app/backend/routers/run_registry.py | 69 +--- .../components/PromptAttributionsTab.svelte | 82 +---- .../src/components/RunSelector.svelte | 9 +- .../prompt-attr/InterventionsView.svelte | 295 +++++++----------- .../prompt-attr/OptimizationParams.svelte | 6 +- spd/app/frontend/src/lib/api/intervention.ts | 1 - spd/app/frontend/src/lib/api/runRegistry.ts | 12 +- spd/app/frontend/src/lib/interventionTypes.ts | 64 ++-- .../src/lib/promptAttributionsTypes.ts | 8 +- spd/app/frontend/src/lib/registry.ts | 10 +- spd/app/frontend/vite.config.ts | 2 +- spd/autointerp/db.py | 4 - .../autointerp-gemini-pro-s-55ea3f9b.yaml | 12 + .../autointerp-gemini-pro-s-82ffb969.yaml | 12 + 20 files changed, 501 insertions(+), 445 deletions(-) create mode 100644 Screenshot 2026-03-02 at 21.13.33.png create mode 100644 spd/postprocess/autointerp-gemini-pro-s-55ea3f9b.yaml create mode 100644 spd/postprocess/autointerp-gemini-pro-s-82ffb969.yaml diff --git a/Screenshot 2026-03-02 at 21.13.33.png b/Screenshot 2026-03-02 at 21.13.33.png new file mode 100644 index 0000000000000000000000000000000000000000..a15c2a3b4045e94b88976fb302ca49a52ba4bda7 GIT binary patch literal 16141 zcmeIZWmsIxwl0h%IKcu54naa_BxrC5?hcKo(V$ImmtY}{OB39KOMu`O8VL~GLa+pP zcfFIfvi9CVId(Q;mXNMsv{wx zfRK=osW9#XPYPL;M}W*LD+vizISC0GRY$0Wm904v($g51Xi-^|H^kkoH<+vs$;c?k z@RCum`5BT?c;2;r0j&?>r>(k3d*0MFzN;yqeI4mwAKe9auzFV@jDeM!JpHmRZ)@_> zDvw0iz3GAR+~vVC@=j(uxCcAa_Pd2E8*^z%nOqC0`%h^M-?N9~CSi(OCvu;LjB`8 z3KCMN6%yL-I?BNP?iUT5cWwT>qsE3Hp#wQW;B?PK`KvYxC=>OsG%^)Xh9s^bAtwji zHB24N&F!6DL!DFi=jVY3m=3bKPDn^340k88oI3p;K!3tY6YLCDQWP+S+Oa{*peE*Q z?sg7$bdZGH1%RZTxif^u-OkqDNx)s0?vEM*K>F@6I~~m*Rh(^v>A*^=G!jrpa~fVY z4mJ)t5iA-S8X-qB3juXW>EF$PoG{&MXJ-cic6K*6H#RpeHmIW|`*VJNes&H{c1})K zpa!dxhrKhzoz>py$)802q9bYUWa?<;;A{o8r@5mGF@d@`3)9ivG5W{vPdUxqt^Uc$ z-syL@fDW?Xm9Rf&<6!>>ZJ?>p-BSToD|d5ST}dlDKxY6C5zZI9LVwi%4<-NP_%AKN z|7^*_^S74&Qu6ODwVccyC7^ZyOJ|XP4%ctC|62IFp%D9B&;Laef2#S9r+}VCu!PwE zF=irIOf)@#z%Wu+Nvdc9cVLqJ@rwekPl5C94xE&ROuH+Yz@1r6Qe4v=c`FUm`{8(V z_pKZlrI#$0n-@HdrKR72Rdb?SOZT?&{nN@Y^y7zi&)SVB5pXrgtH+*`hpk*MY zieDmI`c`aOhhgb6oh%`Zf`WqV%UC|kh58kv9ftazG_L{0mdPmksu}*<^Bd(AZ!R~B zSoO`OAW}3MBxDRRBozFL4`NvDCT|4d@Bf+i_g6rIWIeen@<&3Wc_oY-)OqplDgK{X zB;=rAk>4%-{g2V`Nj9Hki~T_esI5Lx;u+T%w1)miQ=l4H2?_E^fa-l9K`j4h6KvNMqtf1db0ihoJe-f&M0AA{6wEHop)IL$(SkR)`?ANya%=PZT^yW+1sKv{P0Pd^uHxRY# zgSIg_Tod=B?Fh)Gfyc=M#K&W_%`Z~N_kec(zBKJ`D(%gFx(_uieukF;J1Tj}qGW+R?|rz#ug;I%Cf`HJlQF_aL=93Z{5gvAl&ciLjQB*D9V6Xymf!{TJ)fpitlLunWKkOeknf5XCfXR z9(7Gk`iBo6#?G-Bs?!DK)eH7sA~44UnobjU6Km|I_XM?$< zAF7iut!h&~*t)2==L=Co#QbYBya+!go0l|aXh+NB!d|BdHkSR^+YWQ02TN(CEaq}I zn^_)4-f_+lHP?^qmE(e?z6ecA9Y@EgJrn9F3sf{L*RA-H63{6~AuCS8{&eR4+!qse zjIHl25|x}>Jzucv-q>oaW{Oh1Uo0)b3J@bI@FS>gxsi&PC$F`SUAM7xZ`qY2C2?^z zUw>P_6NyW9i+x(HE*Iv+Dy79>sUs96=!v=kQb?Jwn+euJ$8&IH5MUf!Et=oaueFxK zOOF$AMimGxGvuH4daVw!vqC$1O!9gDVt|&%m)+>LY1L9kGesLi;mw;*6a|Iw0Yib} zWOs%nF58tNuuXv$Sv5W(o z8VIZSuw1S`iJF%z4dzZ*@*C%;0rgm_z+8z`Co?xvEXJNtj74w-h#<0|Vd|gOWNzp{ z7J|&}Ch%5!!w2!!WRPLa*jp`tFnxwF>qV4UW-x;u1%?9mv!0aD3i_zyEWHubJUw&> zKQS{jah(F61j=7^{0d(T{5B>BdV&J^^yHbaYLPWQiBr-VQ0nw}sL~IE26WT&3_pPm zCh6-F!^=jh$8*(LDk1}X6t1BFEVUg465{iNGNfw2PAkV1+^!=%xj-&mJ3sU>ribuOi7_fbe z8I&WjLdJiUFb!BIgF4*@CBfk`KS?st<#_`BO}#XrE-%J*6CspZB%phkL7fJS0_ec? zezh$Va=P=T&|D&VchYt$22A&+!3C&d6(pr0ccTOJGS>T~dQB81prw@mzgd$=Dh#6$6GF1@b&GYycU!BC}4*Awk5!;irm#CeZAWJu@|lFktnH z&UnBT=k=5_3I=w&$$yjj-A7p6FdzSC!{djsM?rFH_VArN0g0f=7;0Y8{eO}BG8CzO zA)^FE4Vc9;;egoVG@upEVPM{`aAum{0#oxUPK)N56nN{M+JpEavcHcVdxrD$<~yUx z9NWiEfxmXT$D2JJGDCrOA{|t8r|gPp)TTN}ru~OhVDN zKu@EQQA59m8_zR786YQld)kpB0Jo&EA6-TOTbxqxJ(@QT;S zCi;BhqSvt_`R!5F;8GvE+k3)91fHNGmG4@>5^=&$ZNL?u-40hPTGGaV-p zrexZTgiN_|2L<7F@}Gw0p2>E1C+?ZC(+szv(jW2@x&D7TnazI4d2seD{MiC{N^J zOH?9KJ0bh?2`DpK6ER8WUISOY?BH>#?yJn%QUy1%UwU3LMyA0T}#W}>J(_a>-zf6K`p#I+8kOM zJul+!SCG)7@lpP^>Z0eT_-fEf(oL5IM7P1VUg+eENL^c-vC(xGo-XWJS0`=Aq23im zP87lhuK+oTNUKvb2p-&A(l}QA4HyIx*JoZMod&I6FqZJRC|LFQ;R`jJ+z#>T{X;$g|t;_JFLX zZ~B==d8>AB-)SGJ6IE=+%@J|^YT5Lx%a{!DAt_tz!@fC;=GB=m2=lSvC+6*)!>E|0 zA3C76Bye>ZO{}!dFzqruR%}vXra>W|L55M*oD?$)SL_^O1uzMl9%zsa{B##&Os418 zmPzY498CH(kBAMgpX)k;>ZukbVBOsRv~eTsKFs+o)~b+M7Q{*#(`8ZbgSY+_6hZXf zU3sFd1|nCZGLAd^sLOm+79k`%KRzq?dU}f1pQbl1HT{y9Jy#LsGTuq?X|+9jomLH< zpKJbz0>PxUnc!n@ZIIk?3)%}jor|ca3}Rp;GsV`lXKy_y5~}^FYs%9x&GCqhK644jzkj2KBtFiPoHDB)2DpjQsY)}>8KPM zD7&9Vi_e(a&Q?SUZ(lHJM1nb#Erv>lI-=q=81w9X1wB%BcZw?Yl+VSwrQ?NFeu&;n z@XM!X_fZ1qqra|JP)se$dN>{talYB7pE7E=rZJKug$F)!uqdFcS!c_-;3=Em3k#?6 zdh_8!!Sd$K+PiOgi9LF&d^cNLpi@@;)Tq$=4}?{Wyp|dGcj`=YuoP|s0?kvMhS{hw z$Ma|QPQysxzLtU+36YfjGB0s+pkrhkmhLk_VUF8N#pYtSRWEnN{~?Yng#WxC&itZH z-+YeBT{Cg(P9vhe!rc26US78Xd=xeMyrPS~Q8xeq&^RonhS9O!_40M; z*>vnqJNK>!b>@*TTA+o``d7-YAdz);>3 zF1QmLLGAab_g6=pGF<=TY*0l|jP*swnF3>P_keD<{o}_F@+lo7wr=loCTh%JH3$mI z7usk*&el^5{?odEgTu~5nR`JMlAQ;~-@xtJ6OW!d#(S0a&brO9JxqA$INDz;pC3uN zyv~&2z*pgMm|?zD@SJn?X3DW^J-_!!lrWo$$Z?6{vG7-X?kHt=<&%qUpZ|il!Ld9KtcN_8HvK%Oq@*&_w!u+&sC>*%`!LIXM7$f zIrr2~g@JiiYXin(0O_YOl^MD86HnD_V;7bdCYVMEcQpssd!5eUZ#7!ZF%1e1Iu#Zc z#yroF>G|<_vrgfuWd8mEnW z^+@XA2A$d=QFU>j&sxdl@(kW2l*DX8(Pbb;hK!gD!PyUU zUB3^yVM$2?V(>{L7Q?HV~5k z5?%h-JOyc;+g##Ic?Fn4uF@v&4wV+=3!2BXoG8_9H&Y1BfW@#dv8WoI%g+T7id)1*I)r~$Q_gf32QwvOWpS89wK?Z#pQ)% zjOc@M)C~b&De#(k6=;VXF$%Elk8Przwymype(&LhMf*$?!AG5f9x0Q3nftBLPF>2e z+1Yf31qIPZuIrIaxf82jPB9r(^);Sc_MIwB*Vr)#3hJx3E*LI;b8{*vi3hsKf2ct? z!f$k~22-$g3a3}PSGM?|>Z0^b zjrUYlEtr?YKE5~kS@Y8@;|RHF!R`l)56<*ta@A;z`8_c88fFf|`HbsbtDOrTiZ z4RqmIvNP3KZXjrr9x1T3d2@X!b&U_#$Ef`Bxl0BIWYJE z8{JW2Lk7LC$l3W%cK`$1c^zH1uio?i;)SeX7i7U)6Fi0|Bae-)hgV%>de;}6WNbM#* zlEfIl0ZXLzm(+JqG*UyJaN=k+X1Kn6F%`wnI{mELfVOX<^TE#V`Qr9RyJNw8WJ(of%+RNqe$=jAO z*tJ>8a}1p80h)EUqok3;g7h7-d-v|`{QS<4V-xPiF^!Fq*U6B5zdtpV!tF3PI5-v# zXPo71^cbd?4i|qlxFf7?XvjK$*2QO1^2HWIj3%~i;BA^ZE03f5DqakXP{1y@q9Dqj z*M641_I7_KXorH_h9cnhyoY=JdUkPk{@J)#QrzXS@fZ6U5-URkhUpQEk7-Tjtxhy7fMZVoVo+%ZhDub|17^oJ4EWM!!I@ zzPeHA6d4>WD|lau4_{a?(XE4g0^MG1cc_^uL@;vf6;zoUWDM_Ipj#K^IVBE6__DJw zM*anZ@qt-B$Hv6O)J0U~NfW7d=Mb-op7D`AW~q6Fvay=iWZgzncHv2IeuZkpr&Hh~(0MSHJ z6tq(V_>SKa@rP$Xu#gRueMP=2E80T`=u-3F`*=s+=1%j$16q{Uwja?F*XWeW$AR$muwXL?Rf7n3s4|(NRqt6CKv!0;?%H212&1! z2Qj|}ZbLn1-4WdDQ&ph`yzA;g%c_}VRV5mpOdb~Bgz7I6m@5(dqYK@nKxl34OAJpa z*6u6ct<#t&F`zBg`&@}?##Jl^PR5L`t$m83^bmeJ(Zqf_kjh%#dhYnmRL1w_wx`pS zsrSHbE;^D@kPZ~KraPJL`q@doDdC&ZwHO_5;*lKsw?;Jj+eQgiziQaGBckf_YL4NsX__NFI( zsNv0pPQ@kGb@G^DkyI&sHv(P~QSqH1h&os*rqQBTQ%*@KhQBq*9Q~_>c0dI;5e*co zhL-#cgf-I7H_OO8@M&?!%RZ`iD|~bSzL~EKr9l9axv;LL zZlpyupTY_zhkT+z%SUU)>brg*LSH;*_(Z|js}nD}zXSzh#V9s6myN+K1E%s>dUKBO zvROVSC8g*nDcQoI-4|@ zj+qB?(LH|=GxOut;_=OxY++fUeLKmw7GT5CXI#6Z1F2u!NjjLky-sPWc-=Dh(AZB3 z>`h7C$&@7K6l8u@Oh0g+A3Yv&fG2~<&iDbWe$GM$gvdkIS*YWk# zU+%aZG+uq$O*+xLyc9IMb{r64dfq<~ntjF0!jhwvQI^Gr#Vg4lHxQh`zu?%gw8ZQ@ ziY>ZsfUYKDz44Qod-q$tD@lZs$i`@`N<7zw`x&s8)@$66l9!XqB8S$MEmA&mkd*Sc z*Xd!ht(aVx1iZ7u{rb!q3){XWp|YxSE#T6Z+<8^zwj$hn!oy|u^C5F7bUyR)L$&&~ z(ECp}a=)O2f=eUU7zMZNJh-3ZM+8430&my1-PBr3e4Pi&eW{MVE9Cd*+N>BWIh7@w zL!$y9n7xG(EYzvKqCQ$z4@%0Iw}2 z;+T~4WvG1K4T!9kWFibM*8DVlZ$#HWnJa8}H^h9iae6I;9GXS&L^&sMQW%OqyQdNJ8mwLFn{MQ3F5EnV=X5Nw9v@?=>its4 z=V(rQW4imQ*o_19E2HW=5t@Jvc%pfau=!*o!>Y6B){ZoLax>3L!*Zx0QG2v402L!7 z&rtK1U7g)l>BMpcnde20WxHToTl35-FkGlIKK*P_<7jCyYPbW&_(+!DS)LHLFSl@X zJ>7UliKxr?k+c76;BG@; zSA^hg5JLZ`v2$s^W2klBhY6<-Q!e<^^F|%s1f4_;LsP>?UMcl3a-6;`ex2z<`dlMJ zHOmlg#KPhETFP6LUUJ@&St74t1Z+L&A&b>CT*+i5sI8)!g^Aa{?G2iMDL6Sfxghfu z6kPjmTuZOKHud6E6f?^^q zE&C1;4v_R$fBGQ&wQ$J#){uijLDiChb1FFwjC{!<2NF|&jRTV$k9MlYXseKfdGkeJ zBTPc}I6xL_NI(Mb>&yGdk4PR1)H5nPBH_vnf;^W%2NKR=B%DBluR}dNy6qJoXc=%| zj|JK)K4hU*MvEhbNQEeu^$>=1;xI$N--eKZ1Z8wE4hjc|)7eo84iJ|GR2hPI$HgCH zhVKuF$%?b7ydy4uN8JC8iy3Vo5)KR}%^>cMxH`Zk3-yi*4fbQSS75WXrDPC5Jo=6c z<*Pd`x?gFK5OfkulV*3sfZtdz$`vVlOkk5q8VDqr9VNcwGT$UA9p zEiWy7_ab@taH@5A@tFdK10GacM`vujbZg5t)Jxip=BE(~I`)IusAM|dgw98Zk+Ct! z^UH(6k`V<#n46l9MxfU?&nq*`baw$^8{)iE#hfN_kEMP0t&(`W&dyFv9UYl4xiGkg zl`d4Qw@fi>YQa!nLqXW{C==K(@(L6b8_J^}r?ha!oJmIdbHXUHQb^|uxoG^~ zd4Dk)0>pszAwaSM=P$OeD3BL%U>9Cg8^e7yRl;6a4cXm$s6()HU9}m%1Ze4pp zRnTvJ_9)TBXkH*29#20)RC*2Cd$RaFDIN6Vc|lMTff4Z7BXM}CNsill&7tuY6&+9Y zWpzBf@t$C@(SAO6H{> zk(5s=*ob&g>F8uTzbz^TISn@SM(!M*>%lDIN7ic~336c(5y5zBBBgwvJmT&iH zUWp*59ZzE*fcKKpEFwlyep81!hvRP`A)_HiFAoRPCj}~xj*cdF&$6TZFZ_Pq2H@Rr zANg2Pwi#19NyaLOSzd=gV!YB{!1-}trcWEk2uS63(C)>_t=Z|K?@jhm4x14IYj|PG z_maCMITD+}GVpB={ZUN`>e=RU zmh%lpUYk1yv-q-SD)SX5W;5$1(v4^RlVrs^ui-+J!UL@RBX6>nY*jgug67~ZNPf*T z1WJs6b4qFm>n8)eziZ7vk0{*Q=JMS3dfGMUp+hKtV_ycEyr~Kr^ii}yc z=0LaD-6Z4W06+E7-lQr7@HPT$@u|32jB!KLm3TkP69uz-0IQ(W;QOSD0K6L|srS*hnUq@FEX>7j4Fg+H zK5L6PrQ#|(75Dw0SSZSm;SSbTidDy9ygd2$DI(q&cn$Exk8{DOejHuC4sxSs#DP6| z)Hs6)Sdr|@FR8+(*2zAd57v8ibsT;y=ue2F`0ZCbN7_eAEt>x|rv3=dmd@53=Q&#B zF*wp%+8A-7ucJrr3&cFC5rCyV$$Di!bHvNqbc^h1xp@CQ47NBXRa+8>c9Sjs#6t() zp|>JPLjeM&cwxSiy2thYv|~71>4D%-$_?8q@DbMok{E0TXq+m?_`M(4OiEc9Gmc#S zNy0D<(%DKp!ZXR2n3qKz;UAu7EF@JEQ{+)O2&`kB+*YjGMfv{Y_TmNhdD3>7P5&W~8w`<$-f zsxV1&R&`M<%CpfQGThhL{Mz8DM`Qfkdb4%v*Nf9B>+F7>!GB1ieg+GPSn4El7^(Oz z2U}8MQnlxee8YC9;M%4{BjuVJl#~PH=}}7xs0Es+&S7kH%XG0^PC6ob(-Nn_lyQuv zW@zZMS$UiI;Zb&OF8!^RG_c0y?(tz!#jd;lQ@;Gk%$C5l_4>Z`3m(3an}ch-qk>xO zMM^J;Wd6|BKgYJoIA&S^CJygw+uWy2`w}AvwjI`^SbEF zR}pVR%^E0?nQ#>f^so>MZR6tS!$PakIrY8OncY zy8SXgylU(%Zj+(kQ>dY&gqsi#uVK6N+wX9H-R6IDQyhLVARzk-N-RV_x{dL|3jA_h z{SgVeayg6ib@azH&IEg~QygmZ`ffXJ$nq-)ZK>`h3h>fsh?S3R=)rKnJA~A`F*8A;iu8OF)3b`N?DzGq%Fc5#+*dPpSH$~qd=;XD9Bb9VNGVeUijc9 zA0{ld>j3O^LrK`P%vFTsRBrb%lZD=x9KCX-l&v(11C(y%3n|hC+spp3Bj-KNMUdTnx9EJwqG6RDAp3k|+F2cVqlvo(QSD2$3TD zhmGN7C>GXJ=ywj~sfVnauLS&b+PY{ui2ky?jtn+KYJXs_HeomY+QFU;m+VEwrdakE zDyckP)ZhK&9qY%!I0`S~4O*h^_Gt7AFMCY?$6fy4iVoEJ1iB^__ho!E=UK`$_}eO+ zu~HV%f!Q9Z_HSVb9B*yq(|Z?#h$ymH)50mD(SRp{t;KfuB_ATfK8+ERW_{?FY(5=V zm52>3x;(Ck+xEzuuxGhss*#%;{jMl62g+)%@~YjLwm`R`%`1pB*-Tmcnn~}ap{^^; zpm{EP%;R##o$PngG$FAU*zhoTf*puwo;tt!*D$J~6mtFw0gK=F&?>fWL3=QM8lHEs zn>RGP1<&p$CwfNCY;#Ed49Qx|h^`#GbT z*`}Ld<8y5YR!IwEhAUviG!;)lPFYJse-GlsYWYBU`5)Ee%(Xqi_r(mg_PDu60}e0M zT`XWzYj1L6?l*1bb>)JeUHZPala1{qd%C@uQ*EjDv9y5`)WdMB>1nO7l4;O82Jv^`|*&mDS1G>ZOiu0s+=&xOB+HUm`Jui=}6_3asM6}uAPJ|r@**uxw4lubne}^ zW6Q~zNC*k+F5_PE_WsUGe`bWvf zPg)ihsi{LcGeVN)fko5x?>Ynm;S7ix;flR&{ln~VXtB zrpHlmaw2MOS&$=hJdhRGPW==>`)NiT#o*~V$G?Z#or4wMiR4)WvsW?;^x5U|j=5jy z^FP!jQbc9Sxagz&R#C9s>pu3ebM%W7B15HJI@jHq-77uZkK%T>DdW_*gr85>ZgqxR zPe6VvboQu@jC_-~mu{j5$$ii@V z{oBlA5s_5h&lQo5B1M@#i=zc)D%1SjOad)jQ9P|TTZ`4J#e)SUYe>s@q6x0M7eT%k zLyo;Yl%CnD-(67vt$UFEUF(_C%!aJ@bQEIU%{xM8|;xS76xl$ z^OuArVRaih_EQBiZqr<#z?%3NPmpfEALi+(O{;mVi13KT0<^+(sv6rimJFf zOm2jd#Ir+w?wZ{SG^4{ogr1|GqS5X)mvo+pvaibQoO)`t%~;tLz8bbJf> z(!hb~khL`X_eC?n9tw_MS5;Xo=uRG|*Bd`z%Bzc1zQ}xwRjMzb#m`!6I3_4H_N045 zZB{izs<~Wp)<|G4@MqMmBgJjaF>n0wnd7qm3DamLSr1p6@-?2nnx_$#;axq3BBRn zBY}B3d+ddb5Rs7_FI3Y`ct5&5{fu@(ZL8b-!TxJX9xcOIRISkVmMqIW!!GKu{=5IPexevFN7cDH#8)>uIy&ZWYi1hLevV2+$x?Pe%j5( z@e+BjM;ZxHBT^Z4rT)C)o;w|vQ5U!s?&~&R?7ESP6WvWQ?JJt=OlRJ!Hc5%qZ$D|g~cp!hm6B}S%_^VbUcd8 zprD|M_Q!R8KRE&K5i`Kb5~ywrO`E*FoThZv3>}c_R0}{uarv#{Ma+VWYwa89_eWYJ zrZ@^38bqarEkScW=lo>8hX&TlR+nH0n(SoMPR*+oUh9H_YQD}6&w|N!zv7>tC?3#v zX#BECDcra{`BAAgbmp#mmK-l#XY;-$6YoOM37Y_fZwUSWlEQ}jQSWp z=G;r3Dnq5xCq$bk0fEr+@bbo9;KohXPR>>~k`Aq5vwu9yAGC*(mWU!Fcj{k<5R#J0 z0N+)WHVGCXFkimXc!c(2lr~)A6E=pk7s|%1im$jxf_B1iWHzPdyxH}6GKP2fbC3}4 zF`}`i2HxJRp=IdtVT_1neG2vI#6?#5qJ~p->@)cXY8gkq@$oUt1eCFnpR1P03v5+z%*Hsvk!P;S>D^59y=;bWz=GnEfxP(BBjQgeWA!BL4&% zf$RUlDl`^-A18Bex*t5@d2xjO7Zmvu4?v-Ih6r;0LQQpu-tl`@qv(JQSfAfV|6T3i z4(6J~fAUwQNg{v9o92#3G{1ZHl?lL8}r0Fd;U(A57o^h%QfxM^?P=P17+ zQ3qiF-I`PgM}GS!tm^L#NO8c^0emy{hwj>r0Yw?GKfjOh-zGr7ExH+|l9$zl9PqDZ PBqTX0Wy!LauipMIAx 0 (the graph's rounded CI mask). Persisted as an `intervention_run` so predictions are available synchronously — no lazy fetching. --- diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 3ac6c453f..c51a0cd90 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -17,7 +17,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.optim_cis import ( CISnapshotCallback, - KLLossConfig, + LossConfig, OptimCIConfig, OptimizationMetrics, _compute_recon_loss, @@ -681,6 +681,11 @@ def extract_node_subcomp_acts( MASKED_PRED_TOP_K = 3 +# Default eval PGD settings (distinct from optimization PGD which is a training regularizer) +EVAL_PGD_N_STEPS = 4 +EVAL_PGD_STEP_SIZE = 1.0 + + @dataclass class MaskedPredictionsResult: """Top-k predictions under CI, stochastic, and adversarial masking.""" @@ -688,6 +693,9 @@ class MaskedPredictionsResult: ci: TopKPreds stochastic: TopKPreds adversarial: TopKPreds + ci_kl: float # KL(target || ci_masked), mean over positions + stochastic_kl: float + adversarial_kl: float def _extract_topk( @@ -708,20 +716,84 @@ def _extract_topk( return result +def _mean_kl_from_target( + masked_logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], +) -> float: + """Mean KL divergence from target across all positions.""" + target_probs = torch.softmax(target_logits, dim=-1) + pred_log_probs = torch.log_softmax(masked_logits, dim=-1) + # KL per position, then mean + kl_per_pos = torch.nn.functional.kl_div(pred_log_probs, target_probs, reduction="none").sum( + dim=-1 + ) # [1, seq] + return float(kl_per_pos.mean().item()) + + +def _run_eval_pgd( + model: ComponentModel, + tokens: Float[Tensor, "1 seq"], + ci_masks: dict[str, Float[Tensor, "1 seq C"]], + target_logits: Float[Tensor, "1 seq vocab"], + n_steps: int, + step_size: float, + pgd_loss_config: LossConfig | None, +) -> Float[Tensor, "1 seq vocab"]: + """Run PGD to find adversarial sources maximizing loss, return final logits. + + If pgd_loss_config is provided (optimized graph), ascends that specific loss. + Otherwise, ascends mean KL divergence from target over all positions. + """ + adv_sources: dict[str, Tensor] = {} + for layer_name, ci in ci_masks.items(): + source = torch.zeros_like(ci) + source.requires_grad_(True) + adv_sources[layer_name] = source + + source_list = list(adv_sources.values()) + ci_detached = {k: v.detach() for k, v in ci_masks.items()} + + for _ in range(n_steps): + adv_mask_infos = make_mask_infos( + _interpolate_masks(ci_detached, adv_sources), routing_masks="all" + ) + with bf16_autocast(): + out = model(tokens, mask_infos=adv_mask_infos) + + if pgd_loss_config is not None: + loss = _compute_recon_loss(out, pgd_loss_config, target_logits, str(tokens.device)) + else: + target_probs = torch.softmax(target_logits, dim=-1) + pred_log_probs = torch.log_softmax(out, dim=-1) + loss = torch.nn.functional.kl_div(pred_log_probs, target_probs, reduction="batchmean") + + grads = torch.autograd.grad(loss, source_list) + with torch.no_grad(): + for (_layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): + source.add_(step_size * grad.sign()) + source.clamp_(0.0, 1.0) + + with torch.no_grad(), bf16_autocast(): + final_adv_masks = _interpolate_masks( + ci_detached, {k: v.detach() for k, v in adv_sources.items()} + ) + adv_mask_infos = make_mask_infos(final_adv_masks, routing_masks="all") + return model(tokens, mask_infos=adv_mask_infos) + + def compute_masked_predictions( model: ComponentModel, tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] tokenizer: AppTokenizer, - adv_n_steps: int, - adv_step_size: float, + pgd_n_steps: int, + pgd_step_size: float, + pgd_loss_config: LossConfig | None = None, ) -> MaskedPredictionsResult: - """Compute top-k predictions under CI, stochastic, and adversarial masking. + """Compute top-k predictions + KL under CI, stochastic, and adversarial masking. - Given a set of active nodes (mask=1, rest=0): - - CI: mask = selection (binary) - - Stochastic: mask = selection + (1-selection) * rand - - Adversarial: mask = selection + (1-selection) * PGD-optimized source + PGD always runs. If pgd_loss_config is provided (optimized graph), PGD ascends that + specific loss. Otherwise, PGD ascends mean KL divergence from target. """ seq_len = tokens.shape[1] device = tokens.device @@ -734,10 +806,11 @@ def compute_masked_predictions( ci_masks[layer][0, seq_pos, c_idx] = 1.0 with torch.no_grad(), bf16_autocast(): + target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) + # CI-masked forward ci_mask_infos = make_mask_infos(ci_masks, routing_masks="all") ci_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ci_mask_infos) - ci_preds = _extract_topk(ci_logits, tokenizer) # Stochastic forward: mask = ci + (1-ci) * rand stoch_sources = { @@ -750,45 +823,32 @@ def compute_masked_predictions( } stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) - stoch_preds = _extract_topk(stoch_logits, tokenizer) - - # Adversarial forward (needs gradients for PGD) - loss_config = KLLossConfig(coeff=1.0, position=seq_len - 1) - - with torch.no_grad(), bf16_autocast(): - target_out = model(tokens) - - adv_sources: dict[str, Tensor] = {} - for layer_name, ci in ci_masks.items(): - source = torch.zeros_like(ci) - source.requires_grad_(True) - adv_sources[layer_name] = source - - source_list = list(adv_sources.values()) - ci_detached = {k: v.detach() for k, v in ci_masks.items()} - - for _ in range(adv_n_steps): - adv_mask_infos = make_mask_infos( - _interpolate_masks(ci_detached, adv_sources), routing_masks="all" - ) - with bf16_autocast(): - out = model(tokens, mask_infos=adv_mask_infos) - loss = _compute_recon_loss(out, loss_config, target_out, str(device)) - grads = torch.autograd.grad(loss, source_list) - with torch.no_grad(): - for (_layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): - source.add_(adv_step_size * grad.sign()) - source.clamp_(0.0, 1.0) - with torch.no_grad(), bf16_autocast(): - final_adv_masks = _interpolate_masks( - ci_detached, {k: v.detach() for k, v in adv_sources.items()} - ) - adv_mask_infos = make_mask_infos(final_adv_masks, routing_masks="all") - adv_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=adv_mask_infos) - adv_preds = _extract_topk(adv_logits, tokenizer) + # Adversarial forward (PGD always runs) + adv_logits = _run_eval_pgd( + model=model, + tokens=tokens, + ci_masks=ci_masks, + target_logits=target_logits, + n_steps=pgd_n_steps, + step_size=pgd_step_size, + pgd_loss_config=pgd_loss_config, + ) - return MaskedPredictionsResult(ci=ci_preds, stochastic=stoch_preds, adversarial=adv_preds) + # Extract top-k predictions and KL metrics + with torch.no_grad(): + ci_kl = _mean_kl_from_target(ci_logits, target_logits) + stoch_kl = _mean_kl_from_target(stoch_logits, target_logits) + adv_kl = _mean_kl_from_target(adv_logits, target_logits) + + return MaskedPredictionsResult( + ci=_extract_topk(ci_logits, tokenizer), + stochastic=_extract_topk(stoch_logits, tokenizer), + adversarial=_extract_topk(adv_logits, tokenizer), + ci_kl=ci_kl, + stochastic_kl=stoch_kl, + adversarial_kl=adv_kl, + ) @dataclass diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index b6893699a..9e0832558 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -61,6 +61,11 @@ class PromptRecord(BaseModel): is_custom: bool = False +class PgdConfig(BaseModel): + n_steps: int + step_size: float + + class OptimizationParams(BaseModel): """Optimization parameters that affect graph computation.""" @@ -70,8 +75,7 @@ class OptimizationParams(BaseModel): beta: float mask_type: MaskType loss: LossConfig - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + pgd: PgdConfig | None = None # Computed metrics (persisted for display on reload) ci_masked_label_prob: float | None = None stoch_masked_label_prob: float | None = None @@ -458,8 +462,12 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: mask_type = graph.optimization_params.mask_type loss_config_json = graph.optimization_params.loss.model_dump_json() loss_config_hash = hashlib.sha256(loss_config_json.encode()).hexdigest() - adv_pgd_n_steps = graph.optimization_params.adv_pgd_n_steps - adv_pgd_step_size = graph.optimization_params.adv_pgd_step_size + adv_pgd_n_steps = ( + graph.optimization_params.pgd.n_steps if graph.optimization_params.pgd else None + ) + adv_pgd_step_size = ( + graph.optimization_params.pgd.step_size if graph.optimization_params.pgd else None + ) ci_masked_label_prob = graph.optimization_params.ci_masked_label_prob stoch_masked_label_prob = graph.optimization_params.stoch_masked_label_prob adv_pgd_label_prob = graph.optimization_params.adv_pgd_label_prob @@ -571,6 +579,9 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: loss_config = CELossConfig(**loss_config_data) else: loss_config = KLLossConfig(**loss_config_data) + pgd = None + if row["adv_pgd_n_steps"] is not None: + pgd = PgdConfig(n_steps=row["adv_pgd_n_steps"], step_size=row["adv_pgd_step_size"]) opt_params = OptimizationParams( imp_min_coeff=row["imp_min_coeff"], steps=row["steps"], @@ -578,8 +589,7 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: beta=row["beta"], mask_type=row["mask_type"], loss=loss_config, - adv_pgd_n_steps=row["adv_pgd_n_steps"], - adv_pgd_step_size=row["adv_pgd_step_size"], + pgd=pgd, ci_masked_label_prob=row["ci_masked_label_prob"], stoch_masked_label_prob=row["stoch_masked_label_prob"], adv_pgd_label_prob=row["adv_pgd_label_prob"], diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index d00846c9b..cdf40a5fc 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -19,11 +19,21 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import ( + EVAL_PGD_N_STEPS, + EVAL_PGD_STEP_SIZE, Edge, + compute_intervention_forward, + compute_masked_predictions, compute_prompt_attributions, compute_prompt_attributions_optimized, ) -from spd.app.backend.database import GraphType, OptimizationParams, StoredGraph +from spd.app.backend.database import ( + GraphType, + OptimizationParams, + PgdConfig, + PromptAttrDB, + StoredGraph, +) from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.optim_cis import ( AdvPGDConfig, @@ -34,12 +44,106 @@ MaskType, OptimCIConfig, ) +from spd.app.backend.routers.intervention import ( + InterventionResponse, + MaskedPredictionsResponse, + TokenPred, + TokenPrediction, +) from spd.app.backend.schemas import OutputProbability from spd.app.backend.utils import log_errors from spd.configs import ImportanceMinimalityLossConfig from spd.log import logger +from spd.models.component_model import ComponentModel +from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device +NON_INTERVENTABLE_LAYERS = {"embed", "output"} + + +def _save_base_intervention_run( + graph_id: int, + model: ComponentModel, + tokens: torch.Tensor, + node_ci_vals: dict[str, float], + tokenizer: AppTokenizer, + topology: TransformerTopology, + db: PromptAttrDB, + pgd_loss_config: LossConfig | None = None, +) -> None: + """Compute masked predictions for all interventable nodes and save as an intervention run.""" + # Get all interventable node keys with CI > 0 + interventable_keys = [ + k + for k, ci in node_ci_vals.items() + if k.split(":")[0] not in NON_INTERVENTABLE_LAYERS and ci > 0 + ] + if not interventable_keys: + return + + # Parse to (concrete_path, seq, c_idx) tuples + active_nodes: list[tuple[str, int, int]] = [] + for key in interventable_keys: + canon_layer, seq_str, cidx_str = key.split(":") + concrete_path = topology.canon_to_target(canon_layer) + active_nodes.append((concrete_path, int(seq_str), int(cidx_str))) + + # Compute intervention forward pass + intervention_result = compute_intervention_forward( + model=model, + tokens=tokens, + active_nodes=active_nodes, + top_k=10, + tokenizer=tokenizer, + ) + + intervention_response = InterventionResponse( + input_tokens=intervention_result.input_tokens, + predictions_per_position=[ + [ + TokenPrediction( + token=token, + token_id=token_id, + spd_prob=spd_prob, + target_prob=target_prob, + logit=logit, + target_logit=target_logit, + ) + for token, token_id, spd_prob, logit, target_prob, target_logit in pos_preds + ] + for pos_preds in intervention_result.predictions_per_position + ], + ) + + masked_result = compute_masked_predictions( + model=model, + tokens=tokens, + active_nodes=active_nodes, + tokenizer=tokenizer, + pgd_n_steps=EVAL_PGD_N_STEPS, + pgd_step_size=EVAL_PGD_STEP_SIZE, + pgd_loss_config=pgd_loss_config, + ) + + def to_preds(preds: list[list[tuple[str, float]]]) -> list[list[TokenPred]]: + return [[TokenPred(token=t, prob=p) for t, p in pos] for pos in preds] + + masked_response = MaskedPredictionsResponse( + ci=to_preds(masked_result.ci), + stochastic=to_preds(masked_result.stochastic), + adversarial=to_preds(masked_result.adversarial), + ci_kl=masked_result.ci_kl, + stochastic_kl=masked_result.stochastic_kl, + adversarial_kl=masked_result.adversarial_kl, + ) + + db.save_intervention_run( + graph_id=graph_id, + selected_nodes=interventable_keys, + result_json=intervention_response.model_dump_json(), + masked_predictions_json=masked_response.model_dump_json(), + ) + class EdgeData(BaseModel): """Edge in the attribution graph.""" @@ -115,8 +219,7 @@ class OptimizationResult(BaseModel): mask_type: MaskType loss: CELossResult | KLLossResult metrics: OptimizationMetricsResult - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + pgd: PgdConfig | None = None class GraphDataWithOptimization(GraphData): @@ -509,6 +612,18 @@ def work( ) logger.info(f"[perf] save_graph: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + ) + logger.info(f"[perf] base intervention run: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() fg = filter_graph_for_display( raw_edges=result.edges, @@ -653,8 +768,9 @@ def compute_graph_optimized_stream( beta=beta, mask_type=mask_type, loss=loss_config, - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) + if adv_pgd_n_steps is not None and adv_pgd_step_size is not None + else None, ) optim_config = OptimCIConfig( @@ -715,6 +831,17 @@ def work( ), ) + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + pgd_loss_config=loss_config, + ) + fg = filter_graph_for_display( raw_edges=result.edges, node_ci_vals=result.node_ci_vals, @@ -766,8 +893,9 @@ def work( adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, l0_total=result.metrics.l0_total, ), - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) + if adv_pgd_n_steps is not None and adv_pgd_step_size is not None + else None, ), ) @@ -931,8 +1059,9 @@ def stored_graph_to_response( stoch_masked_label_prob=opt.stoch_masked_label_prob, adv_pgd_label_prob=opt.adv_pgd_label_prob, ), - adv_pgd_n_steps=opt.adv_pgd_n_steps, - adv_pgd_step_size=opt.adv_pgd_step_size, + pgd=PgdConfig(n_steps=opt.pgd.n_steps, step_size=opt.pgd.step_size) + if opt.pgd is not None + else None, ), ) diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index da9f67fed..4ad77ee49 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -49,6 +49,11 @@ class InterventionResponse(BaseModel): predictions_per_position: list[list[TokenPrediction]] +class AdvPgdParams(BaseModel): + n_steps: int + step_size: float + + class RunInterventionRequest(BaseModel): """Request to run and save an intervention.""" @@ -56,8 +61,7 @@ class RunInterventionRequest(BaseModel): text: str selected_nodes: list[str] # node keys (layer:seq:cIdx) top_k: int - adv_pgd_n_steps: int - adv_pgd_step_size: float + adv_pgd: AdvPgdParams class ForkedInterventionRunSummary(BaseModel): @@ -78,6 +82,9 @@ class MaskedPredictionsResponse(BaseModel): ci: list[list[TokenPred]] stochastic: list[list[TokenPred]] adversarial: list[list[TokenPred]] + ci_kl: float + stochastic_kl: float + adversarial_kl: float class InterventionRunSummary(BaseModel): @@ -238,13 +245,19 @@ def run_and_save_intervention( loaded=loaded, ) + pgd_loss_config = None + graph_record = db.get_graph(request.graph_id) + if graph_record is not None and graph_record[0].optimization_params is not None: + pgd_loss_config = graph_record[0].optimization_params.loss + masked_result = compute_masked_predictions( model=loaded.model, tokens=tokens, active_nodes=active_nodes, tokenizer=loaded.tokenizer, - adv_n_steps=request.adv_pgd_n_steps, - adv_step_size=request.adv_pgd_step_size, + pgd_n_steps=request.adv_pgd.n_steps, + pgd_step_size=request.adv_pgd.step_size, + pgd_loss_config=pgd_loss_config, ) def _to_token_preds(preds: list[list[tuple[str, float]]]) -> list[list[TokenPred]]: @@ -254,6 +267,9 @@ def _to_token_preds(preds: list[list[tuple[str, float]]]) -> list[list[TokenPred ci=_to_token_preds(masked_result.ci), stochastic=_to_token_preds(masked_result.stochastic), adversarial=_to_token_preds(masked_result.adversarial), + ci_kl=masked_result.ci_kl, + stochastic_kl=masked_result.stochastic_kl, + adversarial_kl=masked_result.adversarial_kl, ) run_id = db.save_intervention_run( diff --git a/spd/app/backend/routers/run_registry.py b/spd/app/backend/routers/run_registry.py index 5f5d99d61..8e29e2134 100644 --- a/spd/app/backend/routers/run_registry.py +++ b/spd/app/backend/routers/run_registry.py @@ -1,7 +1,7 @@ """Run registry endpoint. -Returns canonical SPD runs with lightweight data availability checks -so the run picker can show what post-processing data exists at a glance. +Returns architecture and data availability for requested SPD runs. +The canonical run list lives in the frontend; the backend just hydrates it. """ from pathlib import Path @@ -18,42 +18,6 @@ router = APIRouter(prefix="/api/run_registry", tags=["run_registry"]) -class RegistryEntry(BaseModel): - wandb_run_id: str - name: str | None = None - notes: str | None = None - - -CANONICAL_RUNS: list[RegistryEntry] = [ - RegistryEntry( - name="Thomas", - wandb_run_id="goodfire/spd/runs/s-82ffb969", - notes="pile_llama_simple_mlp-4L", - ), - RegistryEntry( - name="Jose", - wandb_run_id="goodfire/spd/s-55ea3f9b", - notes="pile_llama_simple_mlp-4L", - ), - RegistryEntry( - wandb_run_id="goodfire/spd/s-275c8f21", - notes="Lucius' pile run Feb 11", - ), - RegistryEntry( - wandb_run_id="goodfire/spd/s-eab2ace8", - notes="Oli's PPGD run, great metrics", - ), - RegistryEntry( - wandb_run_id="goodfire/spd/s-892f140b", - notes="Lucius run, Jan 22", - ), - RegistryEntry( - wandb_run_id="goodfire/spd/s-7884efcc", - notes="Lucius' new run, Jan 8", - ), -] - - class DataAvailability(BaseModel): harvest: bool autointerp: bool @@ -61,10 +25,8 @@ class DataAvailability(BaseModel): graph_interp: bool -class RegistryRunInfo(BaseModel): +class RunInfoResponse(BaseModel): wandb_run_id: str - name: str | None - notes: str | None architecture: str | None availability: DataAvailability @@ -114,23 +76,18 @@ def _get_architecture_summary(wandb_path: str) -> str | None: return None -@router.get("") +@router.post("") @log_errors -def get_run_registry() -> list[RegistryRunInfo]: - """Return all canonical runs with data availability.""" - results: list[RegistryRunInfo] = [] - for entry in CANONICAL_RUNS: - _, _, run_id = parse_wandb_run_path(entry.wandb_run_id) - availability = _check_availability(run_id) - architecture = _get_architecture_summary(entry.wandb_run_id) - +def get_run_info(wandb_run_ids: list[str]) -> list[RunInfoResponse]: + """Return architecture and availability for the requested runs.""" + results: list[RunInfoResponse] = [] + for wandb_run_id in wandb_run_ids: + _, _, run_id = parse_wandb_run_path(wandb_run_id) results.append( - RegistryRunInfo( - wandb_run_id=entry.wandb_run_id, - name=entry.name, - notes=entry.notes, - architecture=architecture, - availability=availability, + RunInfoResponse( + wandb_run_id=wandb_run_id, + architecture=_get_architecture_summary(wandb_run_id), + availability=_check_availability(run_id), ) ) return results diff --git a/spd/app/frontend/src/components/PromptAttributionsTab.svelte b/spd/app/frontend/src/components/PromptAttributionsTab.svelte index 838d66e14..f26ce2c05 100644 --- a/spd/app/frontend/src/components/PromptAttributionsTab.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsTab.svelte @@ -22,13 +22,7 @@ type TabViewState, type ViewSettings, } from "./prompt-attr/types"; - import { - buildInterventionState, - getInterventableNodes, - getRunSelection, - type BakedRun, - type InterventionState, - } from "../lib/interventionTypes"; + import { buildInterventionState, type BakedRun, type InterventionState } from "../lib/interventionTypes"; import { SvelteSet } from "svelte/reactivity"; import ViewControls from "./prompt-attr/ViewControls.svelte"; import ViewTabs from "./prompt-attr/ViewTabs.svelte"; @@ -351,58 +345,6 @@ function handleViewChange(view: "graph" | "interventions") { if (!activeCard) return; promptCards = promptCards.map((card) => (card.id === activeCard.id ? { ...card, activeView: view } : card)); - - // Auto-bake the base run when switching to interventions view - if (view === "interventions" && activeGraph) { - const state = getInterventionState(activeGraph.id, activeGraph); - const firstRun = state.runs[0]; - if (firstRun.kind === "base") { - autoBakeBaseRun(activeCard, activeGraph, state); - } - } - } - - async function autoBakeBaseRun(card: PromptCard, graph: StoredGraph, state: InterventionState) { - const allInterventable = getInterventableNodes(graph.data.nodeCiVals); - const text = card.tokens.join(""); - const selectedNodes = Array.from(allInterventable); - - runningIntervention = true; - try { - const run = await api.runAndSaveIntervention({ - graph_id: graph.id, - text, - selected_nodes: selectedNodes, - top_k: 10, - adv_pgd_n_steps: 4, - adv_pgd_step_size: 1.0, - }); - - const baked: BakedRun = { - kind: "baked", - id: run.id, - selectedNodes: new Set(run.selected_nodes), - result: run.result, - maskedPredictions: run.masked_predictions, - createdAt: run.created_at, - }; - state.runs[0] = baked; - state.activeIndex = 0; - - promptCards = promptCards.map((c) => { - if (c.id !== card.id) return c; - return { - ...c, - graphs: c.graphs.map((g) => - g.id === graph.id ? { ...g, interventionRuns: [...g.interventionRuns, run] } : g, - ), - }; - }); - - interventionStates = { ...interventionStates }; - } finally { - runningIntervention = false; - } } // Update draft selection for the active graph @@ -417,7 +359,7 @@ } // Forward a draft run: call API, replace draft with baked - async function handleForwardDraft(advNSteps: number, advStepSize: number) { + async function handleForwardDraft(advPgd: { n_steps: number; step_size: number }) { if (!activeCard || !activeGraph) return; const state = interventionStates[activeGraph.id]; if (!state) throw new Error("No intervention state for active graph"); @@ -434,8 +376,7 @@ text, selected_nodes: selectedNodes, top_k: 10, - adv_pgd_n_steps: advNSteps, - adv_pgd_step_size: advStepSize, + adv_pgd: advPgd, }); // Replace the draft with a baked run @@ -483,13 +424,12 @@ if (!state) throw new Error("No intervention state for active graph"); const activeRun = state.runs[state.activeIndex]; - const allInterventable = getInterventableNodes(activeGraph.data.nodeCiVals); - const parentSelection = getRunSelection(activeRun, allInterventable); + if (activeRun.kind !== "baked") throw new Error("Can only clone baked runs"); const draft = { kind: "draft" as const, - parentId: activeRun.kind === "baked" ? activeRun.id : ("base" as const), - selectedNodes: new SvelteSet(parentSelection), + parentId: activeRun.id, + selectedNodes: new SvelteSet(activeRun.selectedNodes), }; state.runs.push(draft); state.activeIndex = state.runs.length - 1; @@ -541,8 +481,9 @@ const state = interventionStates[activeGraph.id]; if (!state) throw new Error("No intervention state"); const activeRun = state.runs[state.activeIndex]; - const allInterventable = getInterventableNodes(activeGraph.data.nodeCiVals); - const selection = getRunSelection(activeRun, allInterventable); + if (activeRun.kind !== "baked" && activeRun.kind !== "draft") + throw new Error("Can only generate subgraph from baked or draft runs"); + const selection = activeRun.selectedNodes; if (selection.size === 0) throw new Error("handleGenerateGraphFromSelection called with empty selection"); @@ -703,12 +644,13 @@ }); } + const runs = await api.getInterventionRuns(data.id); const newGraph: StoredGraph = { id: data.id, label: getGraphLabel(data), data, viewSettings: { ...defaultViewSettings }, - interventionRuns: [], + interventionRuns: runs, }; getInterventionState(data.id, newGraph); @@ -757,7 +699,7 @@ }; }); - // Intervention state stays as-is — base run auto-reflects new node keys via getRunSelection + // Intervention state stays as-is — base run's selectedNodes are from the persisted run } finally { refetchingGraphId = null; } diff --git a/spd/app/frontend/src/components/RunSelector.svelte b/spd/app/frontend/src/components/RunSelector.svelte index 71a5b0423..ef75c771f 100644 --- a/spd/app/frontend/src/components/RunSelector.svelte +++ b/spd/app/frontend/src/components/RunSelector.svelte @@ -1,7 +1,7 @@
@@ -61,13 +53,6 @@ 2, )})
- {#if singlePosEntry.adv_pgd_prob !== null && singlePosEntry.adv_pgd_logit !== null} -
- Adversarial: {(singlePosEntry.adv_pgd_prob * 100).toFixed(1)}% (logit: {singlePosEntry.adv_pgd_logit.toFixed( - 2, - )}) -
- {/if}

Position: @@ -83,10 +68,6 @@ Logit Target Logit - {#if hasAdvPgd} - Adv - Logit - {/if} @@ -97,10 +78,6 @@ {pos.logit.toFixed(2)} {(pos.target_prob * 100).toFixed(2)}% {pos.target_logit.toFixed(2)} - {#if hasAdvPgd} - {pos.adv_pgd_prob !== null ? (pos.adv_pgd_prob * 100).toFixed(2) + "%" : "—"} - {pos.adv_pgd_logit !== null ? pos.adv_pgd_logit.toFixed(2) : "—"} - {/if} {/each} diff --git a/spd/app/frontend/src/lib/api/intervention.ts b/spd/app/frontend/src/lib/api/intervention.ts index 689c29cc1..e9104b189 100644 --- a/spd/app/frontend/src/lib/api/intervention.ts +++ b/spd/app/frontend/src/lib/api/intervention.ts @@ -3,7 +3,6 @@ */ import type { - ForkedInterventionRunSummary, InterventionRunSummary, RunInterventionRequest, } from "../interventionTypes"; @@ -39,30 +38,3 @@ export async function deleteInterventionRun(runId: number): Promise { throw new Error(error.detail || "Failed to delete intervention run"); } } - -export async function forkInterventionRun( - runId: number, - tokenReplacements: [number, number][], - topK: number = 10, -): Promise { - const response = await fetch(`/api/intervention/runs/${runId}/fork`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ token_replacements: tokenReplacements, top_k: topK }), - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to fork intervention run"); - } - return (await response.json()) as ForkedInterventionRunSummary; -} - -export async function deleteForkedInterventionRun(forkId: number): Promise { - const response = await fetch(`/api/intervention/forks/${forkId}`, { - method: "DELETE", - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to delete forked intervention run"); - } -} diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index d092832a3..d0715ae0a 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -4,67 +4,41 @@ export const EVAL_PGD_N_STEPS = 4; export const EVAL_PGD_STEP_SIZE = 1.0; -export type InterventionNode = { - layer: string; - seq_pos: number; - component_idx: number; -}; - export type TokenPrediction = { token: string; token_id: number; - spd_prob: number; - target_prob: number; + prob: number; logit: number; + target_prob: number; target_logit: number; }; -export type InterventionResponse = { +export type InterventionResult = { input_tokens: string[]; - predictions_per_position: TokenPrediction[][]; -}; - -/** A forked intervention run with modified tokens */ -export type ForkedInterventionRunSummary = { - id: number; - token_replacements: [number, number][]; // [(seq_pos, new_token_id), ...] - result: InterventionResponse; - created_at: string; + ci: TokenPrediction[][]; + stochastic: TokenPrediction[][]; + adversarial: TokenPrediction[][]; + ci_loss: number; + stochastic_loss: number; + adversarial_loss: number; }; /** Persisted intervention run from the server */ export type InterventionRunSummary = { id: number; selected_nodes: string[]; // node keys (layer:seq:cIdx) - result: InterventionResponse; - masked_predictions: MaskedPredictions; + result: InterventionResult; created_at: string; - forked_runs?: ForkedInterventionRunSummary[]; // child runs with modified tokens }; /** Request to run and save an intervention */ export type RunInterventionRequest = { graph_id: number; - text: string; selected_nodes: string[]; top_k: number; adv_pgd: { n_steps: number; step_size: number }; }; -export type TokenPred = { - token: string; - prob: number; -}; - -export type MaskedPredictions = { - ci: TokenPred[][]; - stochastic: TokenPred[][]; - adversarial: TokenPred[][]; - ci_kl: number; - stochastic_kl: number; - adversarial_kl: number; -}; - // --- Frontend-only run lifecycle types --- import { SvelteSet } from "svelte/reactivity"; @@ -82,8 +56,7 @@ export type BakedRun = { kind: "baked"; id: number; selectedNodes: Set; - result: InterventionResponse; - maskedPredictions: MaskedPredictions; + result: InterventionResult; createdAt: string; }; @@ -109,7 +82,6 @@ export function buildInterventionState(persistedRuns: InterventionRunSummary[]): id: r.id, selectedNodes: new Set(r.selected_nodes), result: r.result, - maskedPredictions: r.masked_predictions, createdAt: r.created_at, }), ); diff --git a/spd/app/frontend/src/lib/promptAttributionsTypes.ts b/spd/app/frontend/src/lib/promptAttributionsTypes.ts index 85648efa8..4d0769ac4 100644 --- a/spd/app/frontend/src/lib/promptAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/promptAttributionsTypes.ts @@ -28,8 +28,6 @@ export type OutputProbability = { logit: number; // CI-masked (SPD model) raw logit target_prob: number; // Target model probability target_logit: number; // Target model raw logit - adv_pgd_prob: number | null; // Adversarial PGD probability - adv_pgd_logit: number | null; // Adversarial PGD raw logit token: string; }; diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index fb708ceb5..6d546c86a 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -30,11 +30,11 @@ def _init_adv_sources( match pgd_config.mask_scope: case "unique_per_datapoint": shape = torch.Size([*batch_dims, mask_c]) - source = _get_pgd_init_tensor(pgd_config.init, shape, device) + source = get_pgd_init_tensor(pgd_config.init, shape, device) case "shared_across_batch": singleton_batch_dims = [1 for _ in batch_dims] shape = torch.Size([*singleton_batch_dims, mask_c]) - source = broadcast_tensor(_get_pgd_init_tensor(pgd_config.init, shape, device)) + source = broadcast_tensor(get_pgd_init_tensor(pgd_config.init, shape, device)) adv_sources[module_name] = source.requires_grad_(True) return adv_sources @@ -86,7 +86,7 @@ def _construct_mask_infos_from_adv_sources( adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()} return make_mask_infos( - component_masks=_interpolate_component_mask(ci, adv_sources_components), + component_masks=interpolate_pgd_mask(ci, adv_sources_components), weight_deltas_and_masks=weight_deltas_and_masks, routing_masks=routing_masks, ) @@ -171,7 +171,7 @@ def calc_multibatch_pgd_masked_recon_loss( mask_c = module_c if not use_delta_component else module_c + 1 shape = torch.Size([*singleton_batch_dims, mask_c]) adv_sources[module_name] = broadcast_tensor( - _get_pgd_init_tensor(pgd_config.init, shape, device) + get_pgd_init_tensor(pgd_config.init, shape, device) ).requires_grad_(True) fwd_bwd_fn = partial( @@ -301,11 +301,12 @@ def _multibatch_pgd_fwd_bwd( return pgd_step_accum_sum_loss, pgd_step_accum_n_examples, pgd_step_accum_grads -def _get_pgd_init_tensor( +def get_pgd_init_tensor( init: PGDInitStrategy, shape: tuple[int, ...], device: torch.device | str, ) -> Float[Tensor, "... shape"]: + """Create initial PGD source tensor (random, ones, or zeroes). Shared by training PGD and app eval PGD.""" match init: case "random": return torch.rand(shape, device=device) @@ -315,7 +316,7 @@ def _get_pgd_init_tensor( return torch.zeros(shape, device=device) -def _interpolate_component_mask( +def interpolate_pgd_mask( ci: dict[str, Float[Tensor, "*batch_dims C"]], adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]], ) -> dict[str, Float[Tensor, "*batch_dims C"]]: diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index cc5ed5a0e..289197efc 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -16,6 +16,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB from spd.app.backend.routers import graphs as graphs_router +from spd.app.backend.routers import intervention as intervention_router from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app from spd.app.backend.state import RunState, StateManager @@ -54,6 +55,7 @@ def app_with_state(): # Patch DEVICE in all router modules to use CPU for tests with ( mock.patch.object(graphs_router, "DEVICE", DEVICE), + mock.patch.object(intervention_router, "DEVICE", DEVICE), mock.patch.object(runs_router, "DEVICE", DEVICE), ): db = PromptAttrDB(db_path=Path(":memory:"), check_same_thread=False) @@ -232,6 +234,47 @@ def test_compute_graph(app_with_prompt: tuple[TestClient, int]): assert "outputProbs" in data +def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClient, int]): + """Run-and-save intervention should use graph-linked prompt tokens (no text in request).""" + client, prompt_id = app_with_prompt + + graph_response = client.post( + "/api/graphs", + params={"prompt_id": prompt_id, "normalize": "none", "ci_threshold": 0.0}, + ) + assert graph_response.status_code == 200 + events = [line for line in graph_response.text.strip().split("\n") if line.startswith("data:")] + final_data = json.loads(events[-1].replace("data: ", "")) + graph_data = final_data["data"] + graph_id = graph_data["id"] + + selected_nodes = [ + key + for key, ci in graph_data["nodeCiVals"].items() + if not key.startswith("embed:") and not key.startswith("output:") and ci > 0 + ] + assert len(selected_nodes) > 0 + + request = { + "graph_id": graph_id, + "selected_nodes": selected_nodes[: min(5, len(selected_nodes))], + "top_k": 5, + "adv_pgd": {"n_steps": 1, "step_size": 1.0}, + } + response = client.post("/api/intervention/run", json=request) + assert response.status_code == 200 + body = response.json() + assert body["selected_nodes"] == request["selected_nodes"] + result = body["result"] + assert len(result["input_tokens"]) > 0 + assert len(result["ci"]) > 0 + assert len(result["stochastic"]) > 0 + assert len(result["adversarial"]) > 0 + assert "ci_loss" in result + assert "stochastic_loss" in result + assert "adversarial_loss" in result + + # ----------------------------------------------------------------------------- # Streaming: Prompt Generation # ----------------------------------------------------------------------------- From dadedba372cd9038fd61d096e260eec62de08fc7 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 15:57:10 +0000 Subject: [PATCH 085/102] Add target-sans masking strategy to interventions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Target-sans (T\S) shows the full target model with only the unselected alive nodes ablated — answering "what do the deselected components contribute?" Complements CI (only selected active) by showing the inverse view. Mask: everything=1 except alive-but-unselected=0. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/compute.py | 22 ++++++++++++- .../prompt-attr/InterventionsView.svelte | 33 +++++++++++-------- spd/app/frontend/src/lib/interventionTypes.ts | 2 ++ tests/app/test_server_api.py | 2 ++ 4 files changed, 44 insertions(+), 15 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index d04b1ca50..8295f7d02 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -684,15 +684,17 @@ class TokenPrediction(BaseModel): class InterventionResult(BaseModel): - """Unified result of an intervention evaluation (CI, stochastic, adversarial masking).""" + """Unified result of an intervention evaluation under multiple masking regimes.""" input_tokens: list[str] ci: list[list[TokenPrediction]] stochastic: list[list[TokenPrediction]] adversarial: list[list[TokenPrediction]] + target_sans: list[list[TokenPrediction]] ci_loss: float stochastic_loss: float adversarial_loss: float + target_sans_loss: float # Default eval PGD settings (distinct from optimization PGD which is a training regularizer) @@ -788,6 +790,14 @@ def compute_intervention( f"Selected node {layer}:{seq_pos}:{c_idx} is not alive in the graph" ) + # Target-sans masks: everything=1 except alive-but-unselected=0 + target_sans_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name in ci_masks: + mask = torch.ones_like(ci_masks[layer_name]) + alive_unselected = graph_alive_masks[layer_name] & (ci_masks[layer_name] == 0) + mask[alive_unselected] = 0.0 + target_sans_masks[layer_name] = mask + with torch.no_grad(), bf16_autocast(): # Target forward (unmasked) target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) @@ -804,6 +814,10 @@ def compute_intervention( stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) + # Target-sans forward: full model minus unselected alive nodes + ts_mask_infos = make_mask_infos(target_sans_masks, routing_masks="all") + ts_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ts_mask_infos) + # Adversarial: PGD optimizes alive-but-unselected components adv_sources = run_adv_pgd( model=model, @@ -830,6 +844,7 @@ def compute_intervention( ci_preds = _extract_topk_predictions(ci_logits, target_logits, tokenizer, top_k) stoch_preds = _extract_topk_predictions(stoch_logits, target_logits, tokenizer, top_k) adv_preds = _extract_topk_predictions(adv_logits, target_logits, tokenizer, top_k) + ts_preds = _extract_topk_predictions(ts_logits, target_logits, tokenizer, top_k) ci_loss = float( compute_recon_loss(ci_logits, loss_config, target_logits, device_str).item() @@ -840,6 +855,9 @@ def compute_intervention( adv_loss = float( compute_recon_loss(adv_logits, loss_config, target_logits, device_str).item() ) + ts_loss = float( + compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() + ) input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) @@ -848,7 +866,9 @@ def compute_intervention( ci=ci_preds, stochastic=stoch_preds, adversarial=adv_preds, + target_sans=ts_preds, ci_loss=ci_loss, stochastic_loss=stoch_loss, adversarial_loss=adv_loss, + target_sans_loss=ts_loss, ) diff --git a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte index cf0c66fea..34d3dfa38 100644 --- a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte @@ -139,6 +139,7 @@ if (interventionResult.adversarial.length > 0) rows.push({ label: "Adv", preds: interventionResult.adversarial }); if (interventionResult.stochastic.length > 0) rows.push({ label: "Stoch", preds: interventionResult.stochastic }); rows.push({ label: "CI", preds: interventionResult.ci }); + if (interventionResult.target_sans.length > 0) rows.push({ label: "T\\S", preds: interventionResult.target_sans }); return rows; }); @@ -663,21 +664,21 @@ > {runningIntervention ? "Forwarding..." : "Forward"} + + + + {:else} {/if} - - - - @@ -743,10 +744,10 @@ {#each row.preds as preds, seqIdx (seqIdx)} {@const colX = layout.seqXStarts[seqIdx]} {@const colW = layout.seqWidths[seqIdx]} - {@const chipW = Math.min(48, Math.floor((colW - 2) / Math.max(preds.length, 1)))} + {@const chipW = 48} {@const chipH = PRED_ROW_HEIGHT} {@const chipGap = 1} - {@const maxChips = Math.max(1, Math.floor((colW - 2 + chipGap) / (chipW + chipGap)))} + {@const maxChips = Math.min(preds.length, Math.max(1, Math.floor((colW - 2 + chipGap) / (chipW + chipGap))))} {#each preds.slice(0, maxChips) as pred, rank (rank)} {@const cx = colX + rank * (chipW + chipGap)} @@ -1019,6 +1020,10 @@ adv {run.result.adversarial_loss.toFixed(3)} +

+ T\S + {run.result.target_sans_loss.toFixed(3)} +
metric {lossLabel} diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index d0715ae0a..23e2f996e 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -18,9 +18,11 @@ export type InterventionResult = { ci: TokenPrediction[][]; stochastic: TokenPrediction[][]; adversarial: TokenPrediction[][]; + target_sans: TokenPrediction[][]; ci_loss: number; stochastic_loss: number; adversarial_loss: number; + target_sans_loss: number; }; /** Persisted intervention run from the server */ diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 289197efc..dc9818c64 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -270,9 +270,11 @@ def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClien assert len(result["ci"]) > 0 assert len(result["stochastic"]) > 0 assert len(result["adversarial"]) > 0 + assert len(result["target_sans"]) > 0 assert "ci_loss" in result assert "stochastic_loss" in result assert "adversarial_loss" in result + assert "target_sans_loss" in result # ----------------------------------------------------------------------------- From c1d2e05507bb11a59d48fbcbd3c620d2910b5604 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 16:41:30 +0000 Subject: [PATCH 086/102] Fix target-sans to use weight deltas for exact target reconstruction With weight deltas, components + delta = exact target model output. Without them, all-ones mask has small reconstruction error vs target. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/compute.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 8295f7d02..73d1c3092 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -790,13 +790,18 @@ def compute_intervention( f"Selected node {layer}:{seq_pos}:{c_idx} is not alive in the graph" ) - # Target-sans masks: everything=1 except alive-but-unselected=0 + # Target-sans: full target model with alive-but-unselected components zeroed out. + # Includes weight deltas so components + delta = exact target reconstruction. target_sans_masks: dict[str, Float[Tensor, "1 seq C"]] = {} for layer_name in ci_masks: mask = torch.ones_like(ci_masks[layer_name]) alive_unselected = graph_alive_masks[layer_name] & (ci_masks[layer_name] == 0) mask[alive_unselected] = 0.0 target_sans_masks[layer_name] = mask + weight_deltas = model.calc_weight_deltas() + ts_weight_deltas_and_masks = { + k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() + } with torch.no_grad(), bf16_autocast(): # Target forward (unmasked) @@ -814,8 +819,12 @@ def compute_intervention( stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) - # Target-sans forward: full model minus unselected alive nodes - ts_mask_infos = make_mask_infos(target_sans_masks, routing_masks="all") + # Target-sans forward: target model with unselected alive nodes ablated + ts_mask_infos = make_mask_infos( + target_sans_masks, + routing_masks="all", + weight_deltas_and_masks=ts_weight_deltas_and_masks, + ) ts_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ts_mask_infos) # Adversarial: PGD optimizes alive-but-unselected components From abbca53094c19a8a593427995cd4c51980e198d9 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 16:55:29 +0000 Subject: [PATCH 087/102] Disable bf16 autocast for target-sans forward pass Exact target reconstruction requires full precision. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/compute.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 73d1c3092..9d2a05fe5 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -819,7 +819,8 @@ def compute_intervention( stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) - # Target-sans forward: target model with unselected alive nodes ablated + # Target-sans forward: no bf16 autocast — exact target reconstruction matters here + with torch.no_grad(): ts_mask_infos = make_mask_infos( target_sans_masks, routing_masks="all", From 82c73d8493c5ab6f5b5a2acecb3b8d871f61e377 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 17:32:31 +0000 Subject: [PATCH 088/102] =?UTF-8?q?Revert=20bf16=20removal=20=E2=80=94=20w?= =?UTF-8?q?eights=20are=20already=20bf16,=20autocast=20is=20irrelevant?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/compute.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 9d2a05fe5..73d1c3092 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -819,8 +819,7 @@ def compute_intervention( stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) - # Target-sans forward: no bf16 autocast — exact target reconstruction matters here - with torch.no_grad(): + # Target-sans forward: target model with unselected alive nodes ablated ts_mask_infos = make_mask_infos( target_sans_masks, routing_masks="all", From a93a29af5e96201215962daa415a717b282e4892 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 18:14:11 +0000 Subject: [PATCH 089/102] Highlight CE label token in intervention predictions Add LabelPredictions to InterventionResult: per-regime prediction stats for the CE-optimized token. Frontend highlights matching topk chips with amber border, and appends a dashed-border chip when the label token falls outside topk. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/compute.py | 44 ++++++++++++++++ .../prompt-attr/InterventionsView.svelte | 50 ++++++++++++++++--- spd/app/frontend/src/lib/interventionTypes.ts | 9 ++++ 3 files changed, 95 insertions(+), 8 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 73d1c3092..e35fa3c88 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -18,6 +18,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.optim_cis import ( AdvPGDConfig, + CELossConfig, CISnapshotCallback, LossConfig, OptimCIConfig, @@ -683,6 +684,16 @@ class TokenPrediction(BaseModel): target_logit: float +class LabelPredictions(BaseModel): + """Prediction stats for the CE label token at the optimized position, per masking regime.""" + + position: int + ci: TokenPrediction + stochastic: TokenPrediction + adversarial: TokenPrediction + target_sans: TokenPrediction + + class InterventionResult(BaseModel): """Unified result of an intervention evaluation under multiple masking regimes.""" @@ -695,6 +706,7 @@ class InterventionResult(BaseModel): stochastic_loss: float adversarial_loss: float target_sans_loss: float + label: LabelPredictions | None # Default eval PGD settings (distinct from optimization PGD which is a training regularizer) @@ -758,6 +770,26 @@ def _extract_topk_predictions( return result +def _extract_label_prediction( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + position: int, + label_token: int, +) -> TokenPrediction: + """Extract the prediction for a specific token at a specific position.""" + probs = torch.softmax(logits[0, position], dim=-1) + target_probs = torch.softmax(target_logits[0, position], dim=-1) + return TokenPrediction( + token=tokenizer.get_tok_display(label_token), + token_id=label_token, + prob=float(probs[label_token].item()), + logit=float(logits[0, position, label_token].item()), + target_prob=float(target_probs[label_token].item()), + target_logit=float(target_logits[0, position, label_token].item()), + ) + + def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], @@ -868,6 +900,17 @@ def compute_intervention( compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() ) + label: LabelPredictions | None = None + if isinstance(loss_config, CELossConfig): + pos, tid = loss_config.position, loss_config.label_token + label = LabelPredictions( + position=pos, + ci=_extract_label_prediction(ci_logits, target_logits, tokenizer, pos, tid), + stochastic=_extract_label_prediction(stoch_logits, target_logits, tokenizer, pos, tid), + adversarial=_extract_label_prediction(adv_logits, target_logits, tokenizer, pos, tid), + target_sans=_extract_label_prediction(ts_logits, target_logits, tokenizer, pos, tid), + ) + input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) return InterventionResult( @@ -880,4 +923,5 @@ def compute_intervention( stochastic_loss=stoch_loss, adversarial_loss=adv_loss, target_sans_loss=ts_loss, + label=label, ) diff --git a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte index 34d3dfa38..f832661f8 100644 --- a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte @@ -131,15 +131,16 @@ return null; }); - // Prediction rows for rendering: [{label, preds}] ordered top-to-bottom (Adv, Stoch, CI) - type PredRow = { label: string; preds: TokenPrediction[][] }; + // Prediction rows for rendering: [{label, preds, labelPred}] ordered top-to-bottom (Adv, Stoch, CI) + type PredRow = { label: string; preds: TokenPrediction[][]; labelPred: TokenPrediction | null }; const predRows = $derived.by((): PredRow[] | null => { if (!interventionResult) return null; + const lbl = interventionResult.label; const rows: PredRow[] = []; - if (interventionResult.adversarial.length > 0) rows.push({ label: "Adv", preds: interventionResult.adversarial }); - if (interventionResult.stochastic.length > 0) rows.push({ label: "Stoch", preds: interventionResult.stochastic }); - rows.push({ label: "CI", preds: interventionResult.ci }); - if (interventionResult.target_sans.length > 0) rows.push({ label: "T\\S", preds: interventionResult.target_sans }); + if (interventionResult.adversarial.length > 0) rows.push({ label: "Adv", preds: interventionResult.adversarial, labelPred: lbl?.adversarial ?? null }); + if (interventionResult.stochastic.length > 0) rows.push({ label: "Stoch", preds: interventionResult.stochastic, labelPred: lbl?.stochastic ?? null }); + rows.push({ label: "CI", preds: interventionResult.ci, labelPred: lbl?.ci ?? null }); + if (interventionResult.target_sans.length > 0) rows.push({ label: "T\\S", preds: interventionResult.target_sans, labelPred: lbl?.target_sans ?? null }); return rows; }); @@ -747,9 +748,13 @@ {@const chipW = 48} {@const chipH = PRED_ROW_HEIGHT} {@const chipGap = 1} + {@const isLabelPos = interventionResult?.label != null && seqIdx === interventionResult.label.position} + {@const labelTokenId = isLabelPos ? row.labelPred?.token_id ?? null : null} + {@const labelInTopk = labelTokenId != null && preds.some((p) => p.token_id === labelTokenId)} {@const maxChips = Math.min(preds.length, Math.max(1, Math.floor((colW - 2 + chipGap) / (chipW + chipGap))))} {#each preds.slice(0, maxChips) as pred, rank (rank)} {@const cx = colX + rank * (chipW + chipGap)} + {@const isLabel = labelTokenId != null && pred.token_id === labelTokenId} handlePredMouseEnter(e, pred, row.label, seqIdx)} @@ -762,8 +767,8 @@ height={chipH} rx="2" fill={getNextTokenProbBgColor(pred.prob)} - stroke="#ddd" - stroke-width="0.5" + stroke={isLabel ? "#f59e0b" : "#ddd"} + stroke-width={isLabel ? "1.5" : "0.5"} /> {/each} + + {#if isLabelPos && !labelInTopk && row.labelPred} + {@const cx = colX + maxChips * (chipW + chipGap) + chipGap} + + handlePredMouseEnter(e, row.labelPred!, row.label, seqIdx)} + onmouseleave={handlePredMouseLeave} + > + + 0.5 ? "white" : colors.textPrimary}>{row.labelPred.token} + + {/if} {/each} {/each} diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index 23e2f996e..a0bb2c473 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -13,6 +13,14 @@ export type TokenPrediction = { target_logit: number; }; +export type LabelPredictions = { + position: number; + ci: TokenPrediction; + stochastic: TokenPrediction; + adversarial: TokenPrediction; + target_sans: TokenPrediction; +}; + export type InterventionResult = { input_tokens: string[]; ci: TokenPrediction[][]; @@ -23,6 +31,7 @@ export type InterventionResult = { stochastic_loss: number; adversarial_loss: number; target_sans_loss: number; + label: LabelPredictions | null; }; /** Persisted intervention run from the server */ From 03206b2fa4310d03c53b29af1de38e485f52553d Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 18:29:27 +0000 Subject: [PATCH 090/102] Compute alive masks from natural CI, not graph's node_ci_vals MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Alive masks (the PGD degrees of freedom and target-sans ablation set) should always be the model's full binarized CI — not the graph's potentially sparse optimized CI. compute_intervention now recomputes natural CI internally via one forward pass + calc_causal_importances. Removes build_graph_alive_masks (no longer needed) and the graph_alive_masks parameter from compute_intervention. Callers now pass sampling type instead. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/CLAUDE.md | 2 +- spd/app/backend/compute.py | 63 ++++++++++--------------- spd/app/backend/routers/graphs.py | 12 ++--- spd/app/backend/routers/intervention.py | 8 +--- spd/app/backend/routers/mcp.py | 8 +--- 5 files changed, 33 insertions(+), 60 deletions(-) diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 720fc6dd2..48e8e041e 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -235,7 +235,7 @@ Returns `InterventionResult` with top-k `TokenPrediction`s per position for each This loss is used for two things: (1) what PGD maximizes during adversarial evaluation, and (2) the `ci_loss`/`stochastic_loss`/`adversarial_loss` metrics reported in `InterventionResult`. -**Alive masks**: `build_graph_alive_masks()` constructs boolean masks from the graph's `node_ci_vals` (CI > 0). These define the PGD degrees of freedom — PGD can only manipulate alive-but-unselected components. +**Alive masks**: `compute_intervention` recomputes the model's natural CI (one forward pass + `calc_causal_importances`) and binarizes at 0 to get alive masks. This ensures the alive set is always the full model's CI — not the graph's potentially sparse optimized CI. PGD can only manipulate alive-but-unselected components. **Training PGD vs Eval PGD**: The PGD settings in the graph optimization config (`adv_pgd_n_steps`, `adv_pgd_step_size`) are a *training* regularizer — they make CI optimization robust. The PGD in `compute_intervention` is an *eval* metric — it measures worst-case leakage for a given node selection. Eval PGD defaults are in `compute.py` (`DEFAULT_EVAL_PGD_CONFIG`). diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index e35fa3c88..da0cee23a 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -713,34 +713,6 @@ class InterventionResult(BaseModel): DEFAULT_EVAL_PGD_CONFIG = AdvPGDConfig(n_steps=4, step_size=1.0, init="random") -def build_graph_alive_masks( - node_ci_vals: dict[str, float], - model: ComponentModel, - topology: TransformerTopology, - seq_len: int, - device: str | torch.device, -) -> dict[str, Bool[Tensor, "1 seq C"]]: - """Build alive masks from graph's node_ci_vals (CI > 0), excluding embed/output.""" - alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = {} - for layer_name, C in model.module_to_c.items(): - alive_masks[layer_name] = torch.zeros(1, seq_len, C, device=device, dtype=torch.bool) - - for key, ci in node_ci_vals.items(): - if ci <= 0: - continue - canon_layer, seq_str, cidx_str = key.split(":") - if canon_layer in ("embed", "output"): - continue - concrete_path = topology.canon_to_target(canon_layer) - assert concrete_path in alive_masks, ( - f"node_ci_vals has layer {canon_layer!r} (concrete: {concrete_path!r}) " - f"not in model.module_to_c" - ) - alive_masks[concrete_path][0, int(seq_str), int(cidx_str)] = True - - return alive_masks - - def _extract_topk_predictions( logits: Float[Tensor, "1 seq vocab"], target_logits: Float[Tensor, "1 seq vocab"], @@ -794,32 +766,47 @@ def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], - graph_alive_masks: dict[str, Bool[Tensor, "1 seq C"]], tokenizer: AppTokenizer, adv_pgd_config: AdvPGDConfig, loss_config: LossConfig, + sampling: SamplingType, top_k: int, ) -> InterventionResult: - """Unified intervention evaluation: CI, stochastic, and adversarial masking. + """Unified intervention evaluation: CI, stochastic, adversarial, and target-sans masking. + + Computes the model's natural CI to determine alive masks (CI > 0). PGD optimizes + alive-but-unselected components; non-alive get uniform random. Target-sans ablates + only the unselected alive nodes from the full target model. Args: active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. - graph_alive_masks: The graph's CI > 0 mask (all alive components). PGD optimizes - alive-but-unselected components; non-alive get uniform random. - loss_config: Loss for PGD adversary to maximize. + loss_config: Loss for PGD adversary to maximize and for reporting metrics. + sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. """ seq_len = tokens.shape[1] device = tokens.device + # Compute natural CI alive masks (the model's own binarized CI, independent of graph) + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = { + k: v > 0 for k, v in ci_outputs.lower_leaky.items() + } + # Build binary CI masks from active nodes (selected = 1, rest = 0) ci_masks: dict[str, Float[Tensor, "1 seq C"]] = {} for layer_name, C in model.module_to_c.items(): ci_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) for layer, seq_pos, c_idx in active_nodes: ci_masks[layer][0, seq_pos, c_idx] = 1.0 - assert graph_alive_masks[layer][0, seq_pos, c_idx], ( - f"Selected node {layer}:{seq_pos}:{c_idx} is not alive in the graph" + assert alive_masks[layer][0, seq_pos, c_idx], ( + f"Selected node {layer}:{seq_pos}:{c_idx} is not alive (CI=0)" ) # Target-sans: full target model with alive-but-unselected components zeroed out. @@ -827,7 +814,7 @@ def compute_intervention( target_sans_masks: dict[str, Float[Tensor, "1 seq C"]] = {} for layer_name in ci_masks: mask = torch.ones_like(ci_masks[layer_name]) - alive_unselected = graph_alive_masks[layer_name] & (ci_masks[layer_name] == 0) + alive_unselected = alive_masks[layer_name] & (ci_masks[layer_name] == 0) mask[alive_unselected] = 0.0 target_sans_masks[layer_name] = mask weight_deltas = model.calc_weight_deltas() @@ -864,7 +851,7 @@ def compute_intervention( model=model, tokens=tokens, ci=ci_masks, - alive_masks=graph_alive_masks, + alive_masks=alive_masks, adv_config=adv_pgd_config, target_out=target_logits, loss_config=loss_config, @@ -873,7 +860,7 @@ def compute_intervention( adv_masks = interpolate_pgd_mask(ci_masks, adv_sources) with torch.no_grad(): for layer in adv_masks: - non_alive = ~graph_alive_masks[layer] + non_alive = ~alive_masks[layer] adv_masks[layer][non_alive] = torch.rand(int(non_alive.sum().item()), device=device) with torch.no_grad(), bf16_autocast(): adv_mask_infos = make_mask_infos(adv_masks, routing_masks="all") diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 85f374023..3da8f4c97 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -21,7 +21,6 @@ from spd.app.backend.compute import ( DEFAULT_EVAL_PGD_CONFIG, Edge, - build_graph_alive_masks, compute_intervention, compute_prompt_attributions, compute_prompt_attributions_optimized, @@ -46,7 +45,7 @@ ) from spd.app.backend.schemas import OutputProbability from spd.app.backend.utils import log_errors -from spd.configs import ImportanceMinimalityLossConfig +from spd.configs import ImportanceMinimalityLossConfig, SamplingType from spd.log import logger from spd.models.component_model import ComponentModel from spd.topology import TransformerTopology @@ -63,6 +62,7 @@ def _save_base_intervention_run( tokenizer: AppTokenizer, topology: TransformerTopology, db: PromptAttrDB, + sampling: SamplingType, loss_config: LossConfig | None = None, ) -> None: """Compute intervention for all interventable nodes and save as an intervention run.""" @@ -79,10 +79,6 @@ def _save_base_intervention_run( concrete_path = topology.canon_to_target(canon_layer) active_nodes.append((concrete_path, int(seq_str), int(cidx_str))) - device = str(tokens.device) - seq_len = tokens.shape[1] - graph_alive_masks = build_graph_alive_masks(node_ci_vals, model, topology, seq_len, device) - effective_loss_config: LossConfig = ( loss_config if loss_config is not None else MeanKLLossConfig() ) @@ -91,10 +87,10 @@ def _save_base_intervention_run( model=model, tokens=tokens, active_nodes=active_nodes, - graph_alive_masks=graph_alive_masks, tokenizer=tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=effective_loss_config, + sampling=sampling, top_k=10, ) @@ -569,6 +565,7 @@ def work( tokenizer=loaded.tokenizer, topology=loaded.topology, db=db, + sampling=loaded.config.sampling, ) logger.info(f"[perf] base intervention run: {time.perf_counter() - t0:.2f}s") @@ -783,6 +780,7 @@ def work( tokenizer=loaded.tokenizer, topology=loaded.topology, db=db, + sampling=loaded.config.sampling, loss_config=loss_config, ) diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 194de172b..2650a6701 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -6,7 +6,6 @@ from spd.app.backend.compute import ( InterventionResult, - build_graph_alive_masks, compute_intervention, ) from spd.app.backend.dependencies import DepDB, DepLoadedRun, DepStateManager @@ -99,11 +98,6 @@ def run_and_save_intervention( ) tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - # Build alive masks from graph's CI values - graph_alive_masks = build_graph_alive_masks( - graph.node_ci_vals, loaded.model, loaded.topology, len(token_ids), str(DEVICE) - ) - # Use graph's loss config if optimized, else mean KL loss_config: LossConfig = ( graph.optimization_params.loss @@ -115,7 +109,6 @@ def run_and_save_intervention( model=loaded.model, tokens=tokens, active_nodes=active_nodes, - graph_alive_masks=graph_alive_masks, tokenizer=loaded.tokenizer, adv_pgd_config=AdvPGDConfig( n_steps=request.adv_pgd.n_steps, @@ -123,6 +116,7 @@ def run_and_save_intervention( init="random", ), loss_config=loss_config, + sampling=loaded.config.sampling, top_k=request.top_k, ) diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py index 719e3572e..c1cc1fa54 100644 --- a/spd/app/backend/routers/mcp.py +++ b/spd/app/backend/routers/mcp.py @@ -897,21 +897,15 @@ def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") active_nodes.append((layer, int(seq_str), int(cidx_str))) - # Build trivial alive masks (all nodes alive — no graph context here) - alive_masks = { - layer_name: torch.ones(1, len(token_ids), C, device=DEVICE, dtype=torch.bool) - for layer_name, C in loaded.model.module_to_c.items() - } - with manager.gpu_lock(): result = compute_intervention( model=loaded.model, tokens=tokens, active_nodes=active_nodes, - graph_alive_masks=alive_masks, tokenizer=loaded.tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=MeanKLLossConfig(), + sampling=loaded.config.sampling, top_k=top_k, ) From d68a9e0ab86e7b4425068a6a34a585635ffec1ac Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 18:57:45 +0000 Subject: [PATCH 091/102] Target-sans: frontend sends explicit sans_nodes instead of backend inferring The frontend computes sans_nodes = allGraphNodes - selectedNodes and sends it in the request. This avoids the backend needing to know about graph-level vs natural-CI-level alive sets. Base intervention runs pass sans_nodes=[] (nothing to ablate). MCP ablation tool also passes sans_nodes=[]. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/compute.py | 17 ++++++++--------- spd/app/backend/routers/graphs.py | 1 + spd/app/backend/routers/intervention.py | 5 +++++ spd/app/backend/routers/mcp.py | 1 + .../src/components/PromptAttributionsTab.svelte | 4 ++++ spd/app/frontend/src/lib/interventionTypes.ts | 1 + tests/app/test_server_api.py | 5 ++++- 7 files changed, 24 insertions(+), 10 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index da0cee23a..02e6a0db9 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -766,6 +766,7 @@ def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], + sans_nodes: list[tuple[str, int, int]], tokenizer: AppTokenizer, adv_pgd_config: AdvPGDConfig, loss_config: LossConfig, @@ -774,12 +775,11 @@ def compute_intervention( ) -> InterventionResult: """Unified intervention evaluation: CI, stochastic, adversarial, and target-sans masking. - Computes the model's natural CI to determine alive masks (CI > 0). PGD optimizes - alive-but-unselected components; non-alive get uniform random. Target-sans ablates - only the unselected alive nodes from the full target model. - Args: active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. + Used for CI, stochastic, and adversarial masking. + sans_nodes: (concrete_path, seq_pos, component_idx) tuples for nodes to ablate + in target-sans. The frontend computes this as all_graph_nodes - selected_nodes. loss_config: Loss for PGD adversary to maximize and for reporting metrics. sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. @@ -809,14 +809,13 @@ def compute_intervention( f"Selected node {layer}:{seq_pos}:{c_idx} is not alive (CI=0)" ) - # Target-sans: full target model with alive-but-unselected components zeroed out. + # Target-sans: full target model with sans_nodes ablated. # Includes weight deltas so components + delta = exact target reconstruction. target_sans_masks: dict[str, Float[Tensor, "1 seq C"]] = {} for layer_name in ci_masks: - mask = torch.ones_like(ci_masks[layer_name]) - alive_unselected = alive_masks[layer_name] & (ci_masks[layer_name] == 0) - mask[alive_unselected] = 0.0 - target_sans_masks[layer_name] = mask + target_sans_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) + for layer, seq_pos, c_idx in sans_nodes: + target_sans_masks[layer][0, seq_pos, c_idx] = 0.0 weight_deltas = model.calc_weight_deltas() ts_weight_deltas_and_masks = { k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 3da8f4c97..9deadd648 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -87,6 +87,7 @@ def _save_base_intervention_run( model=model, tokens=tokens, active_nodes=active_nodes, + sans_nodes=[], tokenizer=tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=effective_loss_config, diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 2650a6701..7cf44d44f 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -29,6 +29,7 @@ class RunInterventionRequest(BaseModel): graph_id: int selected_nodes: list[str] # node keys (layer:seq:cIdx) + sans_nodes: list[str] # node keys to ablate in target-sans top_k: int adv_pgd: AdvPgdParams @@ -96,6 +97,9 @@ def run_and_save_intervention( active_nodes = _parse_and_validate_active_nodes( request.selected_nodes, loaded.topology, len(token_ids) ) + sans_nodes = _parse_and_validate_active_nodes( + request.sans_nodes, loaded.topology, len(token_ids) + ) tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) # Use graph's loss config if optimized, else mean KL @@ -109,6 +113,7 @@ def run_and_save_intervention( model=loaded.model, tokens=tokens, active_nodes=active_nodes, + sans_nodes=sans_nodes, tokenizer=loaded.tokenizer, adv_pgd_config=AdvPGDConfig( n_steps=request.adv_pgd.n_steps, diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py index c1cc1fa54..ffbcd7745 100644 --- a/spd/app/backend/routers/mcp.py +++ b/spd/app/backend/routers/mcp.py @@ -902,6 +902,7 @@ def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: model=loaded.model, tokens=tokens, active_nodes=active_nodes, + sans_nodes=[], tokenizer=loaded.tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=MeanKLLossConfig(), diff --git a/spd/app/frontend/src/components/PromptAttributionsTab.svelte b/spd/app/frontend/src/components/PromptAttributionsTab.svelte index 939510739..168cb1e19 100644 --- a/spd/app/frontend/src/components/PromptAttributionsTab.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsTab.svelte @@ -369,10 +369,14 @@ runningIntervention = true; try { const selectedNodes = Array.from(activeRun.selectedNodes); + const baseRun = state.runs[0]; + if (baseRun.kind !== "baked") throw new Error("First run must be baked base run"); + const sansNodes = Array.from(baseRun.selectedNodes).filter((n) => !activeRun.selectedNodes.has(n)); const run = await api.runAndSaveIntervention({ graph_id: activeGraph.id, selected_nodes: selectedNodes, + sans_nodes: sansNodes, top_k: 10, adv_pgd: advPgd, }); diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index a0bb2c473..917bc6909 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -46,6 +46,7 @@ export type InterventionRunSummary = { export type RunInterventionRequest = { graph_id: number; selected_nodes: string[]; + sans_nodes: string[]; top_k: number; adv_pgd: { n_steps: number; step_size: number }; }; diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index dc9818c64..058b61cb1 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -255,9 +255,12 @@ def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClien ] assert len(selected_nodes) > 0 + subset = selected_nodes[: min(5, len(selected_nodes))] + sans = [n for n in selected_nodes if n not in subset] request = { "graph_id": graph_id, - "selected_nodes": selected_nodes[: min(5, len(selected_nodes))], + "selected_nodes": subset, + "sans_nodes": sans, "top_k": 5, "adv_pgd": {"n_steps": 1, "step_size": 1.0}, } From cc6b4e76dec32128b429324c68b2ed5d2d3be844 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 3 Mar 2026 19:03:07 +0000 Subject: [PATCH 092/102] =?UTF-8?q?Make=20sans=5Fnodes=20optional=20?= =?UTF-8?q?=E2=80=94=20skip=20target-sans=20when=20None?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sans_nodes: list | None (no default). Callers pass None explicitly when target-sans isn't needed (base intervention runs, MCP). Frontend sends sans_nodes only when the user has deselected nodes. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/backend/compute.py | 69 ++++++++++--------- spd/app/backend/routers/graphs.py | 2 +- spd/app/backend/routers/intervention.py | 8 ++- spd/app/backend/routers/mcp.py | 2 +- .../components/PromptAttributionsTab.svelte | 2 +- .../prompt-attr/InterventionsView.svelte | 12 ++-- spd/app/frontend/src/lib/interventionTypes.ts | 8 +-- tests/app/test_server_api.py | 9 +-- 8 files changed, 60 insertions(+), 52 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 02e6a0db9..da015c885 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -691,7 +691,7 @@ class LabelPredictions(BaseModel): ci: TokenPrediction stochastic: TokenPrediction adversarial: TokenPrediction - target_sans: TokenPrediction + target_sans: TokenPrediction | None class InterventionResult(BaseModel): @@ -701,11 +701,11 @@ class InterventionResult(BaseModel): ci: list[list[TokenPrediction]] stochastic: list[list[TokenPrediction]] adversarial: list[list[TokenPrediction]] - target_sans: list[list[TokenPrediction]] + target_sans: list[list[TokenPrediction]] | None ci_loss: float stochastic_loss: float adversarial_loss: float - target_sans_loss: float + target_sans_loss: float | None label: LabelPredictions | None @@ -766,20 +766,21 @@ def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], - sans_nodes: list[tuple[str, int, int]], + sans_nodes: list[tuple[str, int, int]] | None, tokenizer: AppTokenizer, adv_pgd_config: AdvPGDConfig, loss_config: LossConfig, sampling: SamplingType, top_k: int, ) -> InterventionResult: - """Unified intervention evaluation: CI, stochastic, adversarial, and target-sans masking. + """Unified intervention evaluation: CI, stochastic, adversarial, and optionally target-sans. Args: active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. Used for CI, stochastic, and adversarial masking. - sans_nodes: (concrete_path, seq_pos, component_idx) tuples for nodes to ablate - in target-sans. The frontend computes this as all_graph_nodes - selected_nodes. + sans_nodes: If provided, nodes to ablate in target-sans (full target model minus these). + The frontend computes this as all_graph_nodes - selected_nodes. + If None, target-sans is skipped. loss_config: Loss for PGD adversary to maximize and for reporting metrics. sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. @@ -809,18 +810,6 @@ def compute_intervention( f"Selected node {layer}:{seq_pos}:{c_idx} is not alive (CI=0)" ) - # Target-sans: full target model with sans_nodes ablated. - # Includes weight deltas so components + delta = exact target reconstruction. - target_sans_masks: dict[str, Float[Tensor, "1 seq C"]] = {} - for layer_name in ci_masks: - target_sans_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) - for layer, seq_pos, c_idx in sans_nodes: - target_sans_masks[layer][0, seq_pos, c_idx] = 0.0 - weight_deltas = model.calc_weight_deltas() - ts_weight_deltas_and_masks = { - k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() - } - with torch.no_grad(), bf16_autocast(): # Target forward (unmasked) target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) @@ -837,13 +826,22 @@ def compute_intervention( stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) - # Target-sans forward: target model with unselected alive nodes ablated - ts_mask_infos = make_mask_infos( - target_sans_masks, - routing_masks="all", - weight_deltas_and_masks=ts_weight_deltas_and_masks, - ) - ts_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ts_mask_infos) + # Target-sans forward (only if sans_nodes provided) + ts_logits: Float[Tensor, "1 seq vocab"] | None = None + if sans_nodes is not None: + ts_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name in ci_masks: + ts_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) + for layer, seq_pos, c_idx in sans_nodes: + ts_masks[layer][0, seq_pos, c_idx] = 0.0 + weight_deltas = model.calc_weight_deltas() + ts_wd = { + k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() + } + ts_mask_infos = make_mask_infos( + ts_masks, routing_masks="all", weight_deltas_and_masks=ts_wd + ) + ts_logits = model(tokens, mask_infos=ts_mask_infos) # Adversarial: PGD optimizes alive-but-unselected components adv_sources = run_adv_pgd( @@ -871,7 +869,6 @@ def compute_intervention( ci_preds = _extract_topk_predictions(ci_logits, target_logits, tokenizer, top_k) stoch_preds = _extract_topk_predictions(stoch_logits, target_logits, tokenizer, top_k) adv_preds = _extract_topk_predictions(adv_logits, target_logits, tokenizer, top_k) - ts_preds = _extract_topk_predictions(ts_logits, target_logits, tokenizer, top_k) ci_loss = float( compute_recon_loss(ci_logits, loss_config, target_logits, device_str).item() @@ -882,19 +879,29 @@ def compute_intervention( adv_loss = float( compute_recon_loss(adv_logits, loss_config, target_logits, device_str).item() ) - ts_loss = float( - compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() - ) + + ts_preds: list[list[TokenPrediction]] | None = None + ts_loss: float | None = None + if ts_logits is not None: + ts_preds = _extract_topk_predictions(ts_logits, target_logits, tokenizer, top_k) + ts_loss = float( + compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() + ) label: LabelPredictions | None = None if isinstance(loss_config, CELossConfig): pos, tid = loss_config.position, loss_config.label_token + ts_label = ( + _extract_label_prediction(ts_logits, target_logits, tokenizer, pos, tid) + if ts_logits is not None + else None + ) label = LabelPredictions( position=pos, ci=_extract_label_prediction(ci_logits, target_logits, tokenizer, pos, tid), stochastic=_extract_label_prediction(stoch_logits, target_logits, tokenizer, pos, tid), adversarial=_extract_label_prediction(adv_logits, target_logits, tokenizer, pos, tid), - target_sans=_extract_label_prediction(ts_logits, target_logits, tokenizer, pos, tid), + target_sans=ts_label, ) input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 9deadd648..91ab27740 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -87,7 +87,7 @@ def _save_base_intervention_run( model=model, tokens=tokens, active_nodes=active_nodes, - sans_nodes=[], + sans_nodes=None, tokenizer=tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=effective_loss_config, diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 7cf44d44f..80f28ab1d 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -29,7 +29,7 @@ class RunInterventionRequest(BaseModel): graph_id: int selected_nodes: list[str] # node keys (layer:seq:cIdx) - sans_nodes: list[str] # node keys to ablate in target-sans + sans_nodes: list[str] | None = None # node keys to ablate in target-sans (omit to skip) top_k: int adv_pgd: AdvPgdParams @@ -97,8 +97,10 @@ def run_and_save_intervention( active_nodes = _parse_and_validate_active_nodes( request.selected_nodes, loaded.topology, len(token_ids) ) - sans_nodes = _parse_and_validate_active_nodes( - request.sans_nodes, loaded.topology, len(token_ids) + sans_nodes = ( + _parse_and_validate_active_nodes(request.sans_nodes, loaded.topology, len(token_ids)) + if request.sans_nodes is not None + else None ) tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py index ffbcd7745..043201378 100644 --- a/spd/app/backend/routers/mcp.py +++ b/spd/app/backend/routers/mcp.py @@ -902,7 +902,7 @@ def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: model=loaded.model, tokens=tokens, active_nodes=active_nodes, - sans_nodes=[], + sans_nodes=None, tokenizer=loaded.tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=MeanKLLossConfig(), diff --git a/spd/app/frontend/src/components/PromptAttributionsTab.svelte b/spd/app/frontend/src/components/PromptAttributionsTab.svelte index 168cb1e19..0ba2357e6 100644 --- a/spd/app/frontend/src/components/PromptAttributionsTab.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsTab.svelte @@ -376,7 +376,7 @@ const run = await api.runAndSaveIntervention({ graph_id: activeGraph.id, selected_nodes: selectedNodes, - sans_nodes: sansNodes, + sans_nodes: sansNodes.length > 0 ? sansNodes : undefined, top_k: 10, adv_pgd: advPgd, }); diff --git a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte index f832661f8..2ab189191 100644 --- a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte @@ -140,7 +140,7 @@ if (interventionResult.adversarial.length > 0) rows.push({ label: "Adv", preds: interventionResult.adversarial, labelPred: lbl?.adversarial ?? null }); if (interventionResult.stochastic.length > 0) rows.push({ label: "Stoch", preds: interventionResult.stochastic, labelPred: lbl?.stochastic ?? null }); rows.push({ label: "CI", preds: interventionResult.ci, labelPred: lbl?.ci ?? null }); - if (interventionResult.target_sans.length > 0) rows.push({ label: "T\\S", preds: interventionResult.target_sans, labelPred: lbl?.target_sans ?? null }); + if (interventionResult.target_sans && interventionResult.target_sans.length > 0) rows.push({ label: "T\\S", preds: interventionResult.target_sans, labelPred: lbl?.target_sans ?? null }); return rows; }); @@ -1054,10 +1054,12 @@ adv {run.result.adversarial_loss.toFixed(3)}
-
- T\S - {run.result.target_sans_loss.toFixed(3)} -
+ {#if run.result.target_sans_loss != null} +
+ T\S + {run.result.target_sans_loss.toFixed(3)} +
+ {/if}
metric {lossLabel} diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index 917bc6909..262aa3dfd 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -18,7 +18,7 @@ export type LabelPredictions = { ci: TokenPrediction; stochastic: TokenPrediction; adversarial: TokenPrediction; - target_sans: TokenPrediction; + target_sans: TokenPrediction | null; }; export type InterventionResult = { @@ -26,11 +26,11 @@ export type InterventionResult = { ci: TokenPrediction[][]; stochastic: TokenPrediction[][]; adversarial: TokenPrediction[][]; - target_sans: TokenPrediction[][]; + target_sans: TokenPrediction[][] | null; ci_loss: number; stochastic_loss: number; adversarial_loss: number; - target_sans_loss: number; + target_sans_loss: number | null; label: LabelPredictions | null; }; @@ -46,7 +46,7 @@ export type InterventionRunSummary = { export type RunInterventionRequest = { graph_id: number; selected_nodes: string[]; - sans_nodes: string[]; + sans_nodes?: string[]; top_k: number; adv_pgd: { n_steps: number; step_size: number }; }; diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 058b61cb1..d2cd74057 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -255,12 +255,9 @@ def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClien ] assert len(selected_nodes) > 0 - subset = selected_nodes[: min(5, len(selected_nodes))] - sans = [n for n in selected_nodes if n not in subset] request = { "graph_id": graph_id, - "selected_nodes": subset, - "sans_nodes": sans, + "selected_nodes": selected_nodes[: min(5, len(selected_nodes))], "top_k": 5, "adv_pgd": {"n_steps": 1, "step_size": 1.0}, } @@ -273,11 +270,11 @@ def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClien assert len(result["ci"]) > 0 assert len(result["stochastic"]) > 0 assert len(result["adversarial"]) > 0 - assert len(result["target_sans"]) > 0 + assert result["target_sans"] is None assert "ci_loss" in result assert "stochastic_loss" in result assert "adversarial_loss" in result - assert "target_sans_loss" in result + assert result["target_sans_loss"] is None # ----------------------------------------------------------------------------- From 8b09ded1aa618a840e1714037fc3d2a793b36b25 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Wed, 4 Mar 2026 14:23:45 +0000 Subject: [PATCH 093/102] =?UTF-8?q?Graph=20spotlight=20mode,=20node=20inte?= =?UTF-8?q?raction=20state=20machine,=20rename=20sans=E2=86=92ablated?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Spotlight mode: when "hide unpinned edges" is on and hovering a node (no pinned nodes), hide all edges and non-connected nodes; show only graph neighbors colored by edge polarity with strength-based opacity and grey outline - Refactor node rendering from parallel booleans to a global InteractionMode state machine (spotlight | focusing | resting) with O(1) per-node role lookup via getNodeRole() - Rename sans_nodes → nodes_to_ablate, target_sans → ablated across full stack (backend, frontend, DB JSON migration) - Fix ablated tooltip (was describing necessity, actually measures sufficiency) - Delete orphaned graph with missing base intervention run Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/CLAUDE.md | 29 +- spd/app/TODO.md | 172 ++++++++++- spd/app/backend/compute.py | 26 +- spd/app/backend/database.py | 283 ++++++++++-------- spd/app/backend/routers/graphs.py | 2 +- spd/app/backend/routers/intervention.py | 12 +- spd/app/backend/routers/mcp.py | 2 +- .../components/PromptAttributionsGraph.svelte | 169 ++++++++--- .../components/PromptAttributionsTab.svelte | 4 +- .../prompt-attr/InterventionsView.svelte | 48 ++- .../components/prompt-attr/NodeTooltip.svelte | 42 ++- .../prompt-attr/OptimizationParams.svelte | 30 +- .../prompt-attr/OutputNodeCard.svelte | 39 ++- spd/app/frontend/src/lib/interventionTypes.ts | 8 +- spd/app/frontend/vite.config.ts | 2 +- 15 files changed, 625 insertions(+), 243 deletions(-) diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 48e8e041e..795e55623 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -4,15 +4,15 @@ Web-based visualization and analysis tool for exploring neural network component - **Backend**: Python FastAPI (`backend/`) - **Frontend**: Svelte 5 + TypeScript (`frontend/`) -- **Database**: SQLite at `.data/app/prompt_attr.db` (relative to repo root) +- **Database**: SQLite at `SPD_OUT_DIR/app/prompt_attr.db` (shared across team via NFS) - **TODOs**: See `TODO.md` for open work items ## Project Context This is a **rapidly iterated research tool**. Key implications: -- **Please do not code for backwards compatibility**: Schema changes don't need migrations, expect state can be deleted, etc. -- **Database is disposable**: Delete `.data/app/prompt_attr.db` if schema changes break things +- **Please do not code for backwards compatibility**: Schema changes don't need migrations +- **Database is shared state**: Lives at `SPD_OUT_DIR/app/prompt_attr.db` on NFS, accessible by multiple backends. Do not delete without checking with the team. Uses DELETE journal mode (NFS-safe) with `fcntl.flock` write locking for concurrent access - **Prefer simplicity**: Avoid over-engineering for hypothetical future needs - **Fail loud and fast**: The users are a small team of highly technical people. Errors are good. We want to know immediately if something is wrong. No soft failing, assert, assert, assert - **Token display**: Always ship token strings rendered server-side via `AppTokenizer`, never raw token IDs. For embed/output layers, `component_idx` is a token ID — resolve it to a display string in the backend response. @@ -53,7 +53,7 @@ backend/ ├── clusters.py # Component clustering ├── dataset_search.py # SimpleStories dataset search ├── agents.py # Various useful endpoints that AI agents should look at when helping - ├── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code + ├── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ├── dataset_search.py # Dataset search (reads dataset from run config) └── agents.py # Various useful endpoints that AI agents should look at when helping ``` @@ -223,6 +223,7 @@ Finds sparse CI mask that: ### Interventions (`compute.py → compute_intervention`) A single unified function evaluates a node selection under three masking regimes: + - **CI**: mask = selection (binary on/off) - **Stochastic**: mask = selection + (1-selection) × Uniform(0,1) - **Adversarial**: PGD optimizes alive-but-unselected components to maximize loss; non-alive get Uniform(0,1) @@ -230,6 +231,7 @@ A single unified function evaluates a node selection under three masking regimes Returns `InterventionResult` with top-k `TokenPrediction`s per position for each regime, plus per-regime loss values. **Loss context**: Every graph has an implied loss that interventions evaluate against: + - **Standard/manual graphs** → `MeanKLLossConfig` (mean KL divergence from target across all positions) - **Optimized graphs** → the graph's optimization loss (CE for a specific token at a position, or KL at a position) @@ -237,7 +239,10 @@ This loss is used for two things: (1) what PGD maximizes during adversarial eval **Alive masks**: `compute_intervention` recomputes the model's natural CI (one forward pass + `calc_causal_importances`) and binarizes at 0 to get alive masks. This ensures the alive set is always the full model's CI — not the graph's potentially sparse optimized CI. PGD can only manipulate alive-but-unselected components. -**Training PGD vs Eval PGD**: The PGD settings in the graph optimization config (`adv_pgd_n_steps`, `adv_pgd_step_size`) are a *training* regularizer — they make CI optimization robust. The PGD in `compute_intervention` is an *eval* metric — it measures worst-case leakage for a given node selection. Eval PGD defaults are in `compute.py` (`DEFAULT_EVAL_PGD_CONFIG`). +**Training PGD vs Eval PGD**: The PGD settings in the graph optimization config (`adv_pgd_n_steps`, +`adv_pgd_step_size`) are a _training_ regularizer — they make CI optimization robust. The PGD in +`compute_intervention` is an _eval_ metric — it measures worst-case performance for a given node +selection. Eval PGD defaults are in `compute.py` (`DEFAULT_EVAL_PGD_CONFIG`). **Base intervention run**: Created automatically during graph computation. Uses all interventable nodes with CI > 0. Persisted as an `intervention_run` so predictions are available synchronously. @@ -308,14 +313,14 @@ GET /api/dataset/results?page=1&page_size=20 ## Database Schema -Located at `.data/app/prompt_attr.db`. Delete this file if schema changes cause issues. +Located at `SPD_OUT_DIR/app/prompt_attr.db` (shared via NFS). Uses DELETE journal mode with `fcntl.flock` write locking for safe concurrent access from multiple backends. -| Table | Key | Purpose | -| ------------------ | ---------------------------------- | ------------------------------------------------- | -| `runs` | `wandb_path` | W&B run references | -| `prompts` | `(run_id, context_length)` | Token sequences | -| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + CI/target logits + node CI values | -| `intervention_runs`| `graph_id` | Saved `InterventionResult` JSON (single `result` column) | +| Table | Key | Purpose | +| ------------------- | ---------------------------------- | -------------------------------------------------------- | +| `runs` | `wandb_path` | W&B run references | +| `prompts` | `(run_id, context_length)` | Token sequences | +| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + CI/target logits + node CI values | +| `intervention_runs` | `graph_id` | Saved `InterventionResult` JSON (single `result` column) | Note: Activation contexts, correlations, token stats, and interpretations are loaded from pre-harvested data at `SPD_OUT_DIR/{harvest,autointerp}/` (see `spd/harvest/` and `spd/autointerp/`). diff --git a/spd/app/TODO.md b/spd/app/TODO.md index f7658851f..21a8ba1fd 100644 --- a/spd/app/TODO.md +++ b/spd/app/TODO.md @@ -1,2 +1,172 @@ -# App TODOs +# App Backend Review & Action Items +Review date: 2026-03-04. Scope: `spd/app/backend/` — all Python files. + +Context: the app is a **researcher-first local tool** (frontend + backend launched together, opened in browser). Errors should be loud, silent failures absent, the prompt DB is deletable short-term state, no backwards compatibility needed. + +## Overview + +The backend is ~6,500 lines across 18 Python files. The core architecture (FastAPI + SQLite + singleton state + SSE streaming) is sound. The main concerns are: a few real bugs, accumulated dead code, some silent failures that violate the "loud errors" principle, and a few design seams where complexity hides. + +### File size inventory + +| File | Lines | Risk | +|---|---|---| +| `routers/mcp.py` | 1637 | High — mixed concerns, largest file | +| `routers/graphs.py` | 1036 | Medium — streaming complexity | +| `compute.py` | 920 | Low — core algorithm, well-structured | +| `database.py` | 827 | Medium — manual serialization | +| `optim_cis.py` | 504 | Low | +| `routers/dataset_search.py` | 473 | Medium — hardcoded dataset names | +| `routers/correlations.py` | 386 | Low | +| `routers/graph_interp.py` | 373 | Low | +| `routers/investigations.py` | 317 | Low | +| `routers/pretrain_info.py` | 246 | Low | +| `server.py` | 212 | Low — clean | +| `routers/activation_contexts.py` | 207 | Low | +| `routers/runs.py` | 191 | Low | +| `routers/dataset_attributions.py` | 170 | Low | +| `routers/intervention.py` | 169 | Low — clean | +| `state.py` | 132 | Low — clean | +| `app_tokenizer.py` | 119 | Low | +| `routers/prompts.py` | 115 | Low | + +--- + +## Bugs + +### 1. `dataset_search.py:262` — KeyError on tokenized results + +`get_tokenized_results` accesses `result["story"]` but `search_dataset` stores results with key `"text"` (line 137: `results.append({"text": text, ...})`). This will crash with `KeyError: 'story'` whenever tokenized results are requested. + +**Fix:** Change line 262 from `result["story"]` to `result["text"]`. Also line 287: the metadata exclusion list references `"story"` — should be `"text"`. + +### 2. `dataset_search.py` — random endpoints hardcode SimpleStories + +`get_random_samples` (line 336) and `get_random_samples_with_loss` (line 415) both hardcode `load_dataset("lennart-finke/SimpleStories", ...)` and access `item_dict["story"]`. Since primary models are now Pile-trained, these endpoints are broken for current research. They also don't use `DepLoadedRun` to get the dataset name from the run config like `search_dataset` does. + +**Fix:** Make them take `DepLoadedRun`, read `task_config.dataset_name` and `task_config.column_name`, and use those instead of hardcoded values. Or, if the random endpoints aren't used with Pile models, consider deleting them. + +--- + +## Dead code to delete + +### 3. `ForkedInterventionRunRecord` + `forked_intervention_runs` table + +`database.py:117-125` defines `ForkedInterventionRunRecord`. Lines 256-265 create the `forked_intervention_runs` table. Lines 744-827 implement 3 CRUD methods (`save_forked_intervention_run`, `get_forked_intervention_runs`, `delete_forked_intervention_run`). No router references any of these — the fork endpoints were removed. Delete all of it. + +Files: `database.py` + +### 4. `optim_cis.py:500-504` — `get_out_dir()` never called + +Dead utility function that creates a local `out/` directory. Nothing references it. + +Files: `optim_cis.py` + +### 5. Unused schemas in `graphs.py:188-209` + +`ComponentStats`, `PromptSearchQuery`, and `PromptSearchResponse` are defined but no endpoint uses them. They appear to be leftovers from a removed prompt search feature. The `PromptPreview` in `graphs.py:114` also duplicates the one in `prompts.py:25`. + +Files: `routers/graphs.py` + +### 6. `spd/app/TODO.md` was empty + +(This file — now repurposed for this review.) + +--- + +## Design issues + +### 7. `OptimizationParams` mixes config inputs with computed outputs + +`database.py:69-82` — Fields like `imp_min_coeff`, `steps`, `pnorm`, `beta` are optimization *inputs*. Fields like `ci_masked_label_prob`, `stoch_masked_label_prob`, `adv_pgd_label_prob` are computed *outputs*. These metrics are mutated in-place after construction in `graphs.py:759-761`. + +This makes the object's contract unclear — is it immutable config or mutable state? + +**Suggestion:** Either nest the metrics in a sub-model (`metrics: OptimMetrics | None`), or at minimum stop mutating after construction (compute the metrics before constructing `OptimizationParams`). + +### 8. `StoredGraph.id = -1` sentinel value + +`database.py:90` uses `-1` as "unsaved graph". If a graph is accidentally used before being saved, that `-1` leaks into API responses or DB queries. `id: int | None = None` is more honest and lets the type system catch misuse. + +### 9. GPU lock accessed two different ways + +- `graphs.py:603,844` — `stream_computation(work, manager._gpu_lock)` reaches into the private lock directly +- `intervention.py:86` — `with manager.gpu_lock():` uses the context manager + +The stream pattern is inherently different (hold lock across SSE generator lifetime), but accessing `_gpu_lock` directly breaks encapsulation. + +**Suggestion:** Add a `stream_with_gpu_lock(work)` method on `StateManager` that encapsulates the lock acquisition + SSE streaming pattern. Then `graphs.py` calls `manager.stream_with_gpu_lock(work)` instead of reaching into privates. + +### 10. `load_run` returns untyped dicts + +`runs.py:96,139` returns `{"status": "loaded", "run_id": ...}` and `{"status": "already_loaded", ...}`. No response model, so the frontend has no type-safe contract for this endpoint. + +**Fix:** Define a `LoadRunResponse(BaseModel)` with `status`, `run_id`, `wandb_path`. + +### 11. Edge truncation is invisible to the user + +`graphs.py:903` logs a warning when edges exceed `GLOBAL_EDGE_LIMIT = 50_000` and are truncated, but this info only goes to server logs. The researcher never sees it. + +**Fix:** Add `edges_truncated: bool` (or `total_edge_count: int`) to `GraphData` so the frontend can show a notice. + +### 12. Module-level `DEVICE = get_device()` in multiple files + +`graphs.py:266`, `intervention.py:48`, `dataset_search.py`, `prompts.py:18` all call `get_device()` at import time. Fine in practice but makes testing and non-GPU imports impossible. + +**Suggestion:** Move to a function call or lazily-evaluated property when/if this becomes a testing bottleneck. Low priority. + +### 13. `_GRAPH_INTERP_MOCK_MODE` cross-router import + +`runs.py:13` imports `MOCK_MODE` from `routers/graph_interp.py` and uses it in the status endpoint (line 174). The TODO comment says to remove it. This cross-router dependency for a mock flag should be cleaned up — the mock mode should either be a config flag on `StateManager` or deleted entirely. + +--- + +## Silent failure patterns (violate "loud errors" principle) + +### 14. `compute.py:79-86` — output node capping is silent + +`compute_layer_alive_info` caps output nodes to `MAX_OUTPUT_NODES_PER_POS = 15` per position without any logging or indication. If a researcher has >15 high-probability output tokens at a position, they silently lose some. + +At minimum, log when capping occurs. + +### 15. `correlations.py:291,302` — token stats returns `None` silently + +`get_component_token_stats` returns `None` when token stats haven't been harvested. This means the endpoint returns a `200 null` response, which the frontend has to special-case. An explicit 404 with a message is more honest. + +### 16. `correlations.py:112,260` — interpretations/intruder scores return `{}` silently + +`get_all_interpretations` and `get_intruder_scores` return empty dicts when data isn't available. This is defensible for bulk endpoints (the frontend can check emptiness), but it means the researcher has no way to distinguish "no interpretations exist" from "interpretations not yet generated." Consider logging or adding a `has_interpretations` flag to `LoadedRun`. + +Note: `LoadedRun.autointerp_available` already partially addresses this. But the endpoints themselves don't use it — they independently check `loaded.interp is None`. + +--- + +## Lower priority / nice-to-haves + +### 17. `extract_node_ci_vals` Python double loop + +`compute.py:640-648` iterates every `(seq_pos, component_idx)` pair in Python. For large models (39K components × 512 seq), this is a lot of Python overhead. Could be vectorized to only extract non-zero entries. + +### 18. `database.py` manual graph get-or-create race + +Lines 528-539: catches `IntegrityError` on manual graph save, then re-queries. There's a small race window between the failed insert and the re-query. Acceptable for a single-user local app but worth noting. + +### 19. `mcp.py` is 1637 lines + +The MCP router is the largest file, mixing tool definitions, implementation logic, and JSON-RPC handling. It has module-level global state (`_investigation_config`). This file would benefit from being split, but it's also likely to be rewritten when MCP tooling matures, so the ROI of refactoring now is debatable. + +--- + +## Suggested priority order for implementation + +1. Fix `result["story"]` KeyError (bug #1) — 2 min +2. Delete dead code (items #3-5) — 10 min +3. Fix random dataset endpoints or delete if unused (#2) — 15 min +4. Add `edges_truncated` to GraphData (#11) — 10 min +5. Type the `load_run` response (#10) — 5 min +6. Clean up `_GRAPH_INTERP_MOCK_MODE` (#13) — 5 min +7. Deduplicate `MAX_OUTPUT_NODES_PER_POS` (#5 partial) — 2 min +8. `StoredGraph.id` sentinel → `None` (#8) — 10 min +9. Split `OptimizationParams` (#7) — 20 min +10. GPU lock encapsulation (#9) — 15 min diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index da015c885..55e31d37b 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -691,7 +691,7 @@ class LabelPredictions(BaseModel): ci: TokenPrediction stochastic: TokenPrediction adversarial: TokenPrediction - target_sans: TokenPrediction | None + ablated: TokenPrediction | None class InterventionResult(BaseModel): @@ -701,11 +701,11 @@ class InterventionResult(BaseModel): ci: list[list[TokenPrediction]] stochastic: list[list[TokenPrediction]] adversarial: list[list[TokenPrediction]] - target_sans: list[list[TokenPrediction]] | None + ablated: list[list[TokenPrediction]] | None ci_loss: float stochastic_loss: float adversarial_loss: float - target_sans_loss: float | None + ablated_loss: float | None label: LabelPredictions | None @@ -766,21 +766,21 @@ def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], - sans_nodes: list[tuple[str, int, int]] | None, + nodes_to_ablate: list[tuple[str, int, int]] | None, tokenizer: AppTokenizer, adv_pgd_config: AdvPGDConfig, loss_config: LossConfig, sampling: SamplingType, top_k: int, ) -> InterventionResult: - """Unified intervention evaluation: CI, stochastic, adversarial, and optionally target-sans. + """Unified intervention evaluation: CI, stochastic, adversarial, and optionally ablated. Args: active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. Used for CI, stochastic, and adversarial masking. - sans_nodes: If provided, nodes to ablate in target-sans (full target model minus these). + nodes_to_ablate: If provided, nodes to ablate in ablated (full target model minus these). The frontend computes this as all_graph_nodes - selected_nodes. - If None, target-sans is skipped. + If None, ablated is skipped. loss_config: Loss for PGD adversary to maximize and for reporting metrics. sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. @@ -826,13 +826,13 @@ def compute_intervention( stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) - # Target-sans forward (only if sans_nodes provided) + # Target-sans forward (only if nodes_to_ablate provided) ts_logits: Float[Tensor, "1 seq vocab"] | None = None - if sans_nodes is not None: + if nodes_to_ablate is not None: ts_masks: dict[str, Float[Tensor, "1 seq C"]] = {} for layer_name in ci_masks: ts_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) - for layer, seq_pos, c_idx in sans_nodes: + for layer, seq_pos, c_idx in nodes_to_ablate: ts_masks[layer][0, seq_pos, c_idx] = 0.0 weight_deltas = model.calc_weight_deltas() ts_wd = { @@ -901,7 +901,7 @@ def compute_intervention( ci=_extract_label_prediction(ci_logits, target_logits, tokenizer, pos, tid), stochastic=_extract_label_prediction(stoch_logits, target_logits, tokenizer, pos, tid), adversarial=_extract_label_prediction(adv_logits, target_logits, tokenizer, pos, tid), - target_sans=ts_label, + ablated=ts_label, ) input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) @@ -911,10 +911,10 @@ def compute_intervention( ci=ci_preds, stochastic=stoch_preds, adversarial=adv_preds, - target_sans=ts_preds, + ablated=ts_preds, ci_loss=ci_loss, stochastic_loss=stoch_loss, adversarial_loss=adv_loss, - target_sans_loss=ts_loss, + ablated_loss=ts_loss, label=label, ) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index b6d4f3ef6..dba51a934 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -6,11 +6,13 @@ Interpretations are stored separately at SPD_OUT_DIR/autointerp//. """ +import fcntl import hashlib import io import json import os import sqlite3 +from contextlib import contextmanager from pathlib import Path from typing import Literal @@ -19,13 +21,11 @@ from spd.app.backend.compute import Edge, Node from spd.app.backend.optim_cis import CELossConfig, KLLossConfig, MaskType, PositionalLossConfig -from spd.settings import REPO_ROOT +from spd.settings import SPD_OUT_DIR GraphType = Literal["standard", "optimized", "manual"] -# Persistent data directories -_APP_DATA_DIR = REPO_ROOT / ".data" / "app" -_DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = SPD_OUT_DIR / "app" / "prompt_attr.db" def get_default_db_path() -> Path: @@ -34,7 +34,7 @@ def get_default_db_path() -> Path: Checks env vars in order: 1. SPD_INVESTIGATION_DIR - investigation mode, db at dir/app.db 2. SPD_APP_DB_PATH - explicit override - 3. Default: .data/app/prompt_attr.db + 3. Default: SPD_OUT_DIR/app/prompt_attr.db """ investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") if investigation_dir: @@ -137,6 +137,7 @@ class PromptAttrDB: def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): self.db_path = db_path or get_default_db_path() + self._lock_path = self.db_path.with_suffix(".db.lock") self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None @@ -160,6 +161,16 @@ def __enter__(self) -> "PromptAttrDB": def __exit__(self, *args: object) -> None: self.close() + @contextmanager + def _write_lock(self): + """Acquire an exclusive file lock for write operations (NFS-safe).""" + with open(self._lock_path, "w") as lock_fd: + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX) + yield + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + # ------------------------------------------------------------------------- # Schema initialization # ------------------------------------------------------------------------- @@ -167,7 +178,7 @@ def __exit__(self, *args: object) -> None: def init_schema(self) -> None: """Initialize the database schema. Safe to call multiple times.""" conn = self._get_conn() - conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA journal_mode=DELETE") conn.execute("PRAGMA foreign_keys=ON") conn.executescript(""" CREATE TABLE IF NOT EXISTS runs ( @@ -272,15 +283,16 @@ def init_schema(self) -> None: def create_run(self, wandb_path: str) -> int: """Create a new run. Returns the run ID.""" - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO runs (wandb_path) VALUES (?)", - (wandb_path,), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO runs (wandb_path) VALUES (?)", + (wandb_path,), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_run_by_wandb_path(self, wandb_path: str) -> Run | None: """Get a run by its wandb path.""" @@ -338,19 +350,20 @@ def add_custom_prompt( Returns: The prompt ID (existing or newly created). """ - existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) - if existing_id is not None: - return existing_id + with self._write_lock(): + existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) + if existing_id is not None: + return existing_id - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", - (run_id, json.dumps(token_ids), context_length), - ) - prompt_id = cursor.lastrowid - assert prompt_id is not None - conn.commit() - return prompt_id + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", + (run_id, json.dumps(token_ids), context_length), + ) + prompt_id = cursor.lastrowid + assert prompt_id is not None + conn.commit() + return prompt_id def get_prompt(self, prompt_id: int) -> PromptRecord | None: """Get a prompt by ID.""" @@ -475,68 +488,69 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: included_nodes_json = json.dumps(sorted(graph.included_nodes)) included_nodes_hash = hashlib.sha256(included_nodes_json.encode()).hexdigest() - try: - cursor = conn.execute( - """INSERT INTO graphs - (prompt_id, graph_type, - imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, loss_config_hash, - adv_pgd_n_steps, adv_pgd_step_size, - ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob, - included_nodes, included_nodes_hash, - edges_data, output_logits, node_ci_vals, node_subcomp_acts) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - prompt_id, - graph.graph_type, - imp_min_coeff, - steps, - pnorm, - beta, - mask_type, - loss_config_json, - loss_config_hash, - adv_pgd_n_steps, - adv_pgd_step_size, - ci_masked_label_prob, - stoch_masked_label_prob, - adv_pgd_label_prob, - included_nodes_json, - included_nodes_hash, - edges_json, - output_logits_blob, - node_ci_vals_json, - node_subcomp_acts_json, - ), - ) - conn.commit() - graph_id = cursor.lastrowid - assert graph_id is not None - return graph_id - except sqlite3.IntegrityError as e: - match graph.graph_type: - case "standard": - raise ValueError( - f"Standard graph already exists for prompt_id={prompt_id}. " - "Use get_graphs() to retrieve existing graph or delete it first." - ) from e - case "optimized": - raise ValueError( - f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." - ) from e - case "manual": - # Get-or-create semantics: return existing graph ID - conn.rollback() - row = conn.execute( - """SELECT id FROM graphs - WHERE prompt_id = ? AND graph_type = 'manual' - AND included_nodes_hash = ?""", - (prompt_id, included_nodes_hash), - ).fetchone() - if row: - return row["id"] - # Should not happen if constraint triggered - raise ValueError("A manual graph with the same nodes already exists.") from e + with self._write_lock(): + try: + cursor = conn.execute( + """INSERT INTO graphs + (prompt_id, graph_type, + imp_min_coeff, steps, pnorm, beta, mask_type, + loss_config, loss_config_hash, + adv_pgd_n_steps, adv_pgd_step_size, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob, + included_nodes, included_nodes_hash, + edges_data, output_logits, node_ci_vals, node_subcomp_acts) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + prompt_id, + graph.graph_type, + imp_min_coeff, + steps, + pnorm, + beta, + mask_type, + loss_config_json, + loss_config_hash, + adv_pgd_n_steps, + adv_pgd_step_size, + ci_masked_label_prob, + stoch_masked_label_prob, + adv_pgd_label_prob, + included_nodes_json, + included_nodes_hash, + edges_json, + output_logits_blob, + node_ci_vals_json, + node_subcomp_acts_json, + ), + ) + conn.commit() + graph_id = cursor.lastrowid + assert graph_id is not None + return graph_id + except sqlite3.IntegrityError as e: + match graph.graph_type: + case "standard": + raise ValueError( + f"Standard graph already exists for prompt_id={prompt_id}. " + "Use get_graphs() to retrieve existing graph or delete it first." + ) from e + case "optimized": + raise ValueError( + f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." + ) from e + case "manual": + conn.rollback() + row = conn.execute( + """SELECT id FROM graphs + WHERE prompt_id = ? AND graph_type = 'manual' + AND included_nodes_hash = ?""", + (prompt_id, included_nodes_hash), + ).fetchone() + if row: + return row["id"] + raise ValueError( + "A manual graph with the same nodes already exists." + ) from e def _row_to_stored_graph(self, row: sqlite3.Row) -> StoredGraph: """Convert a database row to a StoredGraph.""" @@ -648,21 +662,23 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: def delete_graphs_for_prompt(self, prompt_id: int) -> int: """Delete all graphs for a prompt. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) + conn.commit() + return cursor.rowcount def delete_graphs_for_run(self, run_id: int) -> int: """Delete all graphs for all prompts in a run. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute( - """DELETE FROM graphs - WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", - (run_id,), - ) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """DELETE FROM graphs + WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", + (run_id,), + ) + conn.commit() + return cursor.rowcount # ------------------------------------------------------------------------- # Intervention run operations @@ -684,16 +700,17 @@ def save_intervention_run( Returns: The intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO intervention_runs (graph_id, selected_nodes, result) - VALUES (?, ?, ?)""", - (graph_id, json.dumps(selected_nodes), result_json), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO intervention_runs (graph_id, selected_nodes, result) + VALUES (?, ?, ?)""", + (graph_id, json.dumps(selected_nodes), result_json), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: """Get all intervention runs for a graph. @@ -726,16 +743,18 @@ def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: def delete_intervention_run(self, run_id: int) -> None: """Delete an intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) + conn.commit() def delete_intervention_runs_for_graph(self, graph_id: int) -> int: """Delete all intervention runs for a graph. Returns count deleted.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) + conn.commit() + return cursor.rowcount # ------------------------------------------------------------------------- # Forked intervention run operations @@ -757,16 +776,17 @@ def save_forked_intervention_run( Returns: The forked intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) - VALUES (?, ?, ?)""", - (intervention_run_id, json.dumps(token_replacements), result_json), - ) - conn.commit() - fork_id = cursor.lastrowid - assert fork_id is not None - return fork_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) + VALUES (?, ?, ?)""", + (intervention_run_id, json.dumps(token_replacements), result_json), + ) + conn.commit() + fork_id = cursor.lastrowid + assert fork_id is not None + return fork_id def get_forked_intervention_runs( self, intervention_run_id: int @@ -822,6 +842,7 @@ def get_intervention_run(self, run_id: int) -> InterventionRunRecord | None: def delete_forked_intervention_run(self, fork_id: int) -> None: """Delete a forked intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) + conn.commit() diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index 91ab27740..14b49980f 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -87,7 +87,7 @@ def _save_base_intervention_run( model=model, tokens=tokens, active_nodes=active_nodes, - sans_nodes=None, + nodes_to_ablate=None, tokenizer=tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=effective_loss_config, diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 80f28ab1d..e26a73462 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -29,7 +29,7 @@ class RunInterventionRequest(BaseModel): graph_id: int selected_nodes: list[str] # node keys (layer:seq:cIdx) - sans_nodes: list[str] | None = None # node keys to ablate in target-sans (omit to skip) + nodes_to_ablate: list[str] | None = None # node keys to ablate in ablated (omit to skip) top_k: int adv_pgd: AdvPgdParams @@ -97,9 +97,11 @@ def run_and_save_intervention( active_nodes = _parse_and_validate_active_nodes( request.selected_nodes, loaded.topology, len(token_ids) ) - sans_nodes = ( - _parse_and_validate_active_nodes(request.sans_nodes, loaded.topology, len(token_ids)) - if request.sans_nodes is not None + nodes_to_ablate = ( + _parse_and_validate_active_nodes( + request.nodes_to_ablate, loaded.topology, len(token_ids) + ) + if request.nodes_to_ablate is not None else None ) tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) @@ -115,7 +117,7 @@ def run_and_save_intervention( model=loaded.model, tokens=tokens, active_nodes=active_nodes, - sans_nodes=sans_nodes, + nodes_to_ablate=nodes_to_ablate, tokenizer=loaded.tokenizer, adv_pgd_config=AdvPGDConfig( n_steps=request.adv_pgd.n_steps, diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py index 043201378..9c212d5ea 100644 --- a/spd/app/backend/routers/mcp.py +++ b/spd/app/backend/routers/mcp.py @@ -902,7 +902,7 @@ def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: model=loaded.model, tokens=tokens, active_nodes=active_nodes, - sans_nodes=None, + nodes_to_ablate=None, tokenizer=loaded.tokenizer, adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, loss_config=MeanKLLossConfig(), diff --git a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte index da0d0e935..5b4dfca42 100644 --- a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte @@ -1,6 +1,6 @@
+ {#if displaySettings.showEdgeAttributions && wteOutgoing.length > 0} + {}} + /> + {/if} {:else if isOutput} - + {:else if !hideNodeCard} {#key `${hoveredNode.layer}:${hoveredNode.cIdx}`} diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte index 5b20ca254..e335514cb 100644 --- a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte @@ -21,15 +21,17 @@
- steps{optimization.steps} - imp_min{optimization.imp_min_coeff} - pnorm{optimization.pnorm} - beta{optimization.beta} - mask{optimization.mask_type} - + steps{optimization.steps} + imp_min{optimization.imp_min_coeff} + pnorm{optimization.pnorm} + beta{optimization.beta} + mask{optimization.mask_type} + {optimization.loss.type}{optimization.loss.coeff} - + pos {optimization.loss.position} {#if tokenAtPos !== null} @@ -37,30 +39,30 @@ {/if} {#if optimization.loss.type === "ce"} - + label({optimization.loss.label_str}) {/if} {#if optimization.pgd} - + pgd_steps{optimization.pgd.n_steps} - + pgd_lr{optimization.pgd.step_size} {/if} - + L0{optimization.metrics.l0_total.toFixed(1)} {#if optimization.loss.type === "ce"} - + CI prob{formatProb(optimization.metrics.ci_masked_label_prob)} - + stoch prob{formatProb(optimization.metrics.stoch_masked_label_prob)} - + adv prob{formatProb(optimization.metrics.adv_pgd_label_prob)} {/if} diff --git a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte index 76d37c037..328e5449d 100644 --- a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte @@ -1,14 +1,38 @@ @@ -106,21 +112,22 @@
- @@ -129,58 +136,69 @@
-
- {#if displaySettings.centerOnPeak} -
- {#each paginatedIndices as idx (idx)} - {@const fp = firingPositions[idx]} -
-
- -
-
- + {#if loading} +
+
+ {#each Array(pageSize) as _, i (i)} +
+ {/each} +
+
+ {:else} + {@const d = loaded!} +
+ {#if displaySettings.centerOnPeak} +
+ {#each paginatedIndices as idx (idx)} + {@const fp = firingPositions[idx]} +
+
+ +
+
+ +
+
+ +
-
+ {/each} +
+ {:else} +
+ {#each paginatedIndices as idx (idx)} +
-
- {/each} -
- {:else} -
- {#each paginatedIndices as idx (idx)} -
- -
- {/each} -
- {/if} -
+ {/each} +
+ {/if} +
+ {/if}
diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index a82c71ce1..232e4cd39 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -1,12 +1,13 @@
@@ -427,19 +440,10 @@
- {#if componentData.componentDetail.status === "loading"} -
Loading component data...
- {:else if componentData.componentDetail.status === "loaded"} - - {:else if componentData.componentDetail.status === "error"} - Error loading component data: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading component data: {String(activationExamples.error)} {:else} - Something went wrong loading component data. + {/if} import { getContext, onMount } from "svelte"; import { computeMaxAbsComponentAct } from "../lib/colors"; + import { mapLoadable } from "../lib/index"; import { anyCorrelationStatsEnabled } from "../lib/displaySettings.svelte"; import { useComponentDataExpectCached } from "../lib/useComponentDataExpectCached.svelte"; import { RUN_KEY, type RunContext } from "../lib/useRun.svelte"; - import ActivationContextsPagedTable from "./ActivationContextsPagedTable.svelte"; + import ActivationContextsPagedTable, { type ActivationExamplesData } from "./ActivationContextsPagedTable.svelte"; import ComponentProbeInput from "./ComponentProbeInput.svelte"; import ComponentCorrelationMetrics from "./ui/ComponentCorrelationMetrics.svelte"; import DatasetAttributionsSection from "./ui/DatasetAttributionsSection.svelte"; @@ -80,6 +81,18 @@ if (componentData.componentDetail.status !== "loaded") return 1; return computeMaxAbsComponentAct(componentData.componentDetail.data.example_component_acts); }); + + const activationExamples = $derived( + mapLoadable( + componentData.componentDetail, + (d): ActivationExamplesData => ({ + tokens: d.example_tokens, + ci: d.example_ci, + componentActs: d.example_component_acts, + maxAbsComponentAct: computeMaxAbsComponentAct(d.example_component_acts), + }), + ), + );
@@ -103,21 +116,12 @@
- {#if componentData.componentDetail.status === "uninitialized"} - uninitialized - {:else if componentData.componentDetail.status === "loading"} - Loading details... - {:else if componentData.componentDetail.status === "loaded"} - {#if componentData.componentDetail.data.example_tokens.length > 0} - - {/if} - {:else if componentData.componentDetail.status === "error"} - Error loading details: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + {/if}
diff --git a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte index 5b4dfca42..c9e9e0637 100644 --- a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte @@ -514,11 +514,15 @@ const p1 = positions[edge.src]; const p2 = positions[edge.tgt]; if (!p1 || !p2) continue; - const dy = Math.abs(p2.y - p1.y); - const curveOffset = Math.max(20, dy * 0.4); const path = new Path2D(); path.moveTo(p1.x, p1.y); - path.bezierCurveTo(p1.x, p1.y - curveOffset, p2.x, p2.y + curveOffset, p2.x, p2.y); + if (displaySettings.curvedEdges) { + const dy = Math.abs(p2.y - p1.y); + const curveOffset = Math.max(20, dy * 0.4); + path.bezierCurveTo(p1.x, p1.y - curveOffset, p2.x, p2.y + curveOffset, p2.x, p2.y); + } else { + path.lineTo(p2.x, p2.y); + } items.push({ edge, path, diff --git a/spd/app/frontend/src/components/TokenHighlights.svelte b/spd/app/frontend/src/components/TokenHighlights.svelte index 4916c1f00..a921450bb 100644 --- a/spd/app/frontend/src/components/TokenHighlights.svelte +++ b/spd/app/frontend/src/components/TokenHighlights.svelte @@ -25,6 +25,15 @@ return tokenNextProbs[i - 1]; } + function sanitizeToken(tok: string): string { + return tok + .replaceAll("\n", "↵") + .replaceAll("\r", "⏎") + .replaceAll("\t", "⇥") + .replaceAll("\v", "⇣") + .replaceAll("\f", "⇟"); + } + function getBgColor(ci: number): string { return getTokenHighlightBg(ci); } @@ -41,14 +50,14 @@ style="background-color:{getBgColor(tokenCi[i])};--underline-color:{getUnderlineColor( tokenComponentActs[i], )}" - data-tooltip={getTooltipText(tokenCi[i], tokenComponentActs[i], getProbAtPosition(i))}>{tok}{sanitizeToken(tok)}{/each} diff --git a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte index 077917234..0dc1f96f4 100644 --- a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte +++ b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte @@ -94,6 +94,14 @@ /> Center on peak +