diff --git a/.claude/skills/gpudash.md b/.claude/skills/gpudash.md new file mode 100644 index 000000000..5e60a2b6a --- /dev/null +++ b/.claude/skills/gpudash.md @@ -0,0 +1,12 @@ +--- +name: gpudash +description: Check GPU availability across the SLURM cluster +user_invocable: true +--- + +# gpudash + +Run the `gpudash` command to show GPU availability across the cluster. + +## Steps +1. Run `gpudash` and show the output to the user. diff --git a/.claude/worktrees/bold-elm-8kpb b/.claude/worktrees/bold-elm-8kpb new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/bold-elm-8kpb @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/bright-fox-a4i0 b/.claude/worktrees/bright-fox-a4i0 new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/bright-fox-a4i0 @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/calm-owl-v4pj b/.claude/worktrees/calm-owl-v4pj new file mode 160000 index 000000000..dbe0668a4 --- /dev/null +++ b/.claude/worktrees/calm-owl-v4pj @@ -0,0 +1 @@ +Subproject commit dbe0668a4119885b7fe952ed820b4ba8b4a3d693 diff --git a/.claude/worktrees/cozy-frolicking-stream b/.claude/worktrees/cozy-frolicking-stream new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/cozy-frolicking-stream @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/stateless-dancing-blanket b/.claude/worktrees/stateless-dancing-blanket new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/stateless-dancing-blanket @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/swift-owl-yep9 b/.claude/worktrees/swift-owl-yep9 new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/swift-owl-yep9 @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/swift-ray-amfs b/.claude/worktrees/swift-ray-amfs new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/swift-ray-amfs @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/vectorized-wiggling-whisper b/.claude/worktrees/vectorized-wiggling-whisper new file mode 160000 index 000000000..cb18c86a7 --- /dev/null +++ b/.claude/worktrees/vectorized-wiggling-whisper @@ -0,0 +1 @@ +Subproject commit cb18c86a77720f94a292e7421a19694082813c8c diff --git a/.claude/worktrees/xenodochial-germain b/.claude/worktrees/xenodochial-germain new file mode 160000 index 000000000..5c9f344eb --- /dev/null +++ b/.claude/worktrees/xenodochial-germain @@ -0,0 +1 @@ +Subproject commit 5c9f344eb490e90bed9db5102325459d42c3c0f4 diff --git a/.gitignore b/.gitignore index 4780cbd03..3581e751a 100644 --- a/.gitignore +++ b/.gitignore @@ -177,4 +177,7 @@ cython_debug/ #.idea/ **/*.db -**/*.db* \ No newline at end of file +**/*.db* +*.schema.json + +.claude/worktrees \ No newline at end of file diff --git a/.mcp.json b/.mcp.json index fefb52c9a..700113020 100644 --- a/.mcp.json +++ b/.mcp.json @@ -1,8 +1,3 @@ { - "mcpServers": { - "svelte-llm": { - "type": "http", - "url": "https://svelte-llm.stanislav.garden/mcp/mcp" - } - } + "mcpServers": {} } \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 6bb73e8b2..13bdcf4d0 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -3,14 +3,18 @@ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Environment Setup + **IMPORTANT**: Always activate the virtual environment before running Python or git operations: + ```bash source .venv/bin/activate ``` -Repo requires `.env` file with WandB credentials (see `.env.example`) +If working in a worktree, make sure there's a local `.venv` first by running `uv sync` in the worktree directory. Do NOT `cd` to the main repo — all commands (including git) should run in the worktree. +Repo requires `.env` file with WandB credentials (see `.env.example`) ## Project Overview + SPD (Stochastic Parameter Decomposition) is a research framework for analyzing neural network components and their interactions through sparse parameter decomposition techniques. - Target model parameters are decomposed as a sum of `parameter components` @@ -36,6 +40,8 @@ The codebase supports three experimental domains: TMS (Toy Model of Superpositio - `ss_llama_simple_mlp`, `ss_llama_simple_mlp-1L`, `ss_llama_simple_mlp-2L` - Llama MLP-only variants - `ss_gpt2`, `ss_gpt2_simple`, `ss_gpt2_simple_noln` - Simple Stories GPT-2 variants - `ss_gpt2_simple-1L`, `ss_gpt2_simple-2L` - GPT-2 simple layer variants + - `pile_llama_simple_mlp-2L`, `pile_llama_simple_mlp-4L`, `pile_llama_simple_mlp-12L` - Pile Llama MLP-only variants + - `pile_gpt2_simple-2L_global_reverse` - Pile GPT-2 with global reverse - `gpt2` - Standard GPT-2 - `ts` - TinyStories @@ -46,7 +52,7 @@ This repository implements methods from two key research papers on parameter dec **Stochastic Parameter Decomposition (SPD)** - [`papers/Stochastic_Parameter_Decomposition/spd_paper.md`](papers/Stochastic_Parameter_Decomposition/spd_paper.md) -- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. +- A version of this repository was used to run the experiments in this paper. But we continue to develop on the code, so it no longer is limited to the implementation used for this paper. - Introduces the core SPD framework - Details the stochastic masking approach and optimization techniques used throughout the codebase - Useful reading for understanding the implementation details, though may be outdated. @@ -95,6 +101,7 @@ This repository implements methods from two key research papers on parameter dec ## Architecture Overview **Core SPD Framework:** + - `spd/run_spd.py` - Main SPD optimization logic called by all experiments - `spd/configs.py` - Pydantic config classes for all experiment types - `spd/registry.py` - Centralized experiment registry with all experiment configurations @@ -105,15 +112,17 @@ This repository implements methods from two key research papers on parameter dec - `spd/figures.py` - Figures for logging to WandB (e.g. CI histograms, Identity plots, etc.) **Terminology: Sources vs Masks:** + - **Sources** (`adv_sources`, `PPGDSources`, `self.sources`): The raw values that PGD optimizes adversarially. These are interpolated with CI to produce component masks: `mask = ci + (1 - ci) * source`. Used in both regular PGD (`spd/metrics/pgd_utils.py`) and persistent PGD (`spd/persistent_pgd.py`). - **Masks** (`component_masks`, `RoutingMasks`, `make_mask_infos`, `n_mask_samples`): The materialized per-component masks used during forward passes. These are produced from sources (in PGD) or from stochastic sampling, and are a general SPD concept across the whole codebase. **Experiment Structure:** Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: + - `models.py` - Experiment-specific model classes and pretrained loading - `*_decomposition.py` - Main SPD execution script -- `train_*.py` - Training script for target models +- `train_*.py` - Training script for target models - `*_config.yaml` - Configuration files - `plotting.py` - Visualization utilities @@ -127,7 +136,7 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: **Configuration System:** - YAML configs define all experiment parameters -- Pydantic models provide type safety and validation +- Pydantic models provide type safety and validation - WandB integration for experiment tracking and model storage - Supports both local paths and `wandb:project/runs/run_id` format for model loading - Centralized experiment registry (`spd/registry.py`) manages all experiment configurations @@ -137,8 +146,9 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: - `spd/harvest/` - Offline GPU pipeline for collecting component statistics (correlations, token stats, activation examples) - `spd/autointerp/` - LLM-based automated interpretation of components - `spd/dataset_attributions/` - Multi-GPU pipeline for computing component-to-component attribution strengths aggregated over training data -- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions}//` -- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, and `spd/dataset_attributions/CLAUDE.md` for details +- `spd/graph_interp/` - Context-aware component labeling using graph structure (attributions + correlations) +- Data stored at `SPD_OUT_DIR/{harvest,autointerp,dataset_attributions,graph_interp}//` +- See `spd/harvest/CLAUDE.md`, `spd/autointerp/CLAUDE.md`, `spd/dataset_attributions/CLAUDE.md`, and `spd/graph_interp/CLAUDE.md` for details **Output Directory (`SPD_OUT_DIR`):** @@ -160,12 +170,14 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: ├── scripts/ # Standalone utility scripts ├── tests/ # Test suite ├── spd/ # Main source code +│ ├── investigate/ # Agent investigation (see investigate/CLAUDE.md) │ ├── app/ # Web visualization app (see app/CLAUDE.md) │ ├── autointerp/ # LLM interpretation (see autointerp/CLAUDE.md) │ ├── clustering/ # Component clustering (see clustering/CLAUDE.md) │ ├── dataset_attributions/ # Dataset attributions (see dataset_attributions/CLAUDE.md) │ ├── harvest/ # Statistics collection (see harvest/CLAUDE.md) │ ├── postprocess/ # Unified postprocessing pipeline (harvest + attributions + autointerp) +│ ├── graph_interp/ # Context-aware interpretation (see graph_interp/CLAUDE.md) │ ├── pretrain/ # Target model pretraining (see pretrain/CLAUDE.md) │ ├── experiments/ # Experiment implementations │ │ ├── tms/ # Toy Model of Superposition @@ -201,14 +213,17 @@ Each experiment (`spd/experiments/{tms,resid_mlp,lm}/`) contains: | `spd-autointerp` | `spd/autointerp/scripts/run_slurm_cli.py` | Submit autointerp SLURM job | | `spd-attributions` | `spd/dataset_attributions/scripts/run_slurm_cli.py` | Submit dataset attribution SLURM job | | `spd-postprocess` | `spd/postprocess/cli.py` | Unified postprocessing pipeline (harvest + attributions + interpret + evals) | +| `spd-graph-interp` | `spd/graph_interp/scripts/run_slurm_cli.py` | Submit graph interpretation SLURM job | | `spd-clustering` | `spd/clustering/scripts/run_pipeline.py` | Clustering pipeline | | `spd-pretrain` | `spd/pretrain/scripts/run_slurm_cli.py` | Pretrain target models | +| `spd-investigate` | `spd/investigate/scripts/run_slurm_cli.py` | Launch investigation agent | ### Files to Skip When Searching Use `spd/` as the search root (not repo root) to avoid noise. **Always skip:** + - `.venv/` - Virtual environment - `__pycache__/`, `.pytest_cache/`, `.ruff_cache/` - Build artifacts - `node_modules/` - Frontend dependencies @@ -218,27 +233,37 @@ Use `spd/` as the search root (not repo root) to avoid noise. - `wandb/` - WandB local files **Usually skip unless relevant:** + - `tests/` - Test files (unless debugging test failures) - `papers/` - Research paper drafts ### Common Call Chains **Running Experiments:** + - `spd-run` → `spd/scripts/run.py` → `spd/utils/slurm.py` → SLURM → `spd/run_spd.py` - `spd-local` → `spd/scripts/run_local.py` → `spd/run_spd.py` directly **Harvest Pipeline:** + - `spd-harvest` → `spd/harvest/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/harvest/scripts/run.py` → `spd/harvest/harvest.py` **Autointerp Pipeline:** + - `spd-autointerp` → `spd/autointerp/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → `spd/autointerp/interpret.py` **Dataset Attributions Pipeline:** + - `spd-attributions` → `spd/dataset_attributions/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM array → `spd/dataset_attributions/harvest.py` **Clustering Pipeline:** + - `spd-clustering` → `spd/clustering/scripts/run_pipeline.py` → `spd/utils/slurm.py` → `spd/clustering/scripts/run_clustering.py` +**Investigation Pipeline:** + +- `spd-investigate` → `spd/investigate/scripts/run_slurm_cli.py` → `spd/utils/slurm.py` → SLURM → `spd/investigate/scripts/run_agent.py` → Claude Code + ## Common Usage Patterns ### Running Experiments Locally (`spd-local`) @@ -285,6 +310,28 @@ spd-autointerp # Submit SLURM job to interpret component Requires `OPENROUTER_API_KEY` env var. See `spd/autointerp/CLAUDE.md` for details. +### Agent Investigation (`spd-investigate`) + +Launch a Claude Code agent to investigate a specific question about an SPD model: + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What components are involved in verb agreement?" --time 4:00:00 +``` + +Each investigation: + +- Runs in its own SLURM job with 1 GPU +- Starts an isolated app backend instance +- Investigates the specific research question using SPD tools via MCP +- Writes findings to append-only JSONL files + +Output: `SPD_OUT_DIR/investigations//` + +For parallel investigations, run the command multiple times with different prompts. + +See `spd/investigate/CLAUDE.md` for details. + ### Unified Postprocessing (`spd-postprocess`) Run all postprocessing steps for a completed SPD run with a single command: @@ -295,6 +342,7 @@ spd-postprocess --config custom_config.yaml # Use custom config ``` Defaults are defined in `PostprocessConfig` (`spd/postprocess/config.py`). Pass a custom YAML/JSON config to override. Set any section to `null` to skip it: + - `attributions: null` — skip dataset attributions - `autointerp: null` — skip autointerp entirely (interpret + evals) - `autointerp.evals: null` — skip evals but still run interpret @@ -323,6 +371,7 @@ spd-run # Run all experiments ``` All `spd-run` executions: + - Submit jobs to SLURM - Create a git snapshot for reproducibility - Create W&B workspace views @@ -343,6 +392,7 @@ spd-run --experiments --sweep --n_agents [--cpu] ``` Examples: + ```bash spd-run --experiments tms_5-2 --sweep --n_agents 4 # Run TMS 5-2 sweep with 4 GPU agents spd-run --experiments resid_mlp2 --sweep --n_agents 3 --cpu # Run ResidualMLP2 sweep with 3 CPU agents @@ -364,6 +414,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee - Default sweep parameters are loaded from `spd/scripts/sweep_params.yaml` - You can specify a custom sweep parameters file by passing its path to `--sweep` - Sweep parameters support both experiment-specific and global configurations: + ```yaml # Global parameters applied to all experiments global: @@ -376,7 +427,7 @@ spd-run --experiments tms_5-2 --sweep custom.yaml --n_agents 2 # Use custom swee # Experiment-specific parameters (override global) tms_5-2: seed: - values: [100, 200] # Overrides global seed + values: [100, 200] # Overrides global seed task_config: feature_probability: values: [0.05, 0.1] @@ -402,6 +453,7 @@ model = ComponentModel.from_run_info(run_info) # Local paths work too model = ComponentModel.from_pretrained("/path/to/checkpoint.pt") ``` + **Path Formats:** - WandB: `wandb:entity/project/run_id` or `wandb:entity/project/runs/run_id` @@ -415,12 +467,12 @@ Downloaded runs are cached in `SPD_OUT_DIR/runs/-/`. - This includes not setting off multiple sweeps/evals that total >8 GPUs - Monitor jobs with: `squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me` - ## Coding Guidelines & Software Engineering Principles **This is research code, not production. Prioritize simplicity and fail-fast over defensive programming.** Core principles: + - **Fail fast** - assert assumptions, crash on violations, don't silently recover - **No legacy support** - delete unused code, don't add fallbacks for old formats or migration shims - **Narrow types** - avoid `| None` unless null is semantically meaningful; use discriminated unions over bags of optional fields @@ -451,11 +503,12 @@ config = get_config(path) value = config.key ``` - ### Tests + - The point of tests in this codebase is to ensure that the code is working as expected, not to prevent production outages - there's no deployment here. Therefore, don't worry about lots of larger integration/end-to-end tests. These often require too much overhead for what it's worth in our case, and this codebase is interactively run so often that issues will likely be caught by the user at very little cost. ### Assertions and error handling + - If you have an invariant in your head, assert it. Are you afraid to assert? Sounds like your program might already be broken. Assert, assert, assert. Never soft fail. - Do not write: `if everythingIsOk: continueHappyPath()`. Instead do `assert everythingIsOk` - You should have a VERY good reason to handle an error gracefully. If your program isn't working like it should then it shouldn't be running, you should be fixing it. @@ -463,11 +516,13 @@ value = config.key - **Write for the golden path.** Never let edge cases bloat the code. Before handling them, just raise an exception. If an edge case becomes annoying enough, we'll handle it then — but write first and foremost for the common case. ### Control Flow + - Keep I/O as high up as possible. Make as many functions as possible pure. - Prefer `match` over `if/elif/else` chains when dispatching on conditions - more declarative and makes cases explicit - If you either have (a and b) or neither, don't make them both independently optional. Instead, put them in an optional tuple ### Types, Arguments, and Defaults + - Write your invariants into types as much as possible. - Use jaxtyping for tensor shapes (though for now we don't do runtime checking) - Always use the PEP 604 typing format of `|` for unions and `type | None` over `Optional`. @@ -483,12 +538,11 @@ value = config.key - Don't use `from __future__ import annotations` — use string quotes for forward references instead. ### Tensor Operations + - Try to use einops by default for clarity. - Assert shapes liberally - Document complex tensor manipulations - - ### Comments - Comments hide sloppy code. If you feel the need to write a comment, consider that you should instead @@ -498,10 +552,11 @@ value = config.key - separate an inlined computation into a meaningfully named variable - Don’t write dialogic / narrativised comments or code. Instead, write comments that describe the code as is, not the diff you're making. Examples of narrativising comments: - - `# the function now uses y instead of x` - - `# changed to be faster` - - `# we now traverse in reverse` + - `# the function now uses y instead of x` + - `# changed to be faster` + - `# we now traverse in reverse` - Here's an example of a bad diff, where the new comment makes reference to a change in code, not just the state of the code: + ``` 95 - # Reservoir states 96 - reservoir_states: list[ReservoirState] @@ -509,14 +564,15 @@ value = config.key 96 + reservoir: TensorReservoirState ``` - ### Other Important Software Development Practices + - Don't add legacy fallbacks or migration code - just change it and let old data be manually migrated if needed. -- Delete unused code. +- Delete unused code. - If an argument is always x, strongly consider removing as an argument and just inlining - **Update CLAUDE.md files** when changing code structure, adding/removing files, or modifying key interfaces. Update the CLAUDE.md in the same directory (or nearest parent) as the changed files. ### GitHub + - To view github issues and PRs, use the github cli (e.g. `gh issue view 28` or `gh pr view 30`). - When making PRs, use the github template defined in `.github/pull_request_template.md`. - Before committing, ALWAYS ensure you are on the correct branch and do not use `git add .` to add all unstaged files. Instead, add only the individual files you changed, don't commit all files. diff --git a/find_clean_facts.py b/find_clean_facts.py deleted file mode 100644 index e22ca906d..000000000 --- a/find_clean_facts.py +++ /dev/null @@ -1,572 +0,0 @@ -#!/usr/bin/env python3 -""" -Find the cleanest (most monosemantic) facts from the SPD analysis. - -A fact is "clean" if the components that fire on it are monosemantic. - -For down_proj: A component is monosemantic if it responds to a single label. -For up_proj: A component is monosemantic if it responds to: - - A single label, OR - - A single input element at position 0, 1, or 2 - -We score each fact based on how monosemantic its firing components are. -""" - -import re -from collections import Counter, defaultdict - - -def parse_analysis_file(filepath: str): - """Parse the analysis.txt file to extract component and fact information.""" - - with open(filepath) as f: - lines = f.readlines() - - # Parse component-to-facts mapping (from the COMPONENT ACTIVATION ANALYSIS section) - up_proj_components = defaultdict(list) # component_id -> list of (fact_idx, input, label) - down_proj_components = defaultdict(list) - - # Parse the per-fact analysis (from PER-FACT COMPONENT ANALYSIS section) - up_proj_per_fact = {} # fact_idx -> {inputs, label, components} - down_proj_per_fact = {} - - current_module = None - current_section = None # 'component_analysis' or 'per_fact' - current_component = None - - i = 0 - while i < len(lines): - line = lines[i].strip() - - # Detect section changes - if "COMPONENT ACTIVATION ANALYSIS" in line: - current_section = "component_analysis" - elif "PER-FACT COMPONENT ANALYSIS" in line: - current_section = "per_fact" - elif "SUMMARY STATISTICS" in line: - current_section = "summary" - - # Detect module changes - if "MODULE: block.mlp.up_proj" in line: - current_module = "up_proj" - elif "MODULE: block.mlp.down_proj" in line: - current_module = "down_proj" - - # Parse component activation analysis section - if current_section == "component_analysis" and current_module: - # Parse component header: [Rank X] Component Y (mean CI=Z): N facts above threshold - comp_match = re.match(r"\[Rank \d+\] Component (\d+)", line) - if comp_match: - current_component = int(comp_match.group(1)) - - # Parse fact line: Fact X: input=[a, b, c] → label=Y (CI=Z) - fact_match = re.match( - r"Fact\s+(\d+): input=\[(\d+), (\d+), (\d+)\] → label=(\d+)", line - ) - if fact_match and current_component is not None: - fact_idx = int(fact_match.group(1)) - inputs = [ - int(fact_match.group(2)), - int(fact_match.group(3)), - int(fact_match.group(4)), - ] - label = int(fact_match.group(5)) - - if current_module == "up_proj": - up_proj_components[current_component].append((fact_idx, inputs, label)) - else: - down_proj_components[current_component].append((fact_idx, inputs, label)) - - # Parse per-fact analysis section - if current_section == "per_fact" and current_module: - # Parse fact line - fact_match = re.match( - r"Fact\s+(\d+): input=\[(\d+), (\d+), (\d+)\] → label=(\d+)", line - ) - if fact_match: - fact_idx = int(fact_match.group(1)) - inputs = [ - int(fact_match.group(2)), - int(fact_match.group(3)), - int(fact_match.group(4)), - ] - label = int(fact_match.group(5)) - - # Look for components in the next lines - components = [] - j = i + 1 - while j < len(lines): - next_line = lines[j].strip() - - # Check if we've hit the next fact or section - if ( - next_line.startswith("Fact ") - or next_line.startswith("===") - or next_line.startswith("MODULE:") - ): - break - - # Parse component activations like C206(1.000) - comp_matches = re.findall(r"C(\d+)\(([\d.]+)\)", next_line) - for comp_id, ci_score in comp_matches: - components.append((int(comp_id), float(ci_score))) - - j += 1 - - if current_module == "up_proj": - up_proj_per_fact[fact_idx] = { - "inputs": inputs, - "label": label, - "components": components, - } - else: - down_proj_per_fact[fact_idx] = { - "inputs": inputs, - "label": label, - "components": components, - } - - i += 1 - - return up_proj_components, down_proj_components, up_proj_per_fact, down_proj_per_fact - - -def compute_component_monosemanticity(component_facts: list) -> dict: - """ - Compute monosemanticity scores for a component. - """ - if not component_facts: - return None - - labels = [f[2] for f in component_facts] - pos0_vals = [f[1][0] for f in component_facts] - pos1_vals = [f[1][1] for f in component_facts] - pos2_vals = [f[1][2] for f in component_facts] - - label_counts = Counter(labels) - pos0_counts = Counter(pos0_vals) - pos1_counts = Counter(pos1_vals) - pos2_counts = Counter(pos2_vals) - - n = len(component_facts) - - dominant_label, dominant_label_count = label_counts.most_common(1)[0] - dominant_pos0, dominant_pos0_count = pos0_counts.most_common(1)[0] - dominant_pos1, dominant_pos1_count = pos1_counts.most_common(1)[0] - dominant_pos2, dominant_pos2_count = pos2_counts.most_common(1)[0] - - return { - "n_facts": n, - "n_unique_labels": len(label_counts), - "dominant_label": dominant_label, - "dominant_label_ratio": dominant_label_count / n, - "n_unique_pos0": len(pos0_counts), - "dominant_pos0": dominant_pos0, - "dominant_pos0_ratio": dominant_pos0_count / n, - "n_unique_pos1": len(pos1_counts), - "dominant_pos1": dominant_pos1, - "dominant_pos1_ratio": dominant_pos1_count / n, - "n_unique_pos2": len(pos2_counts), - "dominant_pos2": dominant_pos2, - "dominant_pos2_ratio": dominant_pos2_count / n, - } - - -def is_component_monosemantic(stats: dict, threshold: float = 0.9) -> tuple[bool, str]: - """ - Determine if a component is monosemantic based on its statistics. - Returns (is_monosemantic, reason) - """ - if stats is None: - return False, "no_data" - - # Check if it responds to a single label - if stats["dominant_label_ratio"] >= threshold: - return True, f"label_{stats['dominant_label']}" - - # Check if it responds to a single input element - if stats["dominant_pos0_ratio"] >= threshold: - return True, f"pos0_{stats['dominant_pos0']}" - if stats["dominant_pos1_ratio"] >= threshold: - return True, f"pos1_{stats['dominant_pos1']}" - if stats["dominant_pos2_ratio"] >= threshold: - return True, f"pos2_{stats['dominant_pos2']}" - - return False, "polysemantic" - - -def compute_monosemanticity_score(stats: dict) -> float: - """ - Compute a monosemanticity score from 0 to 1. - Higher score = more monosemantic. - """ - if stats is None: - return 0.0 - - # The score is the maximum of all the dominant ratios - return max( - stats["dominant_label_ratio"], - stats["dominant_pos0_ratio"], - stats["dominant_pos1_ratio"], - stats["dominant_pos2_ratio"], - ) - - -def score_fact( - fact_info: dict, - up_proj_mono_scores: dict, - down_proj_mono_scores: dict, - up_proj_stats: dict, - down_proj_stats: dict, -) -> tuple[float, dict]: - """ - Score a fact based on how monosemantic its firing components are. - Returns (score, details) - """ - up_components = fact_info.get("up_proj_components", []) - down_components = fact_info.get("down_proj_components", []) - - if not up_components and not down_components: - return 0.0, { - "reason": "no_components", - "up_proj_components": [], - "down_proj_components": [], - "n_components": 0, - } - - # For each component, get its monosemanticity score - up_scores = [] - for comp_id, ci_score in up_components: - mono_score = up_proj_mono_scores.get(comp_id, 0.0) - stats = up_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(stats, threshold=0.9) if stats else (False, "unknown") - ) - up_scores.append((comp_id, mono_score, ci_score, is_mono, reason)) - - down_scores = [] - for comp_id, ci_score in down_components: - mono_score = down_proj_mono_scores.get(comp_id, 0.0) - stats = down_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(stats, threshold=0.9) if stats else (False, "unknown") - ) - down_scores.append((comp_id, mono_score, ci_score, is_mono, reason)) - - # Compute fact score as minimum monosemanticity of all components - all_mono_scores = [s[1] for s in up_scores] + [s[1] for s in down_scores] - - if not all_mono_scores: - return 0.0, { - "reason": "no_scores", - "up_proj_components": up_scores, - "down_proj_components": down_scores, - "n_components": 0, - } - - min_score = min(all_mono_scores) - mean_score = sum(all_mono_scores) / len(all_mono_scores) - - # Count how many components are monosemantic - n_mono = sum(1 for s in up_scores + down_scores if s[3]) - total = len(up_scores) + len(down_scores) - - return min_score, { - "up_proj_components": up_scores, - "down_proj_components": down_scores, - "min_mono_score": min_score, - "mean_mono_score": mean_score, - "n_components": total, - "n_mono_components": n_mono, - "mono_ratio": n_mono / total if total > 0 else 0, - } - - -def main(): - print("Parsing analysis.txt...") - up_proj_comps, down_proj_comps, up_proj_facts, down_proj_facts = parse_analysis_file( - "analysis.txt" - ) - - print(f"\nFound {len(up_proj_comps)} up_proj components with facts") - print(f"Found {len(down_proj_comps)} down_proj components with facts") - print(f"Found {len(up_proj_facts)} facts with up_proj info") - print(f"Found {len(down_proj_facts)} facts with down_proj info") - - # Sample check - if up_proj_facts: - sample_fact = list(up_proj_facts.items())[0] - print(f"\nSample up_proj fact: {sample_fact}") - if down_proj_facts: - sample_fact = list(down_proj_facts.items())[0] - print(f"Sample down_proj fact: {sample_fact}") - - # Compute monosemanticity for each component - print("\nComputing component monosemanticity...") - - up_proj_stats = {} - up_proj_mono_scores = {} - for comp_id, facts in up_proj_comps.items(): - stats = compute_component_monosemanticity(facts) - up_proj_stats[comp_id] = stats - up_proj_mono_scores[comp_id] = compute_monosemanticity_score(stats) - - down_proj_stats = {} - down_proj_mono_scores = {} - for comp_id, facts in down_proj_comps.items(): - stats = compute_component_monosemanticity(facts) - down_proj_stats[comp_id] = stats - down_proj_mono_scores[comp_id] = compute_monosemanticity_score(stats) - - # Print some example monosemantic components - print("\n" + "=" * 80) - print("MONOSEMANTIC UP_PROJ COMPONENTS (threshold >= 0.9)") - print("=" * 80) - mono_up = [] - for comp_id, stats in up_proj_stats.items(): - is_mono, reason = is_component_monosemantic(stats, threshold=0.9) - if is_mono: - mono_up.append((comp_id, stats, reason)) - - mono_up.sort(key=lambda x: compute_monosemanticity_score(x[1]), reverse=True) - for comp_id, stats, reason in mono_up[:20]: - print( - f" Component {comp_id}: {reason}, score={compute_monosemanticity_score(stats):.3f}, n_facts={stats['n_facts']}" - ) - print(f" ... and {len(mono_up) - 20} more" if len(mono_up) > 20 else "") - - print(f"\nTotal monosemantic up_proj components: {len(mono_up)} / {len(up_proj_stats)}") - - print("\n" + "=" * 80) - print("MONOSEMANTIC DOWN_PROJ COMPONENTS (threshold >= 0.9)") - print("=" * 80) - mono_down = [] - for comp_id, stats in down_proj_stats.items(): - is_mono, reason = is_component_monosemantic(stats, threshold=0.9) - if is_mono: - mono_down.append((comp_id, stats, reason)) - - mono_down.sort(key=lambda x: compute_monosemanticity_score(x[1]), reverse=True) - for comp_id, stats, reason in mono_down[:20]: - print( - f" Component {comp_id}: {reason}, score={compute_monosemanticity_score(stats):.3f}, n_facts={stats['n_facts']}" - ) - print(f" ... and {len(mono_down) - 20} more" if len(mono_down) > 20 else "") - - print(f"\nTotal monosemantic down_proj components: {len(mono_down)} / {len(down_proj_stats)}") - - # Combine up_proj and down_proj info for each fact - print("\n" + "=" * 80) - print("SCORING FACTS BY MONOSEMANTICITY") - print("=" * 80) - - all_facts = set(up_proj_facts.keys()) | set(down_proj_facts.keys()) - fact_scores = [] - - for fact_idx in all_facts: - up_info = up_proj_facts.get(fact_idx, {}) - down_info = down_proj_facts.get(fact_idx, {}) - - # Get the inputs and label from either source - inputs = up_info.get("inputs") or down_info.get("inputs", []) - label = up_info.get("label", down_info.get("label", -1)) - - combined_info = { - "inputs": inputs, - "label": label, - "up_proj_components": up_info.get("components", []), - "down_proj_components": down_info.get("components", []), - } - - score, details = score_fact( - combined_info, - up_proj_mono_scores, - down_proj_mono_scores, - up_proj_stats, - down_proj_stats, - ) - - fact_scores.append( - { - "fact_idx": fact_idx, - "inputs": inputs, - "label": label, - "score": score, - "details": details, - } - ) - - # Sort by score (highest = cleanest), then by mono ratio, then by fewer components - fact_scores.sort( - key=lambda x: ( - x["score"], - x["details"].get("mono_ratio", 0), - -x["details"].get("n_components", 999), - ), - reverse=True, - ) - - # Print top cleanest facts - print("\nTOP 50 CLEANEST FACTS (highest monosemanticity score):") - print("-" * 80) - - for i, fs in enumerate(fact_scores[:50]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - up_str = ", ".join([f"C{c[0]}({c[4]})" for c in up_comps]) - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"\n{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print(f" Score: {fs['score']:.3f}, mono_ratio: {fs['details'].get('mono_ratio', 0):.2f}") - print(f" Up_proj ({len(up_comps)}): {up_str if up_str else 'none'}") - print(f" Down_proj ({len(down_comps)}): {down_str if down_str else 'none'}") - - # Find facts where ALL components are monosemantic - print("\n" + "=" * 80) - print("FACTS WHERE ALL COMPONENTS ARE MONOSEMANTIC") - print("=" * 80) - - all_mono_facts = [ - fs - for fs in fact_scores - if fs["details"].get("n_components", 0) > 0 and fs["details"].get("mono_ratio", 0) == 1.0 - ] - - print(f"\nFound {len(all_mono_facts)} facts where ALL components are monosemantic:\n") - - for i, fs in enumerate(all_mono_facts[:30]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - up_str = ", ".join([f"C{c[0]}({c[4]})" for c in up_comps]) - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print(f" Up_proj: {up_str if up_str else 'none'}") - print(f" Down_proj: {down_str if down_str else 'none'}") - print() - - if len(all_mono_facts) > 30: - print(f" ... and {len(all_mono_facts) - 30} more") - - # Also show facts with only 1 component firing in up_proj - print("\n" + "=" * 80) - print("FACTS WITH ONLY 1 UP_PROJ COMPONENT FIRING") - print("=" * 80) - - single_comp_facts = [ - fs for fs in fact_scores if len(fs["details"].get("up_proj_components", [])) == 1 - ] - single_comp_facts.sort(key=lambda x: x["score"], reverse=True) - - print(f"\nFound {len(single_comp_facts)} facts with only 1 up_proj component:\n") - - for i, fs in enumerate(single_comp_facts[:30]): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - comp_id = up_comps[0][0] - comp_stats = up_proj_stats.get(comp_id) - is_mono, reason = ( - is_component_monosemantic(comp_stats, threshold=0.9) - if comp_stats - else (False, "unknown") - ) - - down_str = ", ".join([f"C{c[0]}({c[4]})" for c in down_comps]) - - print(f"{i + 1}. Fact {fs['fact_idx']}: input={fs['inputs']} → label={fs['label']}") - print( - f" Up_proj C{comp_id}: mono_score={fs['score']:.3f}, is_mono={is_mono}, reason={reason}" - ) - print(f" Down_proj: {down_str if down_str else 'none'}") - if comp_stats: - print( - f" Component stats: dominant_label={comp_stats['dominant_label']} ({comp_stats['dominant_label_ratio']:.1%})" - ) - print() - - # Print summary - print("\n" + "=" * 80) - print("SUMMARY") - print("=" * 80) - - # Count facts with at least one component - facts_with_components = [fs for fs in fact_scores if fs["details"].get("n_components", 0) > 0] - print(f"\nTotal facts with at least one component: {len(facts_with_components)}") - - score_thresholds = [1.0, 0.95, 0.9, 0.8, 0.5, 0.0] - for thresh in score_thresholds: - count = sum(1 for fs in facts_with_components if fs["score"] >= thresh) - print(f" Facts with monosemanticity score >= {thresh}: {count}") - - # Save results to a file - print("\n\nSaving detailed results to clean_facts_ranking.txt...") - with open("clean_facts_ranking.txt", "w") as f: - f.write("FACTS RANKED BY MONOSEMANTICITY SCORE\n") - f.write("=" * 80 + "\n\n") - f.write("A fact is 'clean' if all components that fire on it are monosemantic.\n") - f.write( - "Monosemantic = responds to a single label or single input position value (>= 90%).\n\n" - ) - - f.write(f"Total facts with at least one component: {len(facts_with_components)}\n") - f.write(f"Facts where ALL components are monosemantic: {len(all_mono_facts)}\n\n") - - f.write("=" * 80 + "\n") - f.write("CLEANEST FACTS (all components monosemantic)\n") - f.write("=" * 80 + "\n\n") - - for i, fs in enumerate(all_mono_facts): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - f.write(f"Rank {i + 1}: Fact {fs['fact_idx']}\n") - f.write(f" Input: {fs['inputs']} → Label: {fs['label']}\n") - f.write(f" Monosemanticity Score: {fs['score']:.4f}\n") - f.write(f" Up_proj components ({len(up_comps)}):\n") - for comp_id, mono_score, ci_score, _is_mono, reason in up_comps: - f.write( - f" C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write(f" Down_proj components ({len(down_comps)}):\n") - for comp_id, mono_score, ci_score, _is_mono, reason in down_comps: - f.write( - f" C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write("\n") - - f.write("\n" + "=" * 80 + "\n") - f.write("ALL FACTS RANKED\n") - f.write("=" * 80 + "\n\n") - - for i, fs in enumerate(facts_with_components): - up_comps = fs["details"].get("up_proj_components", []) - down_comps = fs["details"].get("down_proj_components", []) - - f.write(f"Rank {i + 1}: Fact {fs['fact_idx']}\n") - f.write(f" Input: {fs['inputs']} → Label: {fs['label']}\n") - f.write(f" Min Monosemanticity Score: {fs['score']:.4f}\n") - f.write( - f" Mono ratio: {fs['details'].get('mono_ratio', 0):.2f} ({fs['details'].get('n_mono_components', 0)}/{fs['details'].get('n_components', 0)})\n" - ) - f.write(f" Up_proj components ({len(up_comps)}):\n") - for comp_id, mono_score, ci_score, is_mono, reason in up_comps: - mono_marker = "✓" if is_mono else "✗" - f.write( - f" {mono_marker} C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write(f" Down_proj components ({len(down_comps)}):\n") - for comp_id, mono_score, ci_score, is_mono, reason in down_comps: - mono_marker = "✓" if is_mono else "✗" - f.write( - f" {mono_marker} C{comp_id}: CI={ci_score:.3f}, mono={mono_score:.3f}, reason={reason}\n" - ) - f.write("\n") - - print("Done!") - - -if __name__ == "__main__": - main() diff --git a/package-lock.json b/package-lock.json deleted file mode 100644 index ca5f8195a..000000000 --- a/package-lock.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "name": "spd", - "lockfileVersion": 3, - "requires": true, - "packages": {} -} diff --git a/pyproject.toml b/pyproject.toml index 88c3405a8..0a5608ce8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,8 +33,9 @@ dependencies = [ "orjson", "aiolimiter>=1.2", "openrouter>=0.1.1", - "httpx>=0.28.0", - "zstandard" # For streaming datasets + "httpx>=0.28.0", # For streaming datasets + "zstandard", + "kaleido==0.2.1", ] [dependency-groups] @@ -56,7 +57,9 @@ spd-clustering = "spd.clustering.scripts.run_pipeline:cli" spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" +spd-investigate = "spd.investigate.scripts.run_slurm_cli:cli" spd-postprocess = "spd.postprocess.cli:cli" +spd-graph-interp = "spd.graph_interp.scripts.run_slurm_cli:cli" [build-system] requires = ["setuptools", "wheel"] @@ -69,7 +72,7 @@ include = ["spd*"] [tool.ruff] line-length = 100 fix = true -extend-exclude = ["spd/app/frontend"] +extend-exclude = ["spd/app/frontend", ".circuits-ref"] [tool.ruff.lint] ignore = [ diff --git a/scripts/export_circuit_json.py b/scripts/export_circuit_json.py new file mode 100644 index 000000000..61eeabdac --- /dev/null +++ b/scripts/export_circuit_json.py @@ -0,0 +1,199 @@ +"""Export an OptimizedPromptAttributionResult to JSON for the circuit graph renderer. + +Usage from a notebook / script: + + from scripts.export_circuit_json import export_circuit_json + + export_circuit_json( + circuit=circuit, + token_ids=tokens.tolist(), + token_strings=tok.get_spans(tokens.tolist()), + output_path=Path("king_circuit.json"), + interp=interp, # optional InterpRepo + graph_interp=gi_repo, # optional GraphInterpRepo + ) +""" + +import json +from pathlib import Path + +from spd.app.backend.compute import Edge, OptimizedPromptAttributionResult +from spd.autointerp.repo import InterpRepo +from spd.graph_interp.repo import GraphInterpRepo + + +def _parse_node_key(key: str) -> tuple[str, int, int]: + """Parse 'h.3.attn.o_proj:11:361' -> ('h.3.attn.o_proj', 11, 361).""" + parts = key.split(":") + layer = ":".join(parts[:-2]) + seq = int(parts[-2]) + cidx = int(parts[-1]) + return layer, seq, cidx + + +def _concrete_to_canonical(layer: str) -> str: + """Map concrete model path to canonical address used by graph layout. + + h.3.attn.o_proj -> 3.attn.o + h.3.attn.v_proj -> 3.attn.v + h.3.attn.q_proj -> 3.attn.q + h.3.attn.k_proj -> 3.attn.k + h.3.mlp.c_fc -> 3.mlp.up + h.3.mlp.down_proj -> 3.mlp.down + h.3.mlp.gate_proj -> 3.glu.gate + h.3.mlp.up_proj -> 3.glu.up + wte -> embed + lm_head -> output + """ + if layer == "wte": + return "embed" + if layer == "lm_head": + return "output" + + PROJ_MAP = { + "q_proj": "q", + "k_proj": "k", + "v_proj": "v", + "o_proj": "o", + "c_fc": "up", + "down_proj": "down", + "gate_proj": "gate", + "up_proj": "up", + } + + # h.{block}.{sublayer}.{proj_name} + parts = layer.split(".") + assert len(parts) == 4, f"Expected h.N.sublayer.proj, got {layer!r}" + block_idx = parts[1] + sublayer = parts[2] # "attn" or "mlp" + proj_name = parts[3] + + canonical_proj = PROJ_MAP.get(proj_name) + assert canonical_proj is not None, f"Unknown projection: {proj_name!r} in {layer!r}" + + # Determine canonical sublayer + if sublayer == "attn": + canonical_sublayer = "attn" + elif sublayer == "mlp": + canonical_sublayer = "glu" if proj_name in ("gate_proj", "up_proj") else "mlp" + else: + canonical_sublayer = sublayer + + return f"{block_idx}.{canonical_sublayer}.{canonical_proj}" + + +def _edge_to_dict(e: Edge) -> dict[str, object]: + return { + "source": str(e.source), + "target": str(e.target), + "attribution": e.strength, + "is_cross_seq": e.is_cross_seq, + } + + +def _get_label( + layer: str, + cidx: int, + interp: InterpRepo | None, + graph_interp: GraphInterpRepo | None, +) -> str | None: + component_key = f"{layer}:{cidx}" + if graph_interp is not None: + unified = graph_interp.get_unified_label(component_key) + if unified is not None: + return unified.label + output = graph_interp.get_output_label(component_key) + if output is not None: + return output.label + if interp is not None: + ir = interp.get_interpretation(component_key) + if ir is not None: + return ir.label + return None + + +def export_circuit_json( + circuit: OptimizedPromptAttributionResult, + token_ids: list[int], + token_strings: list[str], + output_path: Path, + min_ci: float = 0.3, + min_edge_attr: float = 0.1, + interp: InterpRepo | None = None, + graph_interp: GraphInterpRepo | None = None, +) -> None: + """Export circuit to JSON for the standalone HTML renderer. + + Args: + circuit: The optimized circuit result from EditableModel.optimize_circuit. + token_ids: Raw token IDs for each position. + token_strings: Display strings for each position (from tok.get_spans). + output_path: Where to write the JSON. + min_ci: Minimum CI to include a node. + min_edge_attr: Minimum |attribution| to include an edge. + interp: Optional InterpRepo for autointerp labels. + graph_interp: Optional GraphInterpRepo for graph-interp labels. + """ + # Build tokens list + tokens = [ + {"pos": i, "id": tid, "string": tstr} + for i, (tid, tstr) in enumerate(zip(token_ids, token_strings, strict=True)) + ] + + # Build nodes from ci_vals, filtering by min_ci + nodes = [] + node_keys_kept = set() + for key, ci in circuit.node_ci_vals.items(): + if ci < min_ci: + continue + layer, seq, cidx = _parse_node_key(key) + canonical = _concrete_to_canonical(layer) + act = circuit.node_subcomp_acts.get(key, 0.0) + + label = _get_label(layer, cidx, interp, graph_interp) + + nodes.append( + { + "key": key, + "graph_key": f"{layer}:{cidx}", + "layer": layer, + "canonical": canonical, + "seq": seq, + "cidx": cidx, + "ci": round(ci, 4), + "activation": round(act, 4), + "token": token_strings[seq] if seq < len(token_strings) else "?", + "label": label, + } + ) + node_keys_kept.add(key) + + # Build edges, filtering by min_edge_attr and requiring both endpoints in kept nodes + edges = [] + for e in circuit.edges: + if abs(e.strength) < min_edge_attr: + continue + src_key = str(e.source) + tgt_key = str(e.target) + if src_key not in node_keys_kept or tgt_key not in node_keys_kept: + continue + edges.append(_edge_to_dict(e)) + + metrics: dict[str, float] = {"l0_total": circuit.metrics.l0_total} + if circuit.metrics.ci_masked_label_prob is not None: + metrics["ci_masked_label_prob"] = circuit.metrics.ci_masked_label_prob + if circuit.metrics.stoch_masked_label_prob is not None: + metrics["stoch_masked_label_prob"] = circuit.metrics.stoch_masked_label_prob + + data = { + "tokens": tokens, + "nodes": nodes, + "edges": edges, + "metrics": metrics, + } + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(data, f, indent=2) + + print(f"Exported {len(nodes)} nodes, {len(edges)} edges to {output_path}") diff --git a/scripts/migrate_harvest_data.py b/scripts/migrate_harvest_data.py new file mode 100644 index 000000000..c0da50dcf --- /dev/null +++ b/scripts/migrate_harvest_data.py @@ -0,0 +1,369 @@ +"""Migrate legacy harvest + autointerp data to the new layout. + +Copies data into new directories/DBs without modifying the originals. + +Legacy layout: + harvest//activation_contexts/harvest.db (schema: mean_ci, ci_values, component_acts) + harvest//correlations/*.pt + harvest//eval/intruder/*.jsonl + autointerp//interp.db (top-level, with scores already merged) + +New layout: + harvest//h-/harvest.db (schema: firing_density, mean_activations, firings, activations) + harvest//h-/*.pt + autointerp//a-/interp.db +""" + +import shutil +import sqlite3 +from dataclasses import dataclass, field +from pathlib import Path + +import fire +import orjson +import torch + +from spd.harvest.storage import TokenStatsStorage +from spd.settings import SPD_OUT_DIR + + +@dataclass +class RunDiagnostic: + run_id: str + harvest_db: bool = False + token_stats: bool = False + correlations: bool = False + n_components: int = 0 + n_tokens: int = 0 + ci_threshold: str = "" + intruder_scores: int = 0 + interp_db: bool = False + n_interpretations: int = 0 + interp_score_types: list[str] = field(default_factory=list) + already_migrated: bool = False + problems: list[str] = field(default_factory=list) + + @property + def ready(self) -> bool: + return not self.problems and not self.already_migrated + + +def diagnose_run(run_id: str, timestamp: str = "20260218_000000") -> RunDiagnostic: + harvest_root = SPD_OUT_DIR / "harvest" / run_id + autointerp_root = SPD_OUT_DIR / "autointerp" / run_id + diag = RunDiagnostic(run_id=run_id) + + # Already migrated? + if (harvest_root / f"h-{timestamp}").exists(): + diag.already_migrated = True + return diag + + # Harvest DB + old_db = harvest_root / "activation_contexts" / "harvest.db" + diag.harvest_db = old_db.exists() + if not diag.harvest_db: + diag.problems.append("missing harvest.db") + return diag + + conn = sqlite3.connect(f"file:{old_db}?immutable=1", uri=True) + diag.n_components = conn.execute("SELECT COUNT(*) FROM components").fetchone()[0] + threshold_row = conn.execute("SELECT value FROM config WHERE key = 'ci_threshold'").fetchone() + diag.ci_threshold = threshold_row[0] if threshold_row else "0.0 (default)" + conn.close() + + # Token stats + ts_path = harvest_root / "correlations" / "token_stats.pt" + diag.token_stats = ts_path.exists() + if diag.token_stats: + ts_data = torch.load(ts_path, weights_only=False) + diag.n_tokens = ts_data["n_tokens"] + else: + diag.problems.append("missing token_stats.pt") + + # Correlations + diag.correlations = (harvest_root / "correlations" / "component_correlations.pt").exists() + + # Intruder scores + intruder_dir = harvest_root / "eval" / "intruder" + if intruder_dir.exists(): + jsonl_files = list(intruder_dir.glob("*.jsonl")) + if jsonl_files: + largest = max(jsonl_files, key=lambda f: f.stat().st_size) + with open(largest, "rb") as f: + diag.intruder_scores = sum(1 for _ in f) + + # Autointerp + interp_db = autointerp_root / "interp.db" + diag.interp_db = interp_db.exists() + if diag.interp_db: + conn = sqlite3.connect(f"file:{interp_db}?immutable=1", uri=True) + tables = [ + r[0] + for r in conn.execute("SELECT name FROM sqlite_master WHERE type='table'").fetchall() + ] + if "interpretations" in tables: + diag.n_interpretations = conn.execute( + "SELECT COUNT(*) FROM interpretations" + ).fetchone()[0] + if "scores" in tables: + diag.interp_score_types = [ + r[0] for r in conn.execute("SELECT DISTINCT score_type FROM scores").fetchall() + ] + conn.close() + + return diag + + +def print_diagnostics(diagnostics: list[RunDiagnostic]) -> None: + ready = [d for d in diagnostics if d.ready] + skipped = [d for d in diagnostics if d.already_migrated] + blocked = [d for d in diagnostics if d.problems] + + if skipped: + print(f"Already migrated ({len(skipped)}): {', '.join(d.run_id for d in skipped)}\n") + + if blocked: + print(f"BLOCKED ({len(blocked)}):") + for d in blocked: + print(f" {d.run_id}: {', '.join(d.problems)}") + print() + + if ready: + print(f"Ready to migrate ({len(ready)}):") + print( + f"{'run_id':20s} {'comps':>6s} {'n_tokens':>14s} {'threshold':>10s} " + f"{'intruder':>8s} {'interps':>7s} {'scores':>20s} {'corr':>5s}" + ) + print("-" * 95) + for d in ready: + scores_str = ", ".join(d.interp_score_types) if d.interp_score_types else "-" + print( + f"{d.run_id:20s} {d.n_components:>6d} {d.n_tokens:>14,} {d.ci_threshold:>10s} " + f"{d.intruder_scores:>8d} {d.n_interpretations:>7d} {scores_str:>20s} " + f"{'yes' if d.correlations else 'no':>5s}" + ) + print() + + +def migrate_harvest_db(old_db_path: Path, new_db_path: Path, token_stats_path: Path) -> int: + """Copy harvest DB, transforming schema from legacy to new format. + + Old schema: mean_ci REAL, activation_examples with {token_ids, ci_values, component_acts} + New schema: firing_density REAL, mean_activations TEXT, + activation_examples with {token_ids, firings, activations} + """ + old_conn = sqlite3.connect(f"file:{old_db_path}?immutable=1", uri=True) + old_conn.row_factory = sqlite3.Row + + new_conn = sqlite3.connect(str(new_db_path)) + new_conn.execute("PRAGMA journal_mode=WAL") + new_conn.executescript("""\ + CREATE TABLE IF NOT EXISTS components ( + component_key TEXT PRIMARY KEY, + layer TEXT NOT NULL, + component_idx INTEGER NOT NULL, + firing_density REAL NOT NULL, + mean_activations TEXT NOT NULL, + activation_examples TEXT NOT NULL, + input_token_pmi TEXT NOT NULL, + output_token_pmi TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS scores ( + component_key TEXT NOT NULL, + score_type TEXT NOT NULL, + score REAL NOT NULL, + details TEXT NOT NULL, + PRIMARY KEY (component_key, score_type) + ); + """) + + # Load firing densities from token_stats.pt + assert token_stats_path.exists(), f"No token_stats.pt at {token_stats_path}" + ts = TokenStatsStorage.load(token_stats_path) + firing_density_map: dict[str, float] = {} + for i, key in enumerate(ts.component_keys): + firing_density_map[key] = ts.firing_counts[i].item() / ts.n_tokens + + # Read activation threshold from old DB config + threshold_row = old_conn.execute( + "SELECT value FROM config WHERE key = 'ci_threshold'" + ).fetchone() + activation_threshold = float(threshold_row["value"]) if threshold_row else 0.0 + + # Migrate config + for row in old_conn.execute("SELECT key, value FROM config").fetchall(): + key = row["key"] + value = row["value"] + if key == "ci_threshold": + key = "activation_threshold" + new_conn.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", (key, value)) + + # Migrate components row-by-row + n = 0 + rows = old_conn.execute("SELECT * FROM components").fetchall() + for row in rows: + old_examples = orjson.loads(row["activation_examples"]) + new_examples = [] + for ex in old_examples: + ci_values = ex["ci_values"] + component_acts = ex["component_acts"] + new_examples.append( + { + "token_ids": ex["token_ids"], + "firings": [v > activation_threshold for v in ci_values], + "activations": { + "causal_importance": ci_values, + "component_activation": component_acts, + }, + } + ) + + mean_ci = row["mean_ci"] + component_key = row["component_key"] + assert component_key in firing_density_map, f"{component_key} missing from token_stats.pt" + firing_density = firing_density_map[component_key] + mean_activations = {"causal_importance": mean_ci} + + new_conn.execute( + "INSERT OR REPLACE INTO components VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + component_key, + row["layer"], + row["component_idx"], + firing_density, + orjson.dumps(mean_activations).decode(), + orjson.dumps(new_examples).decode(), + row["input_token_pmi"], + row["output_token_pmi"], + ), + ) + n += 1 + + new_conn.commit() + + # Migrate intruder scores from JSONL into harvest DB scores table + old_intruder_dir = old_db_path.parent.parent / "eval" / "intruder" + if old_intruder_dir.exists(): + # Use the largest file (most complete run) + jsonl_files = sorted(old_intruder_dir.glob("*.jsonl"), key=lambda f: f.stat().st_size) + if jsonl_files: + intruder_file = jsonl_files[-1] + n_scores = 0 + with open(intruder_file, "rb") as f: + for line in f: + record = orjson.loads(line) + new_conn.execute( + "INSERT OR REPLACE INTO scores VALUES (?, ?, ?, ?)", + ( + record["component_key"], + "intruder", + record["score"], + orjson.dumps(record.get("trials", [])).decode(), + ), + ) + n_scores += 1 + new_conn.commit() + print(f" Migrated {n_scores} intruder scores from {intruder_file.name}") + + old_conn.close() + new_conn.close() + return n + + +def migrate_autointerp_db(old_db_path: Path, new_db_path: Path) -> int: + """Copy autointerp DB (schema is compatible, just copy and strip intruder scores).""" + shutil.copy2(old_db_path, new_db_path) + + # Remove intruder scores — those belong in harvest.db in the new layout + conn = sqlite3.connect(str(new_db_path)) + conn.execute("DELETE FROM scores WHERE score_type = 'intruder'") + conn.commit() + n = conn.execute("SELECT COUNT(*) FROM interpretations").fetchone()[0] + conn.close() + return n + + +def migrate_run(run_id: str, timestamp: str = "20260218_000000") -> None: + """Migrate a single run's harvest + autointerp data to the new layout.""" + harvest_root = SPD_OUT_DIR / "harvest" / run_id + autointerp_root = SPD_OUT_DIR / "autointerp" / run_id + + print(f"Migrating {run_id}...") + + # --- Harvest --- + old_harvest_db = harvest_root / "activation_contexts" / "harvest.db" + assert old_harvest_db.exists(), f"No legacy harvest DB at {old_harvest_db}" + + new_subrun = harvest_root / f"h-{timestamp}" + assert not new_subrun.exists(), f"Target already exists: {new_subrun}" + new_subrun.mkdir(parents=True) + + token_stats_path = harvest_root / "correlations" / "token_stats.pt" + assert token_stats_path.exists(), f"No token_stats.pt at {token_stats_path}" + print(f" Harvest DB: {old_harvest_db} -> {new_subrun / 'harvest.db'}") + n_components = migrate_harvest_db(old_harvest_db, new_subrun / "harvest.db", token_stats_path) + print(f" Migrated {n_components} components") + + # Copy .pt files + old_corr_dir = harvest_root / "correlations" + for pt_file in ["component_correlations.pt", "token_stats.pt"]: + src = old_corr_dir / pt_file + if src.exists(): + dst = new_subrun / pt_file + shutil.copy2(src, dst) + print(f" Copied {pt_file}") + + # --- Autointerp --- + old_interp_db = autointerp_root / "interp.db" + if old_interp_db.exists(): + new_autointerp_subrun = autointerp_root / f"a-{timestamp}" + assert not new_autointerp_subrun.exists(), f"Target already exists: {new_autointerp_subrun}" + new_autointerp_subrun.mkdir(parents=True) + + print(f" Interp DB: {old_interp_db} -> {new_autointerp_subrun / 'interp.db'}") + n_interps = migrate_autointerp_db(old_interp_db, new_autointerp_subrun / "interp.db") + print(f" Migrated {n_interps} interpretations (detection + fuzzing scores preserved)") + else: + print(" No top-level interp.db found, skipping autointerp") + + print("Done!") + + +def _find_legacy_run_ids() -> list[str]: + harvest_root = SPD_OUT_DIR / "harvest" + return sorted( + d.name + for d in harvest_root.iterdir() + if (d / "activation_contexts" / "harvest.db").exists() + ) + + +def diagnose(run_id: str | None = None, timestamp: str = "20260218_000000") -> None: + """Print diagnostic report without migrating anything.""" + run_ids = [run_id] if run_id else _find_legacy_run_ids() + diagnostics = [diagnose_run(rid, timestamp) for rid in run_ids] + print_diagnostics(diagnostics) + + +def migrate_all(timestamp: str = "20260218_000000", dry_run: bool = False) -> None: + """Discover and migrate all legacy harvest runs.""" + run_ids = _find_legacy_run_ids() + diagnostics = [diagnose_run(rid, timestamp) for rid in run_ids] + + print_diagnostics(diagnostics) + + ready = [d for d in diagnostics if d.ready] + if dry_run or not ready: + return + + for d in ready: + migrate_run(d.run_id, timestamp=timestamp) + print(f"\nAll done — migrated {len(ready)} runs") + + +if __name__ == "__main__": + fire.Fire({"run": migrate_run, "all": migrate_all, "diagnose": diagnose}) diff --git a/scripts/parse_transformer_circuits_post.py b/scripts/parse_transformer_circuits_post.py new file mode 100644 index 000000000..1f528179e --- /dev/null +++ b/scripts/parse_transformer_circuits_post.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +"""Export a Transformer Circuits post to local Markdown + downloaded image assets. + +Example: + .venv/bin/python scripts/parse_transformer_circuits_post.py \ + --url https://transformer-circuits.pub/2025/attribution-graphs/biology.html \ + --output-md papers/biology_source/biology.md \ + --assets-dir papers/biology_source/assets +""" + +from __future__ import annotations + +import argparse +import re +from collections.abc import Iterable +from pathlib import Path +from urllib.parse import urljoin, urlparse + +import requests +from bs4 import BeautifulSoup, NavigableString, Tag + +DEFAULT_URL = "https://transformer-circuits.pub/2025/attribution-graphs/biology.html" +USER_AGENT = "spd-biology-exporter/1.0" + + +def normalize_whitespace(text: str) -> str: + return re.sub(r"\s+", " ", text).strip() + + +def should_insert_space(left: str, right: str) -> bool: + if not left or not right: + return False + left = left.rstrip() + right = right.lstrip() + if not left or not right: + return False + if left.endswith((" ", "\n", "\t", "/", "(", "[", "{", "-", "“", '"', "'")): + return False + if right.startswith( + (" ", "\n", "\t", ".", ",", ":", ";", "!", "?", ")", "]", "}", "-", "”", '"', "'") + ): + return False + if left.endswith(" "): + return False + left_char = left[-1] + right_char = right[0] + return bool( + re.match(r"[A-Za-z0-9\]\)]", left_char) and re.match(r"[A-Za-z0-9\[\(]", right_char) + ) + + +class Exporter: + def __init__( + self, + *, + base_url: str, + output_md: Path, + assets_dir: Path, + download_assets: bool = True, + timeout: float = 30.0, + ) -> None: + self.base_url = base_url + self.output_md = output_md + self.assets_dir = assets_dir + self.download_assets = download_assets + self.timeout = timeout + self.session = requests.Session() + self.session.headers.update({"User-Agent": USER_AGENT}) + self.asset_map: dict[str, str] = {} + self.asset_counter = 0 + + def fetch_html(self) -> str: + response = self.session.get(self.base_url, timeout=self.timeout) + response.raise_for_status() + # Distill pages sometimes produce mojibake via default charset guessing. + return response.content.decode("utf-8", errors="replace") + + def _relative_asset_path(self, local_path: Path) -> str: + return local_path.relative_to(self.output_md.parent).as_posix() + + def _unique_asset_filename(self, remote_url: str) -> str: + parsed = urlparse(remote_url) + basename = Path(parsed.path).name or f"asset_{self.asset_counter}" + if "." not in basename: + basename = f"{basename}.bin" + candidate = basename + while (self.assets_dir / candidate).exists(): + self.asset_counter += 1 + stem = Path(basename).stem + suffix = Path(basename).suffix + candidate = f"{stem}_{self.asset_counter}{suffix}" + return candidate + + def download_asset(self, src: str) -> str: + if not src: + return "" + remote_url = urljoin(self.base_url, src) + if remote_url in self.asset_map: + return self.asset_map[remote_url] + if remote_url.startswith("data:"): + return remote_url + local_rel = remote_url + if self.download_assets: + self.assets_dir.mkdir(parents=True, exist_ok=True) + preferred = Path(urlparse(remote_url).path).name + if preferred and "." in preferred and (self.assets_dir / preferred).exists(): + local_path = self.assets_dir / preferred + else: + filename = self._unique_asset_filename(remote_url) + local_path = self.assets_dir / filename + response = self.session.get(remote_url, timeout=self.timeout) + response.raise_for_status() + local_path.write_bytes(response.content) + local_rel = self._relative_asset_path(local_path) + self.asset_map[remote_url] = local_rel + return local_rel + + def render_inline(self, node: Tag | NavigableString) -> str: + if isinstance(node, NavigableString): + return str(node) + if not isinstance(node, Tag): + return "" + name = node.name.lower() + if name == "br": + return " \n" + if name in {"b", "strong"}: + return f"**{self.render_children_inline(node.children)}**" + if name in {"i", "em"}: + return f"*{self.render_children_inline(node.children)}*" + if name == "code": + text = normalize_whitespace(self.render_children_inline(node.children)) + return f"`{text}`" if text else "" + if name == "a": + href = (node.get("href") or "").strip() + text = normalize_whitespace(self.render_children_inline(node.children)) + if not text: + text = href + if not href: + return text + full_href = urljoin(self.base_url, href) + return f"[{text}]({full_href})" + if name == "d-cite": + cite = normalize_whitespace(self.render_children_inline(node.children)) + return f"[{cite}]" if cite else "[citation]" + if name == "d-footnote": + note = normalize_whitespace(self.render_children_inline(node.children)) + return f" (Footnote: {note}) " if note else "" + return self.render_children_inline(node.children) + + def render_children_inline(self, nodes: Iterable[Tag | NavigableString]) -> str: + pieces = [self.render_inline(child) for child in nodes] + raw = "" + for piece in pieces: + if not piece: + continue + if raw and should_insert_space(raw, piece): + raw += " " + raw += piece + # Keep explicit markdown line breaks but normalize other whitespace. + parts = raw.split(" \n") + return " \n".join(normalize_whitespace(part) for part in parts) + + def render_list(self, list_tag: Tag, level: int = 0) -> list[str]: + lines: list[str] = [] + ordered = list_tag.name.lower() == "ol" + counter = 1 + indent = " " * level + for li in list_tag.find_all("li", recursive=False): + inline_nodes: list[Tag | NavigableString] = [] + nested_lists: list[Tag] = [] + for child in li.children: + if isinstance(child, Tag) and child.name and child.name.lower() in {"ul", "ol"}: + nested_lists.append(child) + else: + inline_nodes.append(child) + text = normalize_whitespace(self.render_children_inline(inline_nodes)) + marker = f"{counter}." if ordered else "-" + if text: + lines.append(f"{indent}{marker} {text}") + else: + lines.append(f"{indent}{marker}") + for nested in nested_lists: + lines.extend(self.render_list(nested, level + 1)) + counter += 1 + lines.append("") + return lines + + def render_table(self, table: Tag) -> list[str]: + rows: list[list[str]] = [] + for tr in table.find_all("tr"): + row: list[str] = [] + cells = tr.find_all(["th", "td"]) + for cell in cells: + row.append(normalize_whitespace(self.render_children_inline(cell.children))) + if row: + rows.append(row) + if not rows: + return [] + width = max(len(row) for row in rows) + padded = [row + [""] * (width - len(row)) for row in rows] + header = padded[0] + sep = ["---"] * width + lines = [ + "| " + " | ".join(header) + " |", + "| " + " | ".join(sep) + " |", + ] + for row in padded[1:]: + lines.append("| " + " | ".join(row) + " |") + lines.append("") + return lines + + def render_figure(self, figure: Tag) -> list[str]: + lines: list[str] = [] + imgs = figure.find_all("img") + for img in imgs: + src = (img.get("src") or "").strip() + alt = normalize_whitespace(img.get("alt") or "") + local_src = self.download_asset(src) + if local_src: + lines.append(f"![{alt}]({local_src})") + caption = figure.find("figcaption") + if caption: + caption_text = normalize_whitespace(self.render_children_inline(caption.children)) + if caption_text: + lines.append(f"_Figure: {caption_text}_") + if lines: + lines.append("") + return lines + + def render_block(self, node: Tag | NavigableString) -> list[str]: + if isinstance(node, NavigableString): + text = normalize_whitespace(str(node)) + return [text, ""] if text else [] + if not isinstance(node, Tag): + return [] + name = node.name.lower() + if name in {"style", "script"}: + return [] + if name in {"h1", "h2", "h3", "h4", "h5", "h6"}: + level = int(name[1]) + text = normalize_whitespace(self.render_children_inline(node.children)) + if not text: + return [] + anchor = node.get("id") + anchor_suffix = f" " if anchor else "" + return [f"{'#' * level} {text}{anchor_suffix}", ""] + if name == "p": + text = normalize_whitespace(self.render_children_inline(node.children)) + return [text, ""] if text else [] + if name in {"ul", "ol"}: + return self.render_list(node) + if name == "figure": + return self.render_figure(node) + if name == "table": + return self.render_table(node) + if name == "hr": + return ["---", ""] + if name == "br": + return [""] + if name == "d-contents": + return [] + if name in {"div", "section", "nav", "d-appendix", "d-article"}: + lines: list[str] = [] + for child in node.children: + lines.extend(self.render_block(child)) + return lines + text = normalize_whitespace(self.render_children_inline(node.children)) + return [text, ""] if text else [] + + def export(self, include_appendix: bool = True) -> tuple[str, int]: + html = self.fetch_html() + soup = BeautifulSoup(html, "html.parser") + title = normalize_whitespace(soup.title.string if soup.title else "") or "Untitled" + article = soup.find("d-article") + if article is None: + raise RuntimeError("Could not find in the page") + lines: list[str] = [ + f"# {title}", + "", + f"Source: {self.base_url}", + "", + "> Auto-generated by scripts/parse_transformer_circuits_post.py", + "", + ] + for child in article.children: + lines.extend(self.render_block(child)) + if include_appendix: + appendix = soup.find("d-appendix") + if appendix is not None: + appendix_lines: list[str] = [] + for child in appendix.children: + appendix_lines.extend(self.render_block(child)) + appendix_content = [line for line in appendix_lines if line.strip()] + if appendix_content: + lines.extend(["## Appendix", ""]) + lines.extend(appendix_lines) + # Strip trailing whitespace and collapse excessive blank lines. + cleaned: list[str] = [] + blank_run = 0 + for line in lines: + stripped = line.rstrip() + if not stripped: + blank_run += 1 + if blank_run <= 1: + cleaned.append("") + else: + blank_run = 0 + cleaned.append(stripped) + markdown = "\n".join(cleaned).strip() + "\n" + self.output_md.parent.mkdir(parents=True, exist_ok=True) + self.output_md.write_text(markdown, encoding="utf-8") + return markdown, len(self.asset_map) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--url", default=DEFAULT_URL, help="Post URL to parse") + parser.add_argument( + "--output-md", + default="papers/biology_source/biology.md", + help="Path for generated markdown", + ) + parser.add_argument( + "--assets-dir", + default="papers/biology_source/assets", + help="Directory for downloaded assets", + ) + parser.add_argument( + "--skip-appendix", + action="store_true", + help="Do not include d-appendix content", + ) + parser.add_argument( + "--skip-download-assets", + action="store_true", + help="Keep remote asset links instead of downloading files", + ) + parser.add_argument("--timeout", type=float, default=30.0, help="HTTP timeout in seconds") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + output_md = Path(args.output_md).resolve() + assets_dir = Path(args.assets_dir).resolve() + exporter = Exporter( + base_url=args.url, + output_md=output_md, + assets_dir=assets_dir, + download_assets=not args.skip_download_assets, + timeout=args.timeout, + ) + _, asset_count = exporter.export(include_appendix=not args.skip_appendix) + print(f"Wrote markdown: {output_md}") + if args.skip_download_assets: + print("Assets were not downloaded (--skip-download-assets set).") + else: + print(f"Downloaded/linked assets: {asset_count}") + print(f"Assets dir: {assets_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/render_circuit_html.py b/scripts/render_circuit_html.py new file mode 100644 index 000000000..677c9cb57 --- /dev/null +++ b/scripts/render_circuit_html.py @@ -0,0 +1,69 @@ +"""Render a circuit JSON file as a self-contained HTML page. + +This is a thin wrapper that copies the circuit.html template and embeds the JSON +data inline, so the result is a single self-contained HTML file (no separate data file needed). + +Usage: + python scripts/render_circuit_html.py data/king_circuit.json -o circuit_standalone.html + +Or from Python: + from scripts.render_circuit_html import render_circuit_html + render_circuit_html(Path("data/king_circuit.json"), Path("circuit_standalone.html")) +""" + +import argparse +import json +from pathlib import Path + + +def render_circuit_html(json_path: Path, output_path: Path, title: str = "Circuit Graph") -> None: + """Render a circuit JSON as a self-contained HTML file. + + Takes the interactive circuit.html template and replaces the fetch() call + with inline data, producing a single portable HTML file. + """ + with open(json_path) as f: + data = json.load(f) + + # Read the template + template_path = Path(__file__).parent.parent / "scripts" / "_circuit_template.html" + if not template_path.exists(): + # Fallback: read from www + from spd.settings import SPD_OUT_DIR + + template_path = SPD_OUT_DIR / "www" / "pile-editing" / "circuit.html" + + assert template_path.exists(), f"Template not found: {template_path}" + template = template_path.read_text() + + # Replace the fetch with inline data + inline_js = f"data = {json.dumps(data)}; init();" + template = template.replace( + "fetch(DATA_URL)\n" + " .then(r => { if (!r.ok) throw new Error(`HTTP ${r.status}`); return r.json(); })\n" + " .then(d => { data = d; init(); })\n" + " .catch(e => { document.getElementById('stats').textContent = `Error: ${e.message}`; });", + inline_js, + ) + + # Update title + template = template.replace( + 'Circuit Graph — King → "he"', f"{title}" + ) + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + f.write(template) + + print(f"Wrote self-contained HTML to {output_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Render circuit JSON as self-contained HTML") + parser.add_argument("json_path", type=Path, help="Path to circuit JSON file") + parser.add_argument( + "-o", "--output", type=Path, default=Path("circuit.html"), help="Output HTML path" + ) + parser.add_argument("--title", default="Circuit Graph", help="Page title") + args = parser.parse_args() + render_circuit_html(args.json_path, args.output, args.title) diff --git a/scripts/test_abs_grad_trick.py b/scripts/test_abs_grad_trick.py new file mode 100644 index 000000000..b3d15d296 --- /dev/null +++ b/scripts/test_abs_grad_trick.py @@ -0,0 +1,150 @@ +"""Verify that ∂|y|/∂x = sign(y) · ∂y/∂x for a scalar y, even through nonlinearities. + +The chain rule: ∂|y|/∂x = (d|y|/dy) · (∂y/∂x) = sign(y) · ∂y/∂x + +This holds regardless of what nonlinear computation sits between x and y, +because ∂y/∂x already accounts for all intermediate nonlinearities. +The sign(y) factor is just the outermost link in the chain. +""" + +import torch +from torch import nn + + +def test_simple_linear(): + """Linear: y = Wx, trivial case.""" + x = torch.randn(5, requires_grad=True) + W = torch.randn(3, 5) + y_vec = W @ x + y = y_vec[1] # pick one scalar + + grad = torch.autograd.grad(y, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] + grad_trick = y.sign() * grad + + assert torch.allclose(grad_abs, grad_trick, atol=1e-7), ( + f"FAIL: {(grad_abs - grad_trick).abs().max()}" + ) + print(f" linear: max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓") + + +def test_deep_nonlinear(): + """Deep net with ReLU, tanh, and GELU — representative of a transformer.""" + torch.manual_seed(42) + net = nn.Sequential( + nn.Linear(8, 16), + nn.ReLU(), + nn.Linear(16, 16), + nn.Tanh(), + nn.Linear(16, 16), + nn.GELU(), + nn.Linear(16, 4), + ) + x = torch.randn(8, requires_grad=True) + y_vec = net(x) + y = y_vec[2] # scalar output + + grad = torch.autograd.grad(y, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] + grad_trick = y.sign() * grad + + assert torch.allclose(grad_abs, grad_trick, atol=1e-6), ( + f"FAIL: {(grad_abs - grad_trick).abs().max()}" + ) + print( + f" deep nonlinear (y={y.item():.4f}): max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓" + ) + + +def test_negative_target(): + """Ensure it works when y < 0 (sign flips the gradient).""" + torch.manual_seed(99) + net = nn.Sequential(nn.Linear(4, 8), nn.Tanh(), nn.Linear(8, 1)) + # Find an input that gives negative output + for _seed in range(200): + x = torch.randn(4, requires_grad=True) + y = net(x).squeeze() + if y.item() < -0.1: + break + assert y.item() < 0, "Couldn't find negative output" + + grad = torch.autograd.grad(y, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(y.abs(), x, retain_graph=True)[0] + grad_trick = y.sign() * grad + + assert torch.allclose(grad_abs, grad_trick, atol=1e-6), ( + f"FAIL: {(grad_abs - grad_trick).abs().max()}" + ) + print( + f" negative target (y={y.item():.4f}): max diff = {(grad_abs - grad_trick).abs().max():.2e} ✓" + ) + + +def test_multiple_inputs(): + """Multiple input tensors (mirrors the app's in_post_detaches list).""" + torch.manual_seed(7) + x1 = torch.randn(3, 4, requires_grad=True) + x2 = torch.randn(3, 4, requires_grad=True) + + # Nonlinear function of both inputs + h = torch.relu(x1) + torch.tanh(x2) + y = (h @ torch.randn(4, 1)).sum() # scalar + + grads = torch.autograd.grad(y, [x1, x2], retain_graph=True) + grads_abs = torch.autograd.grad(y.abs(), [x1, x2], retain_graph=True) + + for i, (g, g_abs) in enumerate(zip(grads, grads_abs, strict=True)): + g_trick = y.sign() * g + assert torch.allclose(g_abs, g_trick, atol=1e-6), ( + f"FAIL input {i}: {(g_abs - g_trick).abs().max()}" + ) + + print(f" multiple inputs (y={y.item():.4f}): all match ✓") + + +def test_sum_of_abs_DOES_NOT_work(): + """Show that the trick FAILS for sum-of-abs (dataset attributions case). + + ∂(Σ|y_i|)/∂x ≠ sign(Σy_i) · ∂(Σy_i)/∂x + because each y_i has a different sign. + """ + torch.manual_seed(42) + x = torch.randn(4, requires_grad=True) + W = torch.randn(3, 4) + y_vec = W @ x # [3] + + target_signed = y_vec.sum() + target_abs = y_vec.abs().sum() + + grad_signed = torch.autograd.grad(target_signed, x, retain_graph=True)[0] + grad_abs = torch.autograd.grad(target_abs, x, retain_graph=True)[0] + + # The WRONG trick: use sign of the sum + grad_wrong = target_signed.sign() * grad_signed + + # The correct per-element version + grad_correct = sum( + y_vec[i].sign() * torch.autograd.grad(y_vec[i], x, retain_graph=True)[0] + for i in range(len(y_vec)) + ) + + wrong_diff = (grad_abs - grad_wrong).abs().max() + correct_diff = (grad_abs - grad_correct).abs().max() + print( + f" sum-of-abs: wrong trick diff = {wrong_diff:.4f}, correct per-element diff = {correct_diff:.2e}" + ) + assert wrong_diff > 0.01, "Expected the wrong trick to fail for sum-of-abs" + assert correct_diff < 1e-6, "Per-element version should match" + print(" → confirms: trick works for scalar y, NOT for sum-of-abs ✓") + + +if __name__ == "__main__": + print("Testing ∂|y|/∂x = sign(y) · ∂y/∂x for scalar y:\n") + test_simple_linear() + test_deep_nonlinear() + test_negative_target() + test_multiple_inputs() + print() + print("Testing that the trick does NOT work for sum-of-abs:\n") + test_sum_of_abs_DOES_NOT_work() + print("\nAll tests passed.") diff --git a/spd/app/CLAUDE.md b/spd/app/CLAUDE.md index 0f95d54a8..42f745de6 100644 --- a/spd/app/CLAUDE.md +++ b/spd/app/CLAUDE.md @@ -4,17 +4,18 @@ Web-based visualization and analysis tool for exploring neural network component - **Backend**: Python FastAPI (`backend/`) - **Frontend**: Svelte 5 + TypeScript (`frontend/`) -- **Database**: SQLite at `.data/app/prompt_attr.db` (relative to repo root) +- **Database**: SQLite at `SPD_OUT_DIR/app/prompt_attr.db` (shared across team via NFS) - **TODOs**: See `TODO.md` for open work items ## Project Context This is a **rapidly iterated research tool**. Key implications: -- **Please do not code for backwards compatibility**: Schema changes don't need migrations, expect state can be deleted, etc. -- **Database is disposable**: Delete `.data/app/prompt_attr.db` if schema changes break things +- **Please do not code for backwards compatibility**: Schema changes don't need migrations +- **Database is shared state**: Lives at `SPD_OUT_DIR/app/prompt_attr.db` on NFS, accessible by multiple backends. Do not delete without checking with the team. Uses DELETE journal mode (NFS-safe) with `fcntl.flock` write locking for concurrent access - **Prefer simplicity**: Avoid over-engineering for hypothetical future needs - **Fail loud and fast**: The users are a small team of highly technical people. Errors are good. We want to know immediately if something is wrong. No soft failing, assert, assert, assert +- **Token display**: Always ship token strings rendered server-side via `AppTokenizer`, never raw token IDs. For embed/output layers, `component_idx` is a token ID — resolve it to a display string in the backend response. ## Running the App @@ -34,14 +35,14 @@ This launches both backend (FastAPI/uvicorn) and frontend (Vite) dev servers. backend/ ├── server.py # FastAPI app, CORS, routers ├── state.py # Singleton StateManager + HarvestRepo (lazy-loaded harvest data) -├── compute.py # Core attribution computation +├── compute.py # Core attribution computation + intervention evaluation ├── app_tokenizer.py # AppTokenizer: wraps HF tokenizers for display/encoding ├── (topology lives at spd/topology.py — TransformerTopology) ├── schemas.py # Pydantic API models ├── dependencies.py # FastAPI dependency injection ├── utils.py # Logging/timing utilities ├── database.py # SQLite interface -├── optim_cis.py # Sparse CI optimization +├── optim_cis.py # Sparse CI optimization, loss configs, PGD └── routers/ ├── runs.py # Load W&B runs + GET /api/model_info ├── graphs.py # Compute attribution graphs @@ -50,6 +51,9 @@ backend/ ├── intervention.py # Selective component activation ├── correlations.py # Component correlations + token stats + interpretations ├── clusters.py # Component clustering + ├── dataset_search.py # SimpleStories dataset search + ├── agents.py # Various useful endpoints that AI agents should look at when helping + ├── mcp.py # MCP (Model Context Protocol) endpoint for Claude Code ├── dataset_search.py # Dataset search (reads dataset from run config) └── agents.py # Various useful endpoints that AI agents should look at when helping ``` @@ -90,7 +94,7 @@ frontend/src/ ├── ActivationContextsPagedTable.svelte ├── DatasetSearchTab.svelte # Dataset search UI ├── DatasetSearchResults.svelte - ├── ClusterPathInput.svelte # Cluster path selector + ├── ClusterPathInput.svelte # Cluster path selector (dropdown populated from registry.ts) ├── ComponentProbeInput.svelte # Component probe UI ├── TokenHighlights.svelte # Token highlighting ├── prompt-attr/ @@ -154,8 +158,13 @@ Edge(source: Node, target: Node, strength: float, is_cross_seq: bool) # strength = gradient * activation # is_cross_seq = True for k/v → o_proj (attention pattern) -PromptAttributionResult(edges: list[Edge], output_probs: Tensor[seq, vocab], node_ci_vals: dict[str, float]) -# node_ci_vals maps "layer:seq:c_idx" → CI value +PromptAttributionResult(edges, ci_masked_out_logits, target_out_logits, node_ci_vals, node_subcomp_acts) + +TokenPrediction(token, token_id, prob, logit, target_prob, target_logit) + +InterventionResult(input_tokens, ci, stochastic, adversarial, ci_loss, stochastic_loss, adversarial_loss) +# ci/stochastic/adversarial are list[list[TokenPrediction]] (per-position top-k) +# losses are evaluated using the graph's implied loss context ``` ### Frontend Types (`promptAttributionsTypes.ts`) @@ -211,13 +220,31 @@ Finds sparse CI mask that: - Minimizes L0 (active component count) - Uses importance minimality + CE loss (or KL loss) -### Intervention Forward +### Interventions (`compute.py → compute_intervention`) + +A single unified function evaluates a node selection under three masking regimes: + +- **CI**: mask = selection (binary on/off) +- **Stochastic**: mask = selection + (1-selection) × Uniform(0,1) +- **Adversarial**: PGD optimizes alive-but-unselected components to maximize loss; non-alive get Uniform(0,1) + +Returns `InterventionResult` with top-k `TokenPrediction`s per position for each regime, plus per-regime loss values. + +**Loss context**: Every graph has an implied loss that interventions evaluate against: -`compute_intervention_forward()`: +- **Standard/manual graphs** → `MeanKLLossConfig` (mean KL divergence from target across all positions) +- **Optimized graphs** → the graph's optimization loss (CE for a specific token at a position, or KL at a position) -1. Build component masks (all zeros) -2. Set mask=1.0 for selected nodes -3. Forward pass → top-k predictions per position +This loss is used for two things: (1) what PGD maximizes during adversarial evaluation, and (2) the `ci_loss`/`stochastic_loss`/`adversarial_loss` metrics reported in `InterventionResult`. + +**Alive masks**: `compute_intervention` recomputes the model's natural CI (one forward pass + `calc_causal_importances`) and binarizes at 0 to get alive masks. This ensures the alive set is always the full model's CI — not the graph's potentially sparse optimized CI. PGD can only manipulate alive-but-unselected components. + +**Training PGD vs Eval PGD**: The PGD settings in the graph optimization config (`adv_pgd_n_steps`, +`adv_pgd_step_size`) are a _training_ regularizer — they make CI optimization robust. The PGD in +`compute_intervention` is an _eval_ metric — it measures worst-case performance for a given node +selection. Eval PGD defaults are in `compute.py` (`DEFAULT_EVAL_PGD_CONFIG`). + +**Base intervention run**: Created automatically during graph computation. Uses all interventable nodes with CI > 0. Persisted as an `intervention_run` so predictions are available synchronously. --- @@ -245,9 +272,14 @@ POST /api/graphs ### Intervention ``` -POST /api/intervention {text, nodes: ["h.0.attn.q_proj:3:5", ...]} - → compute_intervention_forward() - ← InterventionResponse with top-k predictions +POST /api/intervention/run {graph_id, selected_nodes, top_k, adv_pgd} + → compute_intervention(active_nodes, graph_alive_masks, loss_config) + ← InterventionRunSummary {id, selected_nodes, result: InterventionResult} + +InterventionResult = { + input_tokens, ci, stochastic, adversarial, // TokenPrediction[][] per regime + ci_loss, stochastic_loss, adversarial_loss // loss under each regime +} ``` ### Component Correlations & Interpretations @@ -281,14 +313,14 @@ GET /api/dataset/results?page=1&page_size=20 ## Database Schema -Located at `.data/app/prompt_attr.db`. Delete this file if schema changes cause issues. +Located at `SPD_OUT_DIR/app/prompt_attr.db` (shared via NFS). Uses DELETE journal mode with `fcntl.flock` write locking for safe concurrent access from multiple backends. -| Table | Key | Purpose | -| ------------------ | ---------------------------------- | ------------------------------------------------- | -| `runs` | `wandb_path` | W&B run references | -| `prompts` | `(run_id, context_length)` | Token sequences | -| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + output probs + node CI values | -| `intervention_runs`| `graph_id` | Saved intervention results | +| Table | Key | Purpose | +| ------------------- | ---------------------------------- | -------------------------------------------------------- | +| `runs` | `wandb_path` | W&B run references | +| `prompts` | `(run_id, context_length)` | Token sequences | +| `graphs` | `(prompt_id, optimization_params)` | Attribution edges + CI/target logits + node CI values | +| `intervention_runs` | `graph_id` | Saved `InterventionResult` JSON (single `result` column) | Note: Activation contexts, correlations, token stats, and interpretations are loaded from pre-harvested data at `SPD_OUT_DIR/{harvest,autointerp}/` (see `spd/harvest/` and `spd/autointerp/`). diff --git a/spd/app/TODO.md b/spd/app/TODO.md index a3c7eb1aa..21a8ba1fd 100644 --- a/spd/app/TODO.md +++ b/spd/app/TODO.md @@ -1,3 +1,172 @@ -# App TODOs +# App Backend Review & Action Items -- Audit SQLite access pragma stuff — `immutable=1` in `HarvestDB` causes "database disk image is malformed" errors when the app reads a harvest DB mid-write (WAL not yet checkpointed). Investigate whether to check for WAL file existence, use normal locking mode, or add another safeguard. See `spd/harvest/db.py:79`. +Review date: 2026-03-04. Scope: `spd/app/backend/` — all Python files. + +Context: the app is a **researcher-first local tool** (frontend + backend launched together, opened in browser). Errors should be loud, silent failures absent, the prompt DB is deletable short-term state, no backwards compatibility needed. + +## Overview + +The backend is ~6,500 lines across 18 Python files. The core architecture (FastAPI + SQLite + singleton state + SSE streaming) is sound. The main concerns are: a few real bugs, accumulated dead code, some silent failures that violate the "loud errors" principle, and a few design seams where complexity hides. + +### File size inventory + +| File | Lines | Risk | +|---|---|---| +| `routers/mcp.py` | 1637 | High — mixed concerns, largest file | +| `routers/graphs.py` | 1036 | Medium — streaming complexity | +| `compute.py` | 920 | Low — core algorithm, well-structured | +| `database.py` | 827 | Medium — manual serialization | +| `optim_cis.py` | 504 | Low | +| `routers/dataset_search.py` | 473 | Medium — hardcoded dataset names | +| `routers/correlations.py` | 386 | Low | +| `routers/graph_interp.py` | 373 | Low | +| `routers/investigations.py` | 317 | Low | +| `routers/pretrain_info.py` | 246 | Low | +| `server.py` | 212 | Low — clean | +| `routers/activation_contexts.py` | 207 | Low | +| `routers/runs.py` | 191 | Low | +| `routers/dataset_attributions.py` | 170 | Low | +| `routers/intervention.py` | 169 | Low — clean | +| `state.py` | 132 | Low — clean | +| `app_tokenizer.py` | 119 | Low | +| `routers/prompts.py` | 115 | Low | + +--- + +## Bugs + +### 1. `dataset_search.py:262` — KeyError on tokenized results + +`get_tokenized_results` accesses `result["story"]` but `search_dataset` stores results with key `"text"` (line 137: `results.append({"text": text, ...})`). This will crash with `KeyError: 'story'` whenever tokenized results are requested. + +**Fix:** Change line 262 from `result["story"]` to `result["text"]`. Also line 287: the metadata exclusion list references `"story"` — should be `"text"`. + +### 2. `dataset_search.py` — random endpoints hardcode SimpleStories + +`get_random_samples` (line 336) and `get_random_samples_with_loss` (line 415) both hardcode `load_dataset("lennart-finke/SimpleStories", ...)` and access `item_dict["story"]`. Since primary models are now Pile-trained, these endpoints are broken for current research. They also don't use `DepLoadedRun` to get the dataset name from the run config like `search_dataset` does. + +**Fix:** Make them take `DepLoadedRun`, read `task_config.dataset_name` and `task_config.column_name`, and use those instead of hardcoded values. Or, if the random endpoints aren't used with Pile models, consider deleting them. + +--- + +## Dead code to delete + +### 3. `ForkedInterventionRunRecord` + `forked_intervention_runs` table + +`database.py:117-125` defines `ForkedInterventionRunRecord`. Lines 256-265 create the `forked_intervention_runs` table. Lines 744-827 implement 3 CRUD methods (`save_forked_intervention_run`, `get_forked_intervention_runs`, `delete_forked_intervention_run`). No router references any of these — the fork endpoints were removed. Delete all of it. + +Files: `database.py` + +### 4. `optim_cis.py:500-504` — `get_out_dir()` never called + +Dead utility function that creates a local `out/` directory. Nothing references it. + +Files: `optim_cis.py` + +### 5. Unused schemas in `graphs.py:188-209` + +`ComponentStats`, `PromptSearchQuery`, and `PromptSearchResponse` are defined but no endpoint uses them. They appear to be leftovers from a removed prompt search feature. The `PromptPreview` in `graphs.py:114` also duplicates the one in `prompts.py:25`. + +Files: `routers/graphs.py` + +### 6. `spd/app/TODO.md` was empty + +(This file — now repurposed for this review.) + +--- + +## Design issues + +### 7. `OptimizationParams` mixes config inputs with computed outputs + +`database.py:69-82` — Fields like `imp_min_coeff`, `steps`, `pnorm`, `beta` are optimization *inputs*. Fields like `ci_masked_label_prob`, `stoch_masked_label_prob`, `adv_pgd_label_prob` are computed *outputs*. These metrics are mutated in-place after construction in `graphs.py:759-761`. + +This makes the object's contract unclear — is it immutable config or mutable state? + +**Suggestion:** Either nest the metrics in a sub-model (`metrics: OptimMetrics | None`), or at minimum stop mutating after construction (compute the metrics before constructing `OptimizationParams`). + +### 8. `StoredGraph.id = -1` sentinel value + +`database.py:90` uses `-1` as "unsaved graph". If a graph is accidentally used before being saved, that `-1` leaks into API responses or DB queries. `id: int | None = None` is more honest and lets the type system catch misuse. + +### 9. GPU lock accessed two different ways + +- `graphs.py:603,844` — `stream_computation(work, manager._gpu_lock)` reaches into the private lock directly +- `intervention.py:86` — `with manager.gpu_lock():` uses the context manager + +The stream pattern is inherently different (hold lock across SSE generator lifetime), but accessing `_gpu_lock` directly breaks encapsulation. + +**Suggestion:** Add a `stream_with_gpu_lock(work)` method on `StateManager` that encapsulates the lock acquisition + SSE streaming pattern. Then `graphs.py` calls `manager.stream_with_gpu_lock(work)` instead of reaching into privates. + +### 10. `load_run` returns untyped dicts + +`runs.py:96,139` returns `{"status": "loaded", "run_id": ...}` and `{"status": "already_loaded", ...}`. No response model, so the frontend has no type-safe contract for this endpoint. + +**Fix:** Define a `LoadRunResponse(BaseModel)` with `status`, `run_id`, `wandb_path`. + +### 11. Edge truncation is invisible to the user + +`graphs.py:903` logs a warning when edges exceed `GLOBAL_EDGE_LIMIT = 50_000` and are truncated, but this info only goes to server logs. The researcher never sees it. + +**Fix:** Add `edges_truncated: bool` (or `total_edge_count: int`) to `GraphData` so the frontend can show a notice. + +### 12. Module-level `DEVICE = get_device()` in multiple files + +`graphs.py:266`, `intervention.py:48`, `dataset_search.py`, `prompts.py:18` all call `get_device()` at import time. Fine in practice but makes testing and non-GPU imports impossible. + +**Suggestion:** Move to a function call or lazily-evaluated property when/if this becomes a testing bottleneck. Low priority. + +### 13. `_GRAPH_INTERP_MOCK_MODE` cross-router import + +`runs.py:13` imports `MOCK_MODE` from `routers/graph_interp.py` and uses it in the status endpoint (line 174). The TODO comment says to remove it. This cross-router dependency for a mock flag should be cleaned up — the mock mode should either be a config flag on `StateManager` or deleted entirely. + +--- + +## Silent failure patterns (violate "loud errors" principle) + +### 14. `compute.py:79-86` — output node capping is silent + +`compute_layer_alive_info` caps output nodes to `MAX_OUTPUT_NODES_PER_POS = 15` per position without any logging or indication. If a researcher has >15 high-probability output tokens at a position, they silently lose some. + +At minimum, log when capping occurs. + +### 15. `correlations.py:291,302` — token stats returns `None` silently + +`get_component_token_stats` returns `None` when token stats haven't been harvested. This means the endpoint returns a `200 null` response, which the frontend has to special-case. An explicit 404 with a message is more honest. + +### 16. `correlations.py:112,260` — interpretations/intruder scores return `{}` silently + +`get_all_interpretations` and `get_intruder_scores` return empty dicts when data isn't available. This is defensible for bulk endpoints (the frontend can check emptiness), but it means the researcher has no way to distinguish "no interpretations exist" from "interpretations not yet generated." Consider logging or adding a `has_interpretations` flag to `LoadedRun`. + +Note: `LoadedRun.autointerp_available` already partially addresses this. But the endpoints themselves don't use it — they independently check `loaded.interp is None`. + +--- + +## Lower priority / nice-to-haves + +### 17. `extract_node_ci_vals` Python double loop + +`compute.py:640-648` iterates every `(seq_pos, component_idx)` pair in Python. For large models (39K components × 512 seq), this is a lot of Python overhead. Could be vectorized to only extract non-zero entries. + +### 18. `database.py` manual graph get-or-create race + +Lines 528-539: catches `IntegrityError` on manual graph save, then re-queries. There's a small race window between the failed insert and the re-query. Acceptable for a single-user local app but worth noting. + +### 19. `mcp.py` is 1637 lines + +The MCP router is the largest file, mixing tool definitions, implementation logic, and JSON-RPC handling. It has module-level global state (`_investigation_config`). This file would benefit from being split, but it's also likely to be rewritten when MCP tooling matures, so the ROI of refactoring now is debatable. + +--- + +## Suggested priority order for implementation + +1. Fix `result["story"]` KeyError (bug #1) — 2 min +2. Delete dead code (items #3-5) — 10 min +3. Fix random dataset endpoints or delete if unused (#2) — 15 min +4. Add `edges_truncated` to GraphData (#11) — 10 min +5. Type the `load_run` response (#10) — 5 min +6. Clean up `_GRAPH_INTERP_MOCK_MODE` (#13) — 5 min +7. Deduplicate `MAX_OUTPUT_NODES_PER_POS` (#5 partial) — 2 min +8. `StoredGraph.id` sentinel → `None` (#8) — 10 min +9. Split `OptimizationParams` (#7) — 20 min +10. GPU lock encapsulation (#9) — 15 min diff --git a/spd/app/backend/app_tokenizer.py b/spd/app/backend/app_tokenizer.py index 0d79cd9ba..acfa4d7eb 100644 --- a/spd/app/backend/app_tokenizer.py +++ b/spd/app/backend/app_tokenizer.py @@ -53,6 +53,12 @@ def vocab_size(self) -> int: assert isinstance(size, int) return size + @property + def eos_token_id(self) -> int: + eos = self._tok.eos_token_id + assert isinstance(eos, int) + return eos + def encode(self, text: str) -> list[int]: return self._tok.encode(text, add_special_tokens=False) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 8992e0e06..c99de5a02 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -12,12 +12,26 @@ import torch from jaxtyping import Bool, Float +from pydantic import BaseModel from torch import Tensor, nn from spd.app.backend.app_tokenizer import AppTokenizer -from spd.app.backend.optim_cis import OptimCIConfig, OptimizationMetrics, optimize_ci_values +from spd.app.backend.optim_cis import ( + AdvPGDConfig, + CELossConfig, + CISnapshotCallback, + LogitLossConfig, + LossConfig, + OptimCIConfig, + OptimizationMetrics, + compute_recon_loss, + optimize_ci_values, + optimize_ci_values_batched, + run_adv_pgd, +) from spd.configs import SamplingType from spd.log import logger +from spd.metrics.pgd_utils import interpolate_pgd_mask from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.topology import TransformerTopology @@ -115,6 +129,7 @@ class PromptAttributionResult: """Result of computing prompt attributions for a prompt.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities @@ -128,6 +143,7 @@ class OptimizedPromptAttributionResult: """Result of computing prompt attributions with optimized CI values.""" edges: list[Edge] + edges_abs: list[Edge] # absolute-target variant: ∂|y|/∂x · x ci_masked_out_probs: Float[Tensor, "seq vocab"] # CI-masked (SPD model) softmax probabilities ci_masked_out_logits: Float[Tensor, "seq vocab"] # CI-masked (SPD model) raw logits target_out_probs: Float[Tensor, "seq vocab"] # Target model softmax probabilities @@ -135,7 +151,6 @@ class OptimizedPromptAttributionResult: node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val node_subcomp_acts: dict[str, float] # layer:seq:c_idx -> subcomponent activation (v_i^T @ a) metrics: OptimizationMetrics # Final loss metrics from optimization - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None # Adversarial PGD output logits ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) @@ -168,17 +183,22 @@ def _compute_edges_for_target( cache: dict[str, Tensor], loss_seq_pos: int, topology: TransformerTopology, -) -> list[Edge]: +) -> tuple[list[Edge], list[Edge]]: """Compute all edges flowing into a single target layer. For each alive (s_out, c_out) in the target layer, computes gradient-based - attribution strengths from all alive source components. + attribution strengths from all alive source components. Computes both signed + (∂y/∂x · x) and absolute-target (∂|y|/∂x · x) variants. Args: loss_seq_pos: Maximum sequence position to include (inclusive). Only compute edges for target positions <= loss_seq_pos. + + Returns: + (edges, edges_abs): Signed and absolute-target edge lists. """ edges: list[Edge] = [] + edges_abs: list[Edge] = [] out_pre_detach: Float[Tensor, "1 s C"] = cache[f"{target}_pre_detach"] in_post_detaches: list[Float[Tensor, "1 s C"]] = [ cache[f"{source}_post_detach"] for source in sources @@ -190,11 +210,19 @@ def _compute_edges_for_target( continue for c_out in s_out_alive_c: + target_val = out_pre_detach[0, s_out, c_out] grads = torch.autograd.grad( - outputs=out_pre_detach[0, s_out, c_out], + outputs=target_val, inputs=in_post_detaches, retain_graph=True, ) + # ∂|y|/∂x = sign(y) · ∂y/∂x — avoids a second backward pass. + # This works because target_val is a single scalar. In dataset_attributions/ + # harvester.py, the target is sum(|y_i|) over batch+seq — there each y_i has a + # different sign, so you can't factor out one scalar. The issue isn't the chain + # rule (sign·grad is always valid per-element), it's that abs breaks the + # grad(sum)=sum(grad) trick that makes the batch reduction a single backward pass. + target_sign = target_val.sign() with torch.no_grad(): canonical_target = topology.target_to_canon(target) for source, source_info, grad, in_post_detach in zip( @@ -203,27 +231,35 @@ def _compute_edges_for_target( canonical_source = topology.target_to_canon(source) is_cross_seq = topology.is_cross_seq_pair(canonical_source, canonical_target) weighted: Float[Tensor, "s C"] = (grad * in_post_detach)[0] + weighted_abs: Float[Tensor, "s C"] = weighted * target_sign if canonical_source == "embed": weighted = weighted.sum(dim=1, keepdim=True) + weighted_abs = weighted_abs.sum(dim=1, keepdim=True) s_in_range = range(s_out + 1) if is_cross_seq else [s_out] for s_in in s_in_range: for c_in in source_info.alive_c_idxs: if not source_info.alive_mask[s_in, c_in]: continue + src = Node(layer=canonical_source, seq_pos=s_in, component_idx=c_in) + tgt = Node(layer=canonical_target, seq_pos=s_out, component_idx=c_out) edges.append( Edge( - source=Node( - layer=canonical_source, seq_pos=s_in, component_idx=c_in - ), - target=Node( - layer=canonical_target, seq_pos=s_out, component_idx=c_out - ), + source=src, + target=tgt, strength=weighted[s_in, c_in].item(), is_cross_seq=is_cross_seq, ) ) - return edges + edges_abs.append( + Edge( + source=src, + target=tgt, + strength=weighted_abs[s_in, c_in].item(), + is_cross_seq=is_cross_seq, + ) + ) + return edges, edges_abs def compute_edges_from_ci( @@ -330,12 +366,13 @@ def compute_edges_from_ci( # Compute edges for each target layer t0 = time.perf_counter() edges: list[Edge] = [] + edges_abs: list[Edge] = [] total_source_layers = sum(len(sources) for sources in sources_by_target.values()) progress_count = 0 for target, sources in sources_by_target.items(): t_target = time.perf_counter() - target_edges = _compute_edges_for_target( + target_edges, target_edges_abs = _compute_edges_for_target( target=target, sources=sources, target_info=alive_info[target], @@ -345,6 +382,7 @@ def compute_edges_from_ci( topology=topology, ) edges.extend(target_edges) + edges_abs.extend(target_edges_abs) canonical_target = topology.target_to_canon(target) logger.info( f"[perf] {canonical_target}: {time.perf_counter() - t_target:.2f}s, " @@ -375,6 +413,7 @@ def compute_edges_from_ci( return PromptAttributionResult( edges=edges, + edges_abs=edges_abs, ci_masked_out_probs=ci_masked_out_probs[0, : loss_seq_pos + 1], ci_masked_out_logits=ci_masked_logits[0, : loss_seq_pos + 1], target_out_probs=target_out_probs[0, : loss_seq_pos + 1], @@ -508,6 +547,7 @@ def compute_prompt_attributions_optimized( output_prob_threshold: float, device: str, on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, ) -> OptimizedPromptAttributionResult: """Compute prompt attributions using optimized sparse CI values. @@ -528,6 +568,7 @@ def compute_prompt_attributions_optimized( config=optim_config, device=device, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) ci_outputs = optim_result.params.create_ci_outputs(model, device) @@ -557,13 +598,9 @@ def compute_prompt_attributions_optimized( loss_seq_pos=loss_seq_pos, ) - # Slice adversarial logits to match the loss_seq_pos range - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None - if optim_result.adv_pgd_out_logits is not None: - adv_pgd_out_logits = optim_result.adv_pgd_out_logits[: loss_seq_pos + 1] - return OptimizedPromptAttributionResult( edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_probs=result.ci_masked_out_probs, ci_masked_out_logits=result.ci_masked_out_logits, target_out_probs=result.target_out_probs, @@ -571,10 +608,81 @@ def compute_prompt_attributions_optimized( node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, metrics=optim_result.metrics, - adv_pgd_out_logits=adv_pgd_out_logits, ) +def compute_prompt_attributions_optimized_batched( + model: ComponentModel, + topology: TransformerTopology, + tokens: Float[Tensor, "1 seq"], + sources_by_target: dict[str, list[str]], + configs: list[OptimCIConfig], + output_prob_threshold: float, + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizedPromptAttributionResult]: + """Compute prompt attributions for multiple sparsity coefficients in one batched optimization.""" + with torch.no_grad(), bf16_autocast(): + target_logits = model(tokens) + target_out_probs = torch.softmax(target_logits, dim=-1) + + optim_results = optimize_ci_values_batched( + model=model, + tokens=tokens, + configs=configs, + device=device, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + if on_progress is not None: + on_progress(0, len(optim_results), "graph") + + with torch.no_grad(), bf16_autocast(): + pre_weight_acts = model(tokens, cache_type="input").cache + + loss_seq_pos = configs[0].loss_config.position + + results: list[OptimizedPromptAttributionResult] = [] + for i, optim_result in enumerate(optim_results): + ci_outputs = optim_result.params.create_ci_outputs(model, device) + + result = compute_edges_from_ci( + model=model, + topology=topology, + tokens=tokens, + ci_lower_leaky=ci_outputs.lower_leaky, + pre_weight_acts=pre_weight_acts, + sources_by_target=sources_by_target, + target_out_probs=target_out_probs, + target_out_logits=target_logits, + output_prob_threshold=output_prob_threshold, + device=device, + on_progress=on_progress, + loss_seq_pos=loss_seq_pos, + ) + + results.append( + OptimizedPromptAttributionResult( + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_probs=result.ci_masked_out_probs, + ci_masked_out_logits=result.ci_masked_out_logits, + target_out_probs=result.target_out_probs, + target_out_logits=result.target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + metrics=optim_result.metrics, + ) + ) + + if on_progress is not None: + on_progress(i + 1, len(optim_results), "graph") + + return results + + @dataclass class CIOnlyResult: """Result of computing CI values only (no attribution graph).""" @@ -666,94 +774,248 @@ def extract_node_subcomp_acts( return node_subcomp_acts -@dataclass -class InterventionResult: - """Result of intervention forward pass.""" +class TokenPrediction(BaseModel): + """A single token prediction with probability.""" + + token: str + token_id: int + prob: float + logit: float + target_prob: float + target_logit: float + + +class LabelPredictions(BaseModel): + """Prediction stats for the CE label token at the optimized position, per masking regime.""" + + position: int + ci: TokenPrediction + stochastic: TokenPrediction + adversarial: TokenPrediction + ablated: TokenPrediction | None + + +class InterventionResult(BaseModel): + """Unified result of an intervention evaluation under multiple masking regimes.""" input_tokens: list[str] - predictions_per_position: list[ - list[tuple[str, int, float, float, float, float]] - ] # [(token, id, spd_prob, logit, target_prob, target_logit)] + ci: list[list[TokenPrediction]] + stochastic: list[list[TokenPrediction]] + adversarial: list[list[TokenPrediction]] + ablated: list[list[TokenPrediction]] | None + ci_loss: float + stochastic_loss: float + adversarial_loss: float + ablated_loss: float | None + label: LabelPredictions | None + + +# Default eval PGD settings (distinct from optimization PGD which is a training regularizer) +DEFAULT_EVAL_PGD_CONFIG = AdvPGDConfig(n_steps=4, step_size=1.0, init="random") + + +def _extract_topk_predictions( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + top_k: int, +) -> list[list[TokenPrediction]]: + """Extract top-k token predictions per position, paired with target probs.""" + probs = torch.softmax(logits, dim=-1) + target_probs = torch.softmax(target_logits, dim=-1) + result: list[list[TokenPrediction]] = [] + for pos in range(probs.shape[1]): + top_vals, top_ids = torch.topk(probs[0, pos], top_k) + pos_preds: list[TokenPrediction] = [] + for p, tid_t in zip(top_vals, top_ids, strict=True): + tid = int(tid_t.item()) + pos_preds.append( + TokenPrediction( + token=tokenizer.get_tok_display(tid), + token_id=tid, + prob=float(p.item()), + logit=float(logits[0, pos, tid].item()), + target_prob=float(target_probs[0, pos, tid].item()), + target_logit=float(target_logits[0, pos, tid].item()), + ) + ) + result.append(pos_preds) + return result + + +def _extract_label_prediction( + logits: Float[Tensor, "1 seq vocab"], + target_logits: Float[Tensor, "1 seq vocab"], + tokenizer: AppTokenizer, + position: int, + label_token: int, +) -> TokenPrediction: + """Extract the prediction for a specific token at a specific position.""" + probs = torch.softmax(logits[0, position], dim=-1) + target_probs = torch.softmax(target_logits[0, position], dim=-1) + return TokenPrediction( + token=tokenizer.get_tok_display(label_token), + token_id=label_token, + prob=float(probs[label_token].item()), + logit=float(logits[0, position, label_token].item()), + target_prob=float(target_probs[label_token].item()), + target_logit=float(target_logits[0, position, label_token].item()), + ) -def compute_intervention_forward( +def compute_intervention( model: ComponentModel, tokens: Float[Tensor, "1 seq"], - active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] - top_k: int, + active_nodes: list[tuple[str, int, int]], + nodes_to_ablate: list[tuple[str, int, int]] | None, tokenizer: AppTokenizer, + adv_pgd_config: AdvPGDConfig, + loss_config: LossConfig, + sampling: SamplingType, + top_k: int, ) -> InterventionResult: - """Forward pass with only specified nodes active. + """Unified intervention evaluation: CI, stochastic, adversarial, and optionally ablated. Args: - model: ComponentModel to run intervention on. - tokens: Input tokens of shape [1, seq]. - active_nodes: List of (layer, seq_pos, component_idx) tuples specifying which nodes to activate. + active_nodes: (concrete_path, seq_pos, component_idx) tuples for selected nodes. + Used for CI, stochastic, and adversarial masking. + nodes_to_ablate: If provided, nodes to ablate in ablated (full target model minus these). + The frontend computes this as all_graph_nodes - selected_nodes. + If None, ablated is skipped. + loss_config: Loss for PGD adversary to maximize and for reporting metrics. + sampling: Sampling type for CI computation. top_k: Number of top predictions to return per position. - tokenizer: Tokenizer for decoding tokens. - - Returns: - InterventionResult with input tokens and top-k predictions per position. """ - seq_len = tokens.shape[1] device = tokens.device - # Build component masks: all zeros, then set 1s for active nodes - component_masks: dict[str, Float[Tensor, "1 seq C"]] = {} - for layer_name, C in model.module_to_c.items(): - component_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) + # Compute natural CI alive masks (the model's own binarized CI, independent of graph) + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=sampling, + detach_inputs=False, + ) + alive_masks: dict[str, Bool[Tensor, "1 seq C"]] = { + k: v > 0 for k, v in ci_outputs.lower_leaky.items() + } + # Build binary CI masks from active nodes (selected = 1, rest = 0) + ci_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name, C in model.module_to_c.items(): + ci_masks[layer_name] = torch.zeros(1, seq_len, C, device=device) for layer, seq_pos, c_idx in active_nodes: - assert layer in component_masks, f"Layer {layer} not in model" - assert 0 <= seq_pos < seq_len, f"seq_pos {seq_pos} out of bounds [0, {seq_len})" - assert 0 <= c_idx < model.module_to_c[layer], ( - f"component_idx {c_idx} out of bounds [0, {model.module_to_c[layer]})" + ci_masks[layer][0, seq_pos, c_idx] = 1.0 + assert alive_masks[layer][0, seq_pos, c_idx], ( + f"Selected node {layer}:{seq_pos}:{c_idx} is not alive (CI=0)" ) - component_masks[layer][0, seq_pos, c_idx] = 1.0 - mask_infos = make_mask_infos(component_masks, routing_masks="all") + with torch.no_grad(), bf16_autocast(): + # Target forward (unmasked) + target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) + + # CI forward (binary mask) + ci_mask_infos = make_mask_infos(ci_masks, routing_masks="all") + ci_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=ci_mask_infos) + + # Stochastic forward: ci + (1-ci) * uniform + stoch_masks = { + layer: ci_masks[layer] + (1 - ci_masks[layer]) * torch.rand_like(ci_masks[layer]) + for layer in ci_masks + } + stoch_mask_infos = make_mask_infos(stoch_masks, routing_masks="all") + stoch_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=stoch_mask_infos) + + # Target-sans forward (only if nodes_to_ablate provided) + ts_logits: Float[Tensor, "1 seq vocab"] | None = None + if nodes_to_ablate is not None: + ts_masks: dict[str, Float[Tensor, "1 seq C"]] = {} + for layer_name in ci_masks: + ts_masks[layer_name] = torch.ones_like(ci_masks[layer_name]) + for layer, seq_pos, c_idx in nodes_to_ablate: + ts_masks[layer][0, seq_pos, c_idx] = 0.0 + weight_deltas = model.calc_weight_deltas() + ts_wd = { + k: (v, torch.ones(tokens.shape, device=device)) for k, v in weight_deltas.items() + } + ts_mask_infos = make_mask_infos( + ts_masks, routing_masks="all", weight_deltas_and_masks=ts_wd + ) + ts_logits = model(tokens, mask_infos=ts_mask_infos) + # Adversarial: PGD optimizes alive-but-unselected components + adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=ci_masks, + alive_masks=alive_masks, + adv_config=adv_pgd_config, + target_out=target_logits, + loss_config=loss_config, + ) + # Non-alive positions get uniform random fill + adv_masks = interpolate_pgd_mask(ci_masks, adv_sources) + with torch.no_grad(): + for layer in adv_masks: + non_alive = ~alive_masks[layer] + adv_masks[layer][non_alive] = torch.rand(int(non_alive.sum().item()), device=device) with torch.no_grad(), bf16_autocast(): - # SPD model forward pass (with component masks) - spd_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=mask_infos) - spd_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(spd_logits, dim=-1) + adv_mask_infos = make_mask_infos(adv_masks, routing_masks="all") + adv_logits: Float[Tensor, "1 seq vocab"] = model(tokens, mask_infos=adv_mask_infos) + + # Extract predictions and loss metrics + device_str = str(device) + with torch.no_grad(): + ci_preds = _extract_topk_predictions(ci_logits, target_logits, tokenizer, top_k) + stoch_preds = _extract_topk_predictions(stoch_logits, target_logits, tokenizer, top_k) + adv_preds = _extract_topk_predictions(adv_logits, target_logits, tokenizer, top_k) + + ci_loss = float( + compute_recon_loss(ci_logits, loss_config, target_logits, device_str).item() + ) + stoch_loss = float( + compute_recon_loss(stoch_logits, loss_config, target_logits, device_str).item() + ) + adv_loss = float( + compute_recon_loss(adv_logits, loss_config, target_logits, device_str).item() + ) - # Target model forward pass (no masks) - target_logits: Float[Tensor, "1 seq vocab"] = model(tokens) - target_out_probs: Float[Tensor, "1 seq vocab"] = torch.softmax(target_logits, dim=-1) - - # Get top-k predictions per position (based on SPD model's top-k) - predictions_per_position: list[list[tuple[str, int, float, float, float, float]]] = [] - for pos in range(seq_len): - pos_spd_probs = spd_probs[0, pos] - pos_spd_logits = spd_logits[0, pos] - pos_target_out_probs = target_out_probs[0, pos] - pos_target_logits = target_logits[0, pos] - top_probs, top_ids = torch.topk(pos_spd_probs, top_k) - - pos_predictions: list[tuple[str, int, float, float, float, float]] = [] - for spd_prob, token_id in zip(top_probs, top_ids, strict=True): - tid = int(token_id.item()) - token_str = tokenizer.get_tok_display(tid) - target_prob = float(pos_target_out_probs[tid].item()) - target_logit = float(pos_target_logits[tid].item()) - pos_predictions.append( - ( - token_str, - tid, - float(spd_prob.item()), - float(pos_spd_logits[tid].item()), - target_prob, - target_logit, - ) + ts_preds: list[list[TokenPrediction]] | None = None + ts_loss: float | None = None + if ts_logits is not None: + ts_preds = _extract_topk_predictions(ts_logits, target_logits, tokenizer, top_k) + ts_loss = float( + compute_recon_loss(ts_logits, loss_config, target_logits, device_str).item() ) - predictions_per_position.append(pos_predictions) - # Decode input tokens + label: LabelPredictions | None = None + if isinstance(loss_config, CELossConfig | LogitLossConfig): + pos, tid = loss_config.position, loss_config.label_token + ts_label = ( + _extract_label_prediction(ts_logits, target_logits, tokenizer, pos, tid) + if ts_logits is not None + else None + ) + label = LabelPredictions( + position=pos, + ci=_extract_label_prediction(ci_logits, target_logits, tokenizer, pos, tid), + stochastic=_extract_label_prediction(stoch_logits, target_logits, tokenizer, pos, tid), + adversarial=_extract_label_prediction(adv_logits, target_logits, tokenizer, pos, tid), + ablated=ts_label, + ) + input_tokens = tokenizer.get_spans([int(t.item()) for t in tokens[0]]) return InterventionResult( input_tokens=input_tokens, - predictions_per_position=predictions_per_position, + ci=ci_preds, + stochastic=stoch_preds, + adversarial=adv_preds, + ablated=ts_preds, + ci_loss=ci_loss, + stochastic_loss=stoch_loss, + adversarial_loss=adv_loss, + ablated_loss=ts_loss, + label=label, ) diff --git a/spd/app/backend/database.py b/spd/app/backend/database.py index 6b2b09552..61be17ad5 100644 --- a/spd/app/backend/database.py +++ b/spd/app/backend/database.py @@ -6,10 +6,13 @@ Interpretations are stored separately at SPD_OUT_DIR/autointerp//. """ +import fcntl import hashlib import io import json +import os import sqlite3 +from contextlib import contextmanager from pathlib import Path from typing import Literal @@ -17,14 +20,35 @@ from pydantic import BaseModel from spd.app.backend.compute import Edge, Node -from spd.app.backend.optim_cis import CELossConfig, KLLossConfig, LossConfig, MaskType -from spd.settings import REPO_ROOT +from spd.app.backend.optim_cis import ( + CELossConfig, + KLLossConfig, + LogitLossConfig, + MaskType, + PositionalLossConfig, +) +from spd.settings import SPD_OUT_DIR GraphType = Literal["standard", "optimized", "manual"] -# Persistent data directories -_APP_DATA_DIR = REPO_ROOT / ".data" / "app" -DEFAULT_DB_PATH = _APP_DATA_DIR / "prompt_attr.db" +_DEFAULT_DB_PATH = SPD_OUT_DIR / "app" / "prompt_attr.db" + + +def get_default_db_path() -> Path: + """Get the default database path. + + Checks env vars in order: + 1. SPD_INVESTIGATION_DIR - investigation mode, db at dir/app.db + 2. SPD_APP_DB_PATH - explicit override + 3. Default: SPD_OUT_DIR/app/prompt_attr.db + """ + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + return Path(investigation_dir) / "app.db" + env_path = os.environ.get("SPD_APP_DB_PATH") + if env_path: + return Path(env_path) + return _DEFAULT_DB_PATH class Run(BaseModel): @@ -43,6 +67,11 @@ class PromptRecord(BaseModel): is_custom: bool = False +class PgdConfig(BaseModel): + n_steps: int + step_size: float + + class OptimizationParams(BaseModel): """Optimization parameters that affect graph computation.""" @@ -51,9 +80,12 @@ class OptimizationParams(BaseModel): pnorm: float beta: float mask_type: MaskType - loss: LossConfig - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + loss: PositionalLossConfig + pgd: PgdConfig | None = None + # Computed metrics (persisted for display on reload) + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + adv_pgd_label_prob: float | None = None class StoredGraph(BaseModel): @@ -66,9 +98,11 @@ class StoredGraph(BaseModel): # Core graph data (all types) edges: list[Edge] + edges_abs: list[Edge] | None = ( + None # absolute-target variant (∂|y|/∂x · x), None for old graphs + ) ci_masked_out_logits: torch.Tensor # [seq, vocab] target_out_logits: torch.Tensor # [seq, vocab] - adv_pgd_out_logits: torch.Tensor | None = None # [seq, vocab] adversarial PGD logits node_ci_vals: dict[str, float] # layer:seq:c_idx -> ci_val (required for all graphs) node_subcomp_acts: dict[str, float] = {} # layer:seq:c_idx -> subcomp act (v_i^T @ a) @@ -85,17 +119,17 @@ class InterventionRunRecord(BaseModel): id: int graph_id: int selected_nodes: list[str] # node keys that were selected - result_json: str # JSON-encoded InterventionResponse + result_json: str # JSON-encoded InterventionResult created_at: str class ForkedInterventionRunRecord(BaseModel): - """A forked intervention run with modified tokens.""" + """A forked intervention run with modified tokens (currently unused).""" id: int intervention_run_id: int token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result_json: str # JSON-encoded InterventionResponse + result_json: str created_at: str @@ -111,7 +145,8 @@ class PromptAttrDB: """ def __init__(self, db_path: Path | None = None, check_same_thread: bool = True): - self.db_path = db_path or DEFAULT_DB_PATH + self.db_path = db_path or get_default_db_path() + self._lock_path = self.db_path.with_suffix(".db.lock") self._check_same_thread = check_same_thread self._conn: sqlite3.Connection | None = None @@ -135,6 +170,16 @@ def __enter__(self) -> "PromptAttrDB": def __exit__(self, *args: object) -> None: self.close() + @contextmanager + def _write_lock(self): + """Acquire an exclusive file lock for write operations (NFS-safe).""" + with open(self._lock_path, "w") as lock_fd: + try: + fcntl.flock(lock_fd, fcntl.LOCK_EX) + yield + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + # ------------------------------------------------------------------------- # Schema initialization # ------------------------------------------------------------------------- @@ -142,7 +187,7 @@ def __exit__(self, *args: object) -> None: def init_schema(self) -> None: """Initialize the database schema. Safe to call multiple times.""" conn = self._get_conn() - conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA journal_mode=DELETE") conn.execute("PRAGMA foreign_keys=ON") conn.executescript(""" CREATE TABLE IF NOT EXISTS runs ( @@ -178,12 +223,19 @@ def init_schema(self) -> None: adv_pgd_n_steps INTEGER, adv_pgd_step_size REAL, + -- Optimization metrics (NULL for non-optimized graphs) + ci_masked_label_prob REAL, + stoch_masked_label_prob REAL, + adv_pgd_label_prob REAL, + -- Manual graph params (NULL for non-manual graphs) included_nodes TEXT, -- JSON array of node keys in this graph included_nodes_hash TEXT, -- SHA256 hash of sorted JSON for uniqueness -- The actual graph data (JSON) edges_data TEXT NOT NULL, + -- Absolute-target edges (∂|y|/∂x · x), NULL for old graphs + edges_data_abs TEXT, -- Node CI values: "layer:seq:c_idx" -> ci_val (required for all graphs) node_ci_vals TEXT NOT NULL, -- Node subcomponent activations: "layer:seq:c_idx" -> v_i^T @ a @@ -216,7 +268,7 @@ def init_schema(self) -> None: id INTEGER PRIMARY KEY AUTOINCREMENT, graph_id INTEGER NOT NULL REFERENCES graphs(id), selected_nodes TEXT NOT NULL, -- JSON array of node keys - result TEXT NOT NULL, -- JSON InterventionResponse + result TEXT NOT NULL, -- JSON InterventionResult created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ); @@ -234,6 +286,12 @@ def init_schema(self) -> None: CREATE INDEX IF NOT EXISTS idx_forked_intervention_runs_parent ON forked_intervention_runs(intervention_run_id); """) + + # Migration: add edges_data_abs column if missing (backwards compat with existing DBs) + columns = {row[1] for row in conn.execute("PRAGMA table_info(graphs)").fetchall()} + if "edges_data_abs" not in columns: + conn.execute("ALTER TABLE graphs ADD COLUMN edges_data_abs TEXT") + conn.commit() # ------------------------------------------------------------------------- @@ -242,15 +300,16 @@ def init_schema(self) -> None: def create_run(self, wandb_path: str) -> int: """Create a new run. Returns the run ID.""" - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO runs (wandb_path) VALUES (?)", - (wandb_path,), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO runs (wandb_path) VALUES (?)", + (wandb_path,), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_run_by_wandb_path(self, wandb_path: str) -> Run | None: """Get a run by its wandb path.""" @@ -308,19 +367,20 @@ def add_custom_prompt( Returns: The prompt ID (existing or newly created). """ - existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) - if existing_id is not None: - return existing_id + with self._write_lock(): + existing_id = self.find_prompt_by_token_ids(run_id, token_ids, context_length) + if existing_id is not None: + return existing_id - conn = self._get_conn() - cursor = conn.execute( - "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", - (run_id, json.dumps(token_ids), context_length), - ) - prompt_id = cursor.lastrowid - assert prompt_id is not None - conn.commit() - return prompt_id + conn = self._get_conn() + cursor = conn.execute( + "INSERT INTO prompts (run_id, token_ids, context_length, is_custom) VALUES (?, ?, ?, 1)", + (run_id, json.dumps(token_ids), context_length), + ) + prompt_id = cursor.lastrowid + assert prompt_id is not None + conn.commit() + return prompt_id def get_prompt(self, prompt_id: int) -> PromptRecord | None: """Get a prompt by ID.""" @@ -384,24 +444,26 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: "component_idx": n.component_idx, } - edges_json = json.dumps( - [ - { - "source": _node_to_dict(e.source), - "target": _node_to_dict(e.target), - "strength": e.strength, - "is_cross_seq": e.is_cross_seq, - } - for e in graph.edges - ] - ) + def _edges_to_json(edges: list[Edge]) -> str: + return json.dumps( + [ + { + "source": _node_to_dict(e.source), + "target": _node_to_dict(e.target), + "strength": e.strength, + "is_cross_seq": e.is_cross_seq, + } + for e in edges + ] + ) + + edges_json = _edges_to_json(graph.edges) + edges_abs_json = _edges_to_json(graph.edges_abs) if graph.edges_abs is not None else None buf = io.BytesIO() logits_dict: dict[str, torch.Tensor] = { "ci_masked": graph.ci_masked_out_logits, "target": graph.target_out_logits, } - if graph.adv_pgd_out_logits is not None: - logits_dict["adv_pgd"] = graph.adv_pgd_out_logits torch.save(logits_dict, buf) output_logits_blob = buf.getvalue() node_ci_vals_json = json.dumps(graph.node_ci_vals) @@ -417,6 +479,9 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: loss_config_hash: str | None = None adv_pgd_n_steps = None adv_pgd_step_size = None + ci_masked_label_prob = None + stoch_masked_label_prob = None + adv_pgd_label_prob = None if graph.optimization_params: imp_min_coeff = graph.optimization_params.imp_min_coeff @@ -426,8 +491,15 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: mask_type = graph.optimization_params.mask_type loss_config_json = graph.optimization_params.loss.model_dump_json() loss_config_hash = hashlib.sha256(loss_config_json.encode()).hexdigest() - adv_pgd_n_steps = graph.optimization_params.adv_pgd_n_steps - adv_pgd_step_size = graph.optimization_params.adv_pgd_step_size + adv_pgd_n_steps = ( + graph.optimization_params.pgd.n_steps if graph.optimization_params.pgd else None + ) + adv_pgd_step_size = ( + graph.optimization_params.pgd.step_size if graph.optimization_params.pgd else None + ) + ci_masked_label_prob = graph.optimization_params.ci_masked_label_prob + stoch_masked_label_prob = graph.optimization_params.stoch_masked_label_prob + adv_pgd_label_prob = graph.optimization_params.adv_pgd_label_prob # Extract manual-specific values (NULL for non-manual graphs) # Sort included_nodes and compute hash for reliable uniqueness @@ -437,64 +509,70 @@ def _node_to_dict(n: Node) -> dict[str, str | int]: included_nodes_json = json.dumps(sorted(graph.included_nodes)) included_nodes_hash = hashlib.sha256(included_nodes_json.encode()).hexdigest() - try: - cursor = conn.execute( - """INSERT INTO graphs - (prompt_id, graph_type, - imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, loss_config_hash, - adv_pgd_n_steps, adv_pgd_step_size, - included_nodes, included_nodes_hash, - edges_data, output_logits, node_ci_vals, node_subcomp_acts) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", - ( - prompt_id, - graph.graph_type, - imp_min_coeff, - steps, - pnorm, - beta, - mask_type, - loss_config_json, - loss_config_hash, - adv_pgd_n_steps, - adv_pgd_step_size, - included_nodes_json, - included_nodes_hash, - edges_json, - output_logits_blob, - node_ci_vals_json, - node_subcomp_acts_json, - ), - ) - conn.commit() - graph_id = cursor.lastrowid - assert graph_id is not None - return graph_id - except sqlite3.IntegrityError as e: - match graph.graph_type: - case "standard": - raise ValueError( - f"Standard graph already exists for prompt_id={prompt_id}. " - "Use get_graphs() to retrieve existing graph or delete it first." - ) from e - case "optimized": - raise ValueError( - f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." - ) from e - case "manual": - # Get-or-create semantics: return existing graph ID - conn.rollback() - row = conn.execute( - """SELECT id FROM graphs - WHERE prompt_id = ? AND graph_type = 'manual' - AND included_nodes_hash = ?""", - (prompt_id, included_nodes_hash), - ).fetchone() - if row: - return row["id"] - # Should not happen if constraint triggered - raise ValueError("A manual graph with the same nodes already exists.") from e + with self._write_lock(): + try: + cursor = conn.execute( + """INSERT INTO graphs + (prompt_id, graph_type, + imp_min_coeff, steps, pnorm, beta, mask_type, + loss_config, loss_config_hash, + adv_pgd_n_steps, adv_pgd_step_size, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob, + included_nodes, included_nodes_hash, + edges_data, edges_data_abs, output_logits, node_ci_vals, node_subcomp_acts) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", + ( + prompt_id, + graph.graph_type, + imp_min_coeff, + steps, + pnorm, + beta, + mask_type, + loss_config_json, + loss_config_hash, + adv_pgd_n_steps, + adv_pgd_step_size, + ci_masked_label_prob, + stoch_masked_label_prob, + adv_pgd_label_prob, + included_nodes_json, + included_nodes_hash, + edges_json, + edges_abs_json, + output_logits_blob, + node_ci_vals_json, + node_subcomp_acts_json, + ), + ) + conn.commit() + graph_id = cursor.lastrowid + assert graph_id is not None + return graph_id + except sqlite3.IntegrityError as e: + match graph.graph_type: + case "standard": + raise ValueError( + f"Standard graph already exists for prompt_id={prompt_id}. " + "Use get_graphs() to retrieve existing graph or delete it first." + ) from e + case "optimized": + raise ValueError( + f"Optimized graph with same parameters already exists for prompt_id={prompt_id}." + ) from e + case "manual": + conn.rollback() + row = conn.execute( + """SELECT id FROM graphs + WHERE prompt_id = ? AND graph_type = 'manual' + AND included_nodes_hash = ?""", + (prompt_id, included_nodes_hash), + ).fetchone() + if row: + return row["id"] + raise ValueError( + "A manual graph with the same nodes already exists." + ) from e def _row_to_stored_graph(self, row: sqlite3.Row) -> StoredGraph: """Convert a database row to a StoredGraph.""" @@ -506,19 +584,22 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: component_idx=int(d["component_idx"]), ) - edges = [ - Edge( - source=_node_from_dict(e["source"]), - target=_node_from_dict(e["target"]), - strength=float(e["strength"]), - is_cross_seq=bool(e["is_cross_seq"]), - ) - for e in json.loads(row["edges_data"]) - ] + def _parse_edges(data: str) -> list[Edge]: + return [ + Edge( + source=_node_from_dict(e["source"]), + target=_node_from_dict(e["target"]), + strength=float(e["strength"]), + is_cross_seq=bool(e["is_cross_seq"]), + ) + for e in json.loads(data) + ] + + edges = _parse_edges(row["edges_data"]) + edges_abs = _parse_edges(row["edges_data_abs"]) if row["edges_data_abs"] else None logits_data = torch.load(io.BytesIO(row["output_logits"]), weights_only=True) ci_masked_out_logits: torch.Tensor = logits_data["ci_masked"] target_out_logits: torch.Tensor = logits_data["target"] - adv_pgd_out_logits: torch.Tensor | None = logits_data.get("adv_pgd") node_ci_vals: dict[str, float] = json.loads(row["node_ci_vals"]) node_subcomp_acts: dict[str, float] = json.loads(row["node_subcomp_acts"] or "{}") @@ -526,12 +607,18 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: if row["graph_type"] == "optimized": loss_config_data = json.loads(row["loss_config"]) loss_type = loss_config_data["type"] - assert loss_type in ("ce", "kl"), f"Unknown loss type: {loss_type}" - loss_config: LossConfig - if loss_type == "ce": - loss_config = CELossConfig(**loss_config_data) - else: - loss_config = KLLossConfig(**loss_config_data) + assert loss_type in ("ce", "kl", "logit"), f"Unknown loss type: {loss_type}" + loss_config: PositionalLossConfig + match loss_type: + case "ce": + loss_config = CELossConfig(**loss_config_data) + case "kl": + loss_config = KLLossConfig(**loss_config_data) + case "logit": + loss_config = LogitLossConfig(**loss_config_data) + pgd = None + if row["adv_pgd_n_steps"] is not None: + pgd = PgdConfig(n_steps=row["adv_pgd_n_steps"], step_size=row["adv_pgd_step_size"]) opt_params = OptimizationParams( imp_min_coeff=row["imp_min_coeff"], steps=row["steps"], @@ -539,8 +626,10 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: beta=row["beta"], mask_type=row["mask_type"], loss=loss_config, - adv_pgd_n_steps=row["adv_pgd_n_steps"], - adv_pgd_step_size=row["adv_pgd_step_size"], + pgd=pgd, + ci_masked_label_prob=row["ci_masked_label_prob"], + stoch_masked_label_prob=row["stoch_masked_label_prob"], + adv_pgd_label_prob=row["adv_pgd_label_prob"], ) # Parse manual-specific fields @@ -552,9 +641,9 @@ def _node_from_dict(d: dict[str, str | int]) -> Node: id=row["id"], graph_type=row["graph_type"], edges=edges, + edges_abs=edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, - adv_pgd_out_logits=adv_pgd_out_logits, node_ci_vals=node_ci_vals, node_subcomp_acts=node_subcomp_acts, optimization_params=opt_params, @@ -572,9 +661,10 @@ def get_graphs(self, prompt_id: int) -> list[StoredGraph]: """ conn = self._get_conn() rows = conn.execute( - """SELECT id, graph_type, edges_data, output_logits, node_ci_vals, + """SELECT id, graph_type, edges_data, edges_data_abs, output_logits, node_ci_vals, node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes + loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes, + ci_masked_label_prob, stoch_masked_label_prob, adv_pgd_label_prob FROM graphs WHERE prompt_id = ? ORDER BY @@ -588,9 +678,11 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: """Retrieve a single graph by its ID. Returns (graph, prompt_id) or None.""" conn = self._get_conn() row = conn.execute( - """SELECT id, prompt_id, graph_type, edges_data, output_logits, node_ci_vals, - node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, mask_type, - loss_config, adv_pgd_n_steps, adv_pgd_step_size, included_nodes + """SELECT id, prompt_id, graph_type, edges_data, edges_data_abs, output_logits, + node_ci_vals, node_subcomp_acts, imp_min_coeff, steps, pnorm, beta, + mask_type, loss_config, adv_pgd_n_steps, adv_pgd_step_size, + included_nodes, ci_masked_label_prob, stoch_masked_label_prob, + adv_pgd_label_prob FROM graphs WHERE id = ?""", (graph_id,), @@ -599,23 +691,45 @@ def get_graph(self, graph_id: int) -> tuple[StoredGraph, int] | None: return None return (self._row_to_stored_graph(row), row["prompt_id"]) + def delete_prompt(self, prompt_id: int) -> None: + """Delete a prompt and all its graphs, intervention runs, and forked runs.""" + with self._write_lock(): + conn = self._get_conn() + graph_ids_query = "SELECT id FROM graphs WHERE prompt_id = ?" + intervention_ids_query = ( + f"SELECT id FROM intervention_runs WHERE graph_id IN ({graph_ids_query})" + ) + conn.execute( + f"DELETE FROM forked_intervention_runs WHERE intervention_run_id IN ({intervention_ids_query})", + (prompt_id,), + ) + conn.execute( + f"DELETE FROM intervention_runs WHERE graph_id IN ({graph_ids_query})", + (prompt_id,), + ) + conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) + conn.execute("DELETE FROM prompts WHERE id = ?", (prompt_id,)) + conn.commit() + def delete_graphs_for_prompt(self, prompt_id: int) -> int: """Delete all graphs for a prompt. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute("DELETE FROM graphs WHERE prompt_id = ?", (prompt_id,)) + conn.commit() + return cursor.rowcount def delete_graphs_for_run(self, run_id: int) -> int: """Delete all graphs for all prompts in a run. Returns the number of deleted rows.""" - conn = self._get_conn() - cursor = conn.execute( - """DELETE FROM graphs - WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", - (run_id,), - ) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """DELETE FROM graphs + WHERE prompt_id IN (SELECT id FROM prompts WHERE run_id = ?)""", + (run_id,), + ) + conn.commit() + return cursor.rowcount # ------------------------------------------------------------------------- # Intervention run operations @@ -632,21 +746,22 @@ def save_intervention_run( Args: graph_id: The graph ID this run belongs to. selected_nodes: List of node keys that were selected. - result_json: JSON-encoded InterventionResponse. + result_json: JSON-encoded InterventionResult. Returns: The intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO intervention_runs (graph_id, selected_nodes, result) - VALUES (?, ?, ?)""", - (graph_id, json.dumps(selected_nodes), result_json), - ) - conn.commit() - run_id = cursor.lastrowid - assert run_id is not None - return run_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO intervention_runs (graph_id, selected_nodes, result) + VALUES (?, ?, ?)""", + (graph_id, json.dumps(selected_nodes), result_json), + ) + conn.commit() + run_id = cursor.lastrowid + assert run_id is not None + return run_id def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: """Get all intervention runs for a graph. @@ -679,16 +794,18 @@ def get_intervention_runs(self, graph_id: int) -> list[InterventionRunRecord]: def delete_intervention_run(self, run_id: int) -> None: """Delete an intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM intervention_runs WHERE id = ?", (run_id,)) + conn.commit() def delete_intervention_runs_for_graph(self, graph_id: int) -> int: """Delete all intervention runs for a graph. Returns count deleted.""" - conn = self._get_conn() - cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) - conn.commit() - return cursor.rowcount + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute("DELETE FROM intervention_runs WHERE graph_id = ?", (graph_id,)) + conn.commit() + return cursor.rowcount # ------------------------------------------------------------------------- # Forked intervention run operations @@ -710,16 +827,17 @@ def save_forked_intervention_run( Returns: The forked intervention run ID. """ - conn = self._get_conn() - cursor = conn.execute( - """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) - VALUES (?, ?, ?)""", - (intervention_run_id, json.dumps(token_replacements), result_json), - ) - conn.commit() - fork_id = cursor.lastrowid - assert fork_id is not None - return fork_id + with self._write_lock(): + conn = self._get_conn() + cursor = conn.execute( + """INSERT INTO forked_intervention_runs (intervention_run_id, token_replacements, result) + VALUES (?, ?, ?)""", + (intervention_run_id, json.dumps(token_replacements), result_json), + ) + conn.commit() + fork_id = cursor.lastrowid + assert fork_id is not None + return fork_id def get_forked_intervention_runs( self, intervention_run_id: int @@ -775,6 +893,7 @@ def get_intervention_run(self, run_id: int) -> InterventionRunRecord | None: def delete_forked_intervention_run(self, fork_id: int) -> None: """Delete a forked intervention run.""" - conn = self._get_conn() - conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) - conn.commit() + with self._write_lock(): + conn = self._get_conn() + conn.execute("DELETE FROM forked_intervention_runs WHERE id = ?", (fork_id,)) + conn.commit() diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 48403feb6..4c7e42fbb 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -13,6 +13,7 @@ from spd.configs import ImportanceMinimalityLossConfig, PGDInitStrategy, SamplingType from spd.metrics import importance_minimality_loss +from spd.metrics.pgd_utils import get_pgd_init_tensor, interpolate_pgd_mask from spd.models.component_model import CIOutputs, ComponentModel, OutputWithCache from spd.models.components import make_mask_infos from spd.routing import AllLayersRouter @@ -48,16 +49,33 @@ class KLLossConfig(BaseModel): position: int -LossConfig = CELossConfig | KLLossConfig +class LogitLossConfig(BaseModel): + """Logit loss: maximize the pre-softmax logit for a specific token at a position.""" + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + + +class MeanKLLossConfig(BaseModel): + """Mean KL divergence loss: match target model distribution across all positions.""" + + type: Literal["mean_kl"] = "mean_kl" + coeff: float = 1.0 -def _compute_recon_loss( + +PositionalLossConfig = CELossConfig | KLLossConfig | LogitLossConfig +LossConfig = CELossConfig | KLLossConfig | LogitLossConfig | MeanKLLossConfig + + +def compute_recon_loss( logits: Tensor, loss_config: LossConfig, target_out: Tensor, device: str, ) -> Tensor: - """Compute recon loss (CE or KL) from model output logits at the configured position.""" + """Compute recon loss (CE, KL, or mean KL) from model output logits.""" match loss_config: case CELossConfig(position=pos, label_token=label_token): return F.cross_entropy( @@ -68,14 +86,12 @@ def _compute_recon_loss( target_probs = F.softmax(target_out[0, pos, :], dim=-1) pred_log_probs = F.log_softmax(logits[0, pos, :], dim=-1) return F.kl_div(pred_log_probs, target_probs, reduction="sum") - - -def _interpolate_masks( - ci: dict[str, Tensor], - sources: dict[str, Tensor], -) -> dict[str, Tensor]: - """Compute PGD component masks: ci + (1 - ci) * source.""" - return {layer: ci[layer] + (1 - ci[layer]) * sources[layer] for layer in ci} + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[0, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="batchmean") @dataclass @@ -188,94 +204,6 @@ def create_optimizable_ci_params( ) -def compute_l0_stats( - ci_outputs: CIOutputs, - ci_alive_threshold: float, -) -> dict[str, float]: - """Compute L0 statistics for each layer.""" - stats: dict[str, float] = {} - for layer_name, layer_ci in ci_outputs.lower_leaky.items(): - l0_val = calc_ci_l_zero(layer_ci, ci_alive_threshold) - stats[f"l0/{layer_name}"] = l0_val - stats["l0/total"] = sum(stats.values()) - return stats - - -def compute_specific_pos_ce_kl( - model: ComponentModel, - batch: Tensor, - target_out: Tensor, - ci: dict[str, Tensor], - rounding_threshold: float, - loss_seq_pos: int, -) -> dict[str, float]: - """Compute CE and KL metrics for a specific sequence position. - - Args: - model: The ComponentModel. - batch: Input tokens of shape [1, seq_len]. - target_out: Target model output logits of shape [1, seq_len, vocab]. - ci: Causal importance values (lower_leaky) per layer. - rounding_threshold: Threshold for rounding CI values to binary masks. - loss_seq_pos: Sequence position to compute metrics for. - - Returns: - Dict with kl and ce_difference metrics for ci_masked, unmasked, and rounded_masked. - """ - assert batch.ndim == 2 and batch.shape[0] == 1, "Expected batch shape [1, seq_len]" - - # Get target logits at the specified position - target_logits = target_out[0, loss_seq_pos, :] # [vocab] - - def kl_vs_target(logits: Tensor) -> float: - """KL divergence between predicted and target logits at target position.""" - pos_logits = logits[0, loss_seq_pos, :] # [vocab] - target_probs = F.softmax(target_logits, dim=-1) - pred_log_probs = F.log_softmax(pos_logits, dim=-1) - return F.kl_div(pred_log_probs, target_probs, reduction="sum").item() - - def ce_vs_target(logits: Tensor) -> float: - """CE between predicted logits and target's argmax at target position.""" - pos_logits = logits[0, loss_seq_pos, :] # [vocab] - target_token = target_logits.argmax() - return F.cross_entropy(pos_logits.unsqueeze(0), target_token.unsqueeze(0)).item() - - # Target model CE (baseline) - target_ce = ce_vs_target(target_out) - - # CI masked - ci_mask_infos = make_mask_infos(ci) - with bf16_autocast(): - ci_masked_logits = model(batch, mask_infos=ci_mask_infos) - ci_masked_kl = kl_vs_target(ci_masked_logits) - ci_masked_ce = ce_vs_target(ci_masked_logits) - - # Unmasked (all components active) - unmasked_infos = make_mask_infos({k: torch.ones_like(v) for k, v in ci.items()}) - with bf16_autocast(): - unmasked_logits = model(batch, mask_infos=unmasked_infos) - unmasked_kl = kl_vs_target(unmasked_logits) - unmasked_ce = ce_vs_target(unmasked_logits) - - # Rounded masked (binary masks based on threshold) - rounded_mask_infos = make_mask_infos( - {k: (v > rounding_threshold).float() for k, v in ci.items()} - ) - with bf16_autocast(): - rounded_masked_logits = model(batch, mask_infos=rounded_mask_infos) - rounded_masked_kl = kl_vs_target(rounded_masked_logits) - rounded_masked_ce = ce_vs_target(rounded_masked_logits) - - return { - "kl_ci_masked": ci_masked_kl, - "kl_unmasked": unmasked_kl, - "kl_rounded_masked": rounded_masked_kl, - "ce_difference_ci_masked": ci_masked_ce - target_ce, - "ce_difference_unmasked": unmasked_ce - target_ce, - "ce_difference_rounded_masked": rounded_masked_ce - target_ce, - } - - @dataclass class OptimCIConfig: """Configuration for optimizing CI values on a single prompt.""" @@ -292,9 +220,9 @@ class OptimCIConfig: log_freq: int - # Loss config (exactly one of CE or KL) + # Loss config (CE or KL — must target a specific position) imp_min_config: ImportanceMinimalityLossConfig - loss_config: LossConfig + loss_config: PositionalLossConfig sampling: SamplingType @@ -306,43 +234,51 @@ class OptimCIConfig: ProgressCallback = Callable[[int, int, str], None] # (current, total, stage) +class CISnapshot(BaseModel): + """Snapshot of alive component counts during CI optimization for visualization.""" + + step: int + total_steps: int + layers: list[str] + seq_len: int + initial_alive: list[list[int]] # layers × seq + current_alive: list[list[int]] # layers × seq + l0_total: float + loss: float + + +CISnapshotCallback = Callable[[CISnapshot], None] + + @dataclass class OptimizeCIResult: """Result from CI optimization including params and final metrics.""" params: OptimizableCIParams metrics: OptimizationMetrics - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None -def _run_adv_pgd( +def run_adv_pgd( model: ComponentModel, tokens: Tensor, - ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], + ci: dict[str, Float[Tensor, "1 seq C"]], alive_masks: dict[str, Bool[Tensor, "1 seq C"]], adv_config: AdvPGDConfig, - loss_config: LossConfig, target_out: Tensor, - device: str, + loss_config: LossConfig, ) -> dict[str, Float[Tensor, "1 seq C"]]: - """Run PGD to find adversarial sources maximizing reconstruction loss. + """Run PGD to find adversarial sources maximizing loss. Sources are optimized via signed gradient ascent. Only alive positions are optimized. Masks are computed as ci + (1 - ci) * source (same interpolation as training PGD). Returns detached adversarial source tensors. """ - ci_detached = {k: v.detach() for k, v in ci_lower_leaky.items()} + ci_detached = {k: v.detach() for k, v in ci.items()} adv_sources: dict[str, Tensor] = {} - for layer_name, ci in ci_detached.items(): - match adv_config.init: - case "random": - source = torch.rand_like(ci) - case "ones": - source = torch.ones_like(ci) - case "zeroes": - source = torch.zeros_like(ci) + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) source[~alive_masks[layer_name]] = 0.0 source.requires_grad_(True) adv_sources[layer_name] = source @@ -350,12 +286,13 @@ def _run_adv_pgd( source_list = list(adv_sources.values()) for _ in range(adv_config.n_steps): - mask_infos = make_mask_infos(_interpolate_masks(ci_detached, adv_sources)) + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) with bf16_autocast(): out = model(tokens, mask_infos=mask_infos) - loss = _compute_recon_loss(out, loss_config, target_out, device) + loss = compute_recon_loss(out, loss_config, target_out, str(tokens.device)) + grads = torch.autograd.grad(loss, source_list) with torch.no_grad(): for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): @@ -372,6 +309,7 @@ def optimize_ci_values( config: OptimCIConfig, device: str, on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, ) -> OptimizeCIResult: """Optimize CI values for a single prompt. @@ -406,13 +344,40 @@ def optimize_ci_values( weight_deltas = model.calc_weight_deltas() + # Precompute snapshot metadata for CI visualization + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + params = ci_params.get_parameters() optimizer = optim.AdamW(params, lr=config.lr, weight_decay=config.weight_decay) progress_interval = max(1, config.steps // 20) # Report ~20 times during optimization + latest_loss: float = 0.0 for step in tqdm(range(config.steps), desc="Optimizing CI values"): - if on_progress is not None and step % progress_interval == 0: - on_progress(step, config.steps, "optimizing") + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: + with torch.no_grad(): + snap_ci = ci_params.create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, + ) + ) optimizer.zero_grad() @@ -444,82 +409,46 @@ def optimize_ci_values( p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, ) - recon_loss = _compute_recon_loss(recon_out, config.loss_config, target_out, device) + recon_loss = compute_recon_loss(recon_out, config.loss_config, target_out, device) total_loss = config.loss_config.coeff * recon_loss + imp_min_coeff * imp_min_loss + latest_loss = total_loss.item() # PGD adversarial loss (runs in tandem with recon) if config.adv_pgd is not None: - adv_sources = _run_adv_pgd( + adv_sources = run_adv_pgd( model=model, tokens=tokens, - ci_lower_leaky=ci_outputs.lower_leaky, + ci=ci_outputs.lower_leaky, alive_masks=alive_info.alive_masks, adv_config=config.adv_pgd, loss_config=config.loss_config, target_out=target_out, - device=device, ) pgd_mask_infos = make_mask_infos( - _interpolate_masks(ci_outputs.lower_leaky, adv_sources) + interpolate_pgd_mask(ci_outputs.lower_leaky, adv_sources) ) with bf16_autocast(): pgd_out = model(tokens, mask_infos=pgd_mask_infos) - pgd_loss = _compute_recon_loss(pgd_out, config.loss_config, target_out, device) + pgd_loss = compute_recon_loss(pgd_out, config.loss_config, target_out, device) total_loss = total_loss + config.loss_config.coeff * pgd_loss - if step % config.log_freq == 0 or step == config.steps - 1: - l0_stats = compute_l0_stats(ci_outputs, ci_alive_threshold=0.0) - - with torch.no_grad(): - ce_kl_stats = compute_specific_pos_ce_kl( - model=model, - batch=tokens, - target_out=target_out, - ci=ci_outputs.lower_leaky, - rounding_threshold=config.ce_kl_rounding_threshold, - loss_seq_pos=config.loss_config.position, - ) - - log_terms: dict[str, float] = { - "imp_min_loss": imp_min_loss.item(), - "total_loss": total_loss.item(), - "recon_loss": recon_loss.item(), - } - - if isinstance(config.loss_config, CELossConfig): - pos = config.loss_config.position - label_token = config.loss_config.label_token - recon_label_prob = F.softmax(recon_out[0, pos, :], dim=-1)[label_token] - log_terms["recon_masked_label_prob"] = recon_label_prob.item() - - with torch.no_grad(): - mask_infos = make_mask_infos(ci_outputs.lower_leaky, routing_masks="all") - logits = model(tokens, mask_infos=mask_infos) - probs = F.softmax(logits[0, pos, :], dim=-1) - log_terms["ci_masked_label_prob"] = float(probs[label_token].item()) - - tqdm.write(f"\n--- Step {step} ---") - for name, value in log_terms.items(): - tqdm.write(f" {name}: {value:.6f}") - for name, value in l0_stats.items(): - tqdm.write(f" {name}: {value:.2f}") - for name, value in ce_kl_stats.items(): - tqdm.write(f" {name}: {value:.6f}") - total_loss.backward() optimizer.step() # Compute final metrics after optimization with torch.no_grad(): final_ci_outputs = ci_params.create_ci_outputs(model, device) - final_l0_stats = compute_l0_stats(final_ci_outputs, ci_alive_threshold=0.0) + + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci_outputs.lower_leaky.values() + ) final_ci_masked_label_prob: float | None = None final_stoch_masked_label_prob: float | None = None - if isinstance(config.loss_config, CELossConfig): + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): pos = config.loss_config.position label_token = config.loss_config.label_token @@ -541,29 +470,26 @@ def optimize_ci_values( final_stoch_masked_label_prob = float(stoch_probs[label_token].item()) # Adversarial PGD final evaluation (needs gradients for PGD, so outside no_grad block) - adv_pgd_out_logits: Float[Tensor, "seq vocab"] | None = None final_adv_pgd_label_prob: float | None = None if config.adv_pgd is not None: - final_adv_sources = _run_adv_pgd( + final_adv_sources = run_adv_pgd( model=model, tokens=tokens, - ci_lower_leaky=final_ci_outputs.lower_leaky, + ci=final_ci_outputs.lower_leaky, alive_masks=alive_info.alive_masks, adv_config=config.adv_pgd, - loss_config=config.loss_config, target_out=target_out, - device=device, + loss_config=config.loss_config, ) with torch.no_grad(): adv_pgd_masks = make_mask_infos( - _interpolate_masks(final_ci_outputs.lower_leaky, final_adv_sources) + interpolate_pgd_mask(final_ci_outputs.lower_leaky, final_adv_sources) ) with bf16_autocast(): adv_logits = model(tokens, mask_infos=adv_pgd_masks) - adv_pgd_out_logits = adv_logits[0].detach() # [seq, vocab] - if isinstance(config.loss_config, CELossConfig): + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): pos = config.loss_config.position label_token = config.loss_config.label_token adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) @@ -573,16 +499,335 @@ def optimize_ci_values( ci_masked_label_prob=final_ci_masked_label_prob, stoch_masked_label_prob=final_stoch_masked_label_prob, adv_pgd_label_prob=final_adv_pgd_label_prob, - l0_total=final_l0_stats["l0/total"], + l0_total=total_l0, ) return OptimizeCIResult( params=ci_params, metrics=metrics, - adv_pgd_out_logits=adv_pgd_out_logits, ) +def compute_recon_loss_batched( + logits: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, + target_out: Float[Tensor, "N seq vocab"], + device: str, +) -> Float[Tensor, " N"]: + """Compute per-element reconstruction loss for batched logits.""" + match loss_config: + case CELossConfig(position=pos, label_token=label_token): + labels = torch.full((logits.shape[0],), label_token, device=device) + return F.cross_entropy(logits[:, pos, :], labels, reduction="none") + case KLLossConfig(position=pos): + target_probs = F.softmax(target_out[:, pos, :], dim=-1) + pred_log_probs = F.log_softmax(logits[:, pos, :], dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1) + case LogitLossConfig(position=pos, label_token=label_token): + return -logits[:, pos, label_token] + case MeanKLLossConfig(): + target_probs = F.softmax(target_out, dim=-1) + pred_log_probs = F.log_softmax(logits, dim=-1) + return F.kl_div(pred_log_probs, target_probs, reduction="none").sum(dim=-1).mean(dim=-1) + + +def importance_minimality_loss_per_element( + ci_upper_leaky_batched: dict[str, Float[Tensor, "N seq C"]], + n_batch: int, + current_frac_of_training: float, + pnorm: float, + beta: float, + eps: float, + p_anneal_start_frac: float, + p_anneal_final_p: float | None, + p_anneal_end_frac: float, +) -> Float[Tensor, " N"]: + """Compute importance minimality loss independently for each batch element.""" + losses = [] + for i in range(n_batch): + element_ci = {k: v[i : i + 1] for k, v in ci_upper_leaky_batched.items()} + losses.append( + importance_minimality_loss( + ci_upper_leaky=element_ci, + current_frac_of_training=current_frac_of_training, + pnorm=pnorm, + beta=beta, + eps=eps, + p_anneal_start_frac=p_anneal_start_frac, + p_anneal_final_p=p_anneal_final_p, + p_anneal_end_frac=p_anneal_end_frac, + ) + ) + return torch.stack(losses) + + +def run_adv_pgd_batched( + model: ComponentModel, + tokens: Float[Tensor, "N seq"], + ci: dict[str, Float[Tensor, "N seq C"]], + alive_masks: dict[str, Bool[Tensor, "N seq C"]], + adv_config: AdvPGDConfig, + target_out: Float[Tensor, "N seq vocab"], + loss_config: LossConfig, +) -> dict[str, Float[Tensor, "N seq C"]]: + """Run PGD adversary with batched tensors. Returns detached adversarial sources.""" + ci_detached = {k: v.detach() for k, v in ci.items()} + + adv_sources: dict[str, Tensor] = {} + for layer_name, ci_val in ci_detached.items(): + source = get_pgd_init_tensor(adv_config.init, tuple(ci_val.shape), str(ci_val.device)) + source[~alive_masks[layer_name]] = 0.0 + source.requires_grad_(True) + adv_sources[layer_name] = source + + source_list = list(adv_sources.values()) + + for _ in range(adv_config.n_steps): + mask_infos = make_mask_infos(interpolate_pgd_mask(ci_detached, adv_sources)) + + with bf16_autocast(): + out = model(tokens, mask_infos=mask_infos) + + losses = compute_recon_loss_batched(out, loss_config, target_out, str(tokens.device)) + loss = losses.sum() + + grads = torch.autograd.grad(loss, source_list) + with torch.no_grad(): + for (layer_name, source), grad in zip(adv_sources.items(), grads, strict=True): + source.add_(adv_config.step_size * grad.sign()) + source.clamp_(0.0, 1.0) + source[~alive_masks[layer_name]] = 0.0 + + return {k: v.detach() for k, v in adv_sources.items()} + + +def optimize_ci_values_batched( + model: ComponentModel, + tokens: Float[Tensor, "1 seq"], + configs: list[OptimCIConfig], + device: str, + on_progress: ProgressCallback | None = None, + on_ci_snapshot: CISnapshotCallback | None = None, +) -> list[OptimizeCIResult]: + """Optimize CI values for N sparsity coefficients in a single batched loop. + + All configs must share the same loss_config, steps, mask_type, adv_pgd settings — + only imp_min_config.coeff varies between them. + """ + N = len(configs) + assert N > 0 + + config = configs[0] + imp_min_coeffs = torch.tensor([c.imp_min_config.coeff for c in configs], device=device) + for c in configs: + assert c.imp_min_config.coeff is not None + + model.requires_grad_(False) + + with torch.no_grad(), bf16_autocast(): + output_with_cache: OutputWithCache = model(tokens, cache_type="input") + initial_ci_outputs = model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling=config.sampling, + detach_inputs=False, + ) + target_out = output_with_cache.output.detach() + + alive_info = compute_alive_info(initial_ci_outputs.lower_leaky) + + ci_params_list = [ + create_optimizable_ci_params( + alive_info=alive_info, + initial_pre_sigmoid=initial_ci_outputs.pre_sigmoid, + ) + for _ in range(N) + ] + + weight_deltas = model.calc_weight_deltas() + + all_params: list[Tensor] = [] + for ci_params in ci_params_list: + all_params.extend(ci_params.get_parameters()) + + optimizer = optim.AdamW(all_params, lr=config.lr, weight_decay=config.weight_decay) + + tokens_batched = tokens.expand(N, -1) + target_out_batched = target_out.expand(N, -1, -1) + + snapshot_layers = list(alive_info.alive_counts.keys()) + snapshot_initial_alive = [alive_info.alive_counts[layer] for layer in snapshot_layers] + snapshot_seq_len = tokens.shape[1] + + progress_interval = max(1, config.steps // 20) + latest_loss = 0.0 + + for step in tqdm(range(config.steps), desc="Optimizing CI values (batched)"): + if step % progress_interval == 0: + if on_progress is not None: + on_progress(step, config.steps, "optimizing") + + if on_ci_snapshot is not None: + with torch.no_grad(): + snap_ci = ci_params_list[0].create_ci_outputs(model, device) + current_alive = [ + (snap_ci.lower_leaky[layer][0] > 0.0).sum(dim=-1).tolist() + for layer in snapshot_layers + ] + on_ci_snapshot( + CISnapshot( + step=step, + total_steps=config.steps, + layers=snapshot_layers, + seq_len=snapshot_seq_len, + initial_alive=snapshot_initial_alive, + current_alive=current_alive, + l0_total=sum(sum(row) for row in current_alive), + loss=latest_loss, + ) + ) + + optimizer.zero_grad() + + ci_outputs_list = [cp.create_ci_outputs(model, device) for cp in ci_params_list] + + layers = list(ci_outputs_list[0].lower_leaky.keys()) + batched_ci_lower_leaky: dict[str, Tensor] = { + layer: torch.cat([co.lower_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + batched_ci_upper_leaky: dict[str, Tensor] = { + layer: torch.cat([co.upper_leaky[layer] for co in ci_outputs_list], dim=0) + for layer in layers + } + + match config.mask_type: + case "stochastic": + recon_mask_infos = calc_stochastic_component_mask_info( + causal_importances=batched_ci_lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + case "ci": + recon_mask_infos = make_mask_infos(component_masks=batched_ci_lower_leaky) + + with bf16_autocast(): + recon_out = model(tokens_batched, mask_infos=recon_mask_infos) + + imp_min_losses = importance_minimality_loss_per_element( + ci_upper_leaky_batched=batched_ci_upper_leaky, + n_batch=N, + current_frac_of_training=step / config.steps, + pnorm=config.imp_min_config.pnorm, + beta=config.imp_min_config.beta, + eps=config.imp_min_config.eps, + p_anneal_start_frac=config.imp_min_config.p_anneal_start_frac, + p_anneal_final_p=config.imp_min_config.p_anneal_final_p, + p_anneal_end_frac=config.imp_min_config.p_anneal_end_frac, + ) + + recon_losses = compute_recon_loss_batched( + recon_out, config.loss_config, target_out_batched, device + ) + + loss_coeff = config.loss_config.coeff + total_loss = (loss_coeff * recon_losses + imp_min_coeffs * imp_min_losses).sum() + latest_loss = total_loss.item() + + if config.adv_pgd is not None: + batched_alive_masks = { + k: v.expand(N, -1, -1) for k, v in alive_info.alive_masks.items() + } + adv_sources = run_adv_pgd_batched( + model=model, + tokens=tokens_batched, + ci=batched_ci_lower_leaky, + alive_masks=batched_alive_masks, + adv_config=config.adv_pgd, + target_out=target_out_batched, + loss_config=config.loss_config, + ) + pgd_masks = interpolate_pgd_mask(batched_ci_lower_leaky, adv_sources) + pgd_mask_infos = make_mask_infos(pgd_masks) + with bf16_autocast(): + pgd_out = model(tokens_batched, mask_infos=pgd_mask_infos) + pgd_losses = compute_recon_loss_batched( + pgd_out, config.loss_config, target_out_batched, device + ) + total_loss = total_loss + (loss_coeff * pgd_losses).sum() + + total_loss.backward() + optimizer.step() + + # Compute final metrics per element + results: list[OptimizeCIResult] = [] + for ci_params in ci_params_list: + with torch.no_grad(): + final_ci = ci_params.create_ci_outputs(model, device) + total_l0 = sum( + calc_ci_l_zero(layer_ci, 0.0) for layer_ci in final_ci.lower_leaky.values() + ) + + ci_masked_label_prob: float | None = None + stoch_masked_label_prob: float | None = None + + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + + ci_mask_infos = make_mask_infos(final_ci.lower_leaky, routing_masks="all") + ci_logits = model(tokens, mask_infos=ci_mask_infos) + ci_probs = F.softmax(ci_logits[0, pos, :], dim=-1) + ci_masked_label_prob = float(ci_probs[label_token].item()) + + stoch_mask_infos = calc_stochastic_component_mask_info( + causal_importances=final_ci.lower_leaky, + component_mask_sampling=config.sampling, + weight_deltas=weight_deltas, + router=AllLayersRouter(), + ) + stoch_logits = model(tokens, mask_infos=stoch_mask_infos) + stoch_probs = F.softmax(stoch_logits[0, pos, :], dim=-1) + stoch_masked_label_prob = float(stoch_probs[label_token].item()) + + adv_pgd_label_prob: float | None = None + if config.adv_pgd is not None: + final_adv_sources = run_adv_pgd( + model=model, + tokens=tokens, + ci=final_ci.lower_leaky, + alive_masks=alive_info.alive_masks, + adv_config=config.adv_pgd, + target_out=target_out, + loss_config=config.loss_config, + ) + with torch.no_grad(): + adv_masks = make_mask_infos( + interpolate_pgd_mask(final_ci.lower_leaky, final_adv_sources) + ) + with bf16_autocast(): + adv_logits = model(tokens, mask_infos=adv_masks) + if isinstance(config.loss_config, CELossConfig | LogitLossConfig): + pos = config.loss_config.position + label_token = config.loss_config.label_token + adv_probs = F.softmax(adv_logits[0, pos, :], dim=-1) + adv_pgd_label_prob = float(adv_probs[label_token].item()) + + results.append( + OptimizeCIResult( + params=ci_params, + metrics=OptimizationMetrics( + ci_masked_label_prob=ci_masked_label_prob, + stoch_masked_label_prob=stoch_masked_label_prob, + adv_pgd_label_prob=adv_pgd_label_prob, + l0_total=total_l0, + ), + ) + ) + + return results + + def get_out_dir() -> Path: """Get the output directory for optimization results.""" out_dir = Path(__file__).parent / "out" diff --git a/spd/app/backend/routers/__init__.py b/spd/app/backend/routers/__init__.py index b7a6f8ed3..7b1729fbb 100644 --- a/spd/app/backend/routers/__init__.py +++ b/spd/app/backend/routers/__init__.py @@ -7,10 +7,14 @@ from spd.app.backend.routers.data_sources import router as data_sources_router from spd.app.backend.routers.dataset_attributions import router as dataset_attributions_router from spd.app.backend.routers.dataset_search import router as dataset_search_router +from spd.app.backend.routers.graph_interp import router as graph_interp_router from spd.app.backend.routers.graphs import router as graphs_router from spd.app.backend.routers.intervention import router as intervention_router +from spd.app.backend.routers.investigations import router as investigations_router +from spd.app.backend.routers.mcp import router as mcp_router from spd.app.backend.routers.pretrain_info import router as pretrain_info_router from spd.app.backend.routers.prompts import router as prompts_router +from spd.app.backend.routers.run_registry import router as run_registry_router from spd.app.backend.routers.runs import router as runs_router __all__ = [ @@ -20,10 +24,14 @@ "correlations_router", "data_sources_router", "dataset_attributions_router", + "graph_interp_router", "dataset_search_router", "graphs_router", "intervention_router", + "investigations_router", + "mcp_router", "pretrain_info_router", "prompts_router", + "run_registry_router", "runs_router", ] diff --git a/spd/app/backend/routers/clusters.py b/spd/app/backend/routers/clusters.py index e2dbae37a..b2dc1d5b9 100644 --- a/spd/app/backend/routers/clusters.py +++ b/spd/app/backend/routers/clusters.py @@ -10,6 +10,7 @@ from spd.app.backend.utils import log_errors from spd.base_config import BaseConfig from spd.settings import SPD_OUT_DIR +from spd.topology import TransformerTopology router = APIRouter(prefix="/api/clusters", tags=["clusters"]) @@ -86,4 +87,17 @@ def load_cluster_mapping(file_path: str) -> ClusterMapping: f"but loaded run is '{run_state.run.wandb_path}'", ) - return ClusterMapping(mapping=parsed.clusters) + canonical_clusters = _to_canonical_keys(parsed.clusters, run_state.topology) + return ClusterMapping(mapping=canonical_clusters) + + +def _to_canonical_keys( + clusters: dict[str, int | None], topology: TransformerTopology +) -> dict[str, int | None]: + """Convert concrete component keys (e.g. 'h.3.mlp.down_proj:5') to canonical (e.g. '3.mlp.down:5').""" + result: dict[str, int | None] = {} + for key, cluster_id in clusters.items(): + layer, idx = key.rsplit(":", 1) + canonical_layer = topology.target_to_canon(layer) + result[f"{canonical_layer}:{idx}"] = cluster_id + return result diff --git a/spd/app/backend/routers/data_sources.py b/spd/app/backend/routers/data_sources.py index 5287d91bd..6888b339f 100644 --- a/spd/app/backend/routers/data_sources.py +++ b/spd/app/backend/routers/data_sources.py @@ -28,15 +28,21 @@ class AutointerpInfo(BaseModel): class AttributionsInfo(BaseModel): subrun_id: str - n_batches_processed: int n_tokens_processed: int ci_threshold: float +class GraphInterpInfo(BaseModel): + subrun_id: str + config: dict[str, Any] | None + label_counts: dict[str, int] + + class DataSourcesResponse(BaseModel): harvest: HarvestInfo | None autointerp: AutointerpInfo | None attributions: AttributionsInfo | None + graph_interp: GraphInterpInfo | None router = APIRouter(prefix="/api/data_sources", tags=["data_sources"]) @@ -70,13 +76,21 @@ def get_data_sources(loaded: DepLoadedRun) -> DataSourcesResponse: storage = loaded.attributions.get_attributions() attributions_info = AttributionsInfo( subrun_id=loaded.attributions.subrun_id, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, ci_threshold=storage.ci_threshold, ) + graph_interp_info: GraphInterpInfo | None = None + if loaded.graph_interp is not None: + graph_interp_info = GraphInterpInfo( + subrun_id=loaded.graph_interp.subrun_id, + config=loaded.graph_interp.get_config(), + label_counts=loaded.graph_interp.get_label_counts(), + ) + return DataSourcesResponse( harvest=harvest_info, autointerp=autointerp_info, attributions=attributions_info, + graph_interp=graph_interp_info, ) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 4c3d07753..178eefc72 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -7,46 +7,43 @@ from typing import Annotated, Literal from fastapi import APIRouter, HTTPException, Query -from jaxtyping import Float from pydantic import BaseModel -from torch import Tensor from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry -from spd.dataset_attributions.storage import DatasetAttributionStorage +ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs"] -class DatasetAttributionEntry(BaseModel): - """A single entry in attribution results.""" +class DatasetAttributionEntry(BaseModel): component_key: str layer: str component_idx: int value: float + token_str: str | None = None class DatasetAttributionMetadata(BaseModel): - """Metadata about dataset attributions availability.""" - available: bool - n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - vocab_size: int | None - d_model: int | None ci_threshold: float | None class ComponentAttributions(BaseModel): - """All attribution data for a single component (sources and targets, positive and negative).""" - positive_sources: list[DatasetAttributionEntry] negative_sources: list[DatasetAttributionEntry] positive_targets: list[DatasetAttributionEntry] negative_targets: list[DatasetAttributionEntry] +class AllMetricAttributions(BaseModel): + attr: ComponentAttributions + attr_abs: ComponentAttributions + + router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"]) NOT_AVAILABLE_MSG = ( @@ -54,91 +51,67 @@ class ComponentAttributions(BaseModel): ) -def _to_concrete_key(canonical_layer: str, component_idx: int, loaded: DepLoadedRun) -> str: - """Translate canonical layer + idx to concrete storage key. - - "embed" maps to the concrete embedding path (e.g. "wte") in storage. - "output" is a pseudo-layer used as-is in storage. - """ - if canonical_layer == "output": - return f"output:{component_idx}" - concrete = loaded.topology.canon_to_target(canonical_layer) - return f"{concrete}:{component_idx}" - - def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: - """Get storage or raise 404.""" if loaded.attributions is None: raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG) return loaded.attributions.get_attributions() -def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a source or raise 404.""" - if not storage.has_source(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as source in attributions", - ) - - -def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a target or raise 404.""" - if not storage.has_target(component_key): - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found as target in attributions", - ) - - -def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: - """Get the unembedding matrix from the loaded model.""" - return loaded.topology.get_unembed_weight() - - def _to_api_entries( - loaded: DepLoadedRun, entries: list[StorageEntry] + entries: list[StorageEntry], loaded: DepLoadedRun ) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format with canonical keys.""" - - def _canonicalize_layer(layer: str) -> str: - if layer == "output": - return layer - return loaded.topology.target_to_canon(layer) - return [ DatasetAttributionEntry( - component_key=f"{_canonicalize_layer(e.layer)}:{e.component_idx}", - layer=_canonicalize_layer(e.layer), + component_key=e.component_key, + layer=e.layer, component_idx=e.component_idx, value=e.value, + token_str=loaded.tokenizer.decode([e.component_idx]) + if e.layer in ("embed", "output") + else None, ) for e in entries ] +def _get_component_attributions_for_metric( + storage: DatasetAttributionStorage, + loaded: DepLoadedRun, + component_key: str, + k: int, + metric: AttrMetric, +) -> ComponentAttributions: + return ComponentAttributions( + positive_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "positive", metric), loaded + ), + negative_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "negative", metric), loaded + ), + positive_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "positive", metric), loaded + ), + negative_targets=_to_api_entries( + storage.get_top_targets(component_key, k, "negative", metric), loaded + ), + ) + + @router.get("/metadata") @log_errors def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata: - """Get metadata about dataset attributions availability.""" if loaded.attributions is None: return DatasetAttributionMetadata( available=False, - n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - vocab_size=None, - d_model=None, ci_threshold=None, ) storage = loaded.attributions.get_attributions() return DatasetAttributionMetadata( available=True, - n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - vocab_size=storage.vocab_size, - d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) @@ -150,58 +123,18 @@ def get_component_attributions( component_idx: int, loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, -) -> ComponentAttributions: - """Get all attribution data for a component (sources and targets, positive and negative).""" +) -> AllMetricAttributions: + """Get all attribution data for a component across all metrics.""" storage = _require_storage(loaded) - component_key = _to_concrete_key(layer, component_idx, loaded) - - # Component can be both a source and a target, so we need to check both - is_source = storage.has_source(component_key) - is_target = storage.has_target(component_key) - - if not is_source and not is_target: - raise HTTPException( - status_code=404, - detail=f"Component {component_key} not found in attributions", - ) - - w_unembed = _get_w_unembed(loaded) if is_source else None - - return ComponentAttributions( - positive_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "positive") - ) - if is_target - else [], - negative_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "negative") - ) - if is_target - else [], - positive_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "positive", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], - negative_targets=_to_api_entries( - loaded, - storage.get_top_targets( - component_key, - k, - "negative", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], + component_key = f"{layer}:{component_idx}" + + return AllMetricAttributions( + **{ + metric: _get_component_attributions_for_metric( + storage, loaded, component_key, k, metric + ) + for metric in ATTR_METRICS + } ) @@ -213,16 +146,11 @@ def get_attribution_sources( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) - target_key = _to_concrete_key(layer, component_idx, loaded) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if layer == "output" else None - return _to_api_entries( - loaded, storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed) + storage.get_top_sources(f"{layer}:{component_idx}", k, sign, metric), loaded ) @@ -234,35 +162,9 @@ def get_attribution_targets( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO over the dataset.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(layer, component_idx, loaded) - _require_source(storage, source_key) - - w_unembed = _get_w_unembed(loaded) - return _to_api_entries( - loaded, storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed) + storage.get_top_targets(f"{layer}:{component_idx}", k, sign, metric), loaded ) - - -@router.get("/between/{source_layer}/{source_idx}/{target_layer}/{target_idx}") -@log_errors -def get_attribution_between( - source_layer: str, - source_idx: int, - target_layer: str, - target_idx: int, - loaded: DepLoadedRun, -) -> float: - """Get attribution strength from source component to target component.""" - storage = _require_storage(loaded) - source_key = _to_concrete_key(source_layer, source_idx, loaded) - target_key = _to_concrete_key(target_layer, target_idx, loaded) - _require_source(storage, source_key) - _require_target(storage, target_key) - - w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None - - return storage.get_attribution(source_key, target_key, w_unembed=w_unembed) diff --git a/spd/app/backend/routers/graph_interp.py b/spd/app/backend/routers/graph_interp.py new file mode 100644 index 000000000..272bd4d80 --- /dev/null +++ b/spd/app/backend/routers/graph_interp.py @@ -0,0 +1,373 @@ +"""Graph interpretation endpoints. + +Serves context-aware component labels (output/input/unified) and the +prompt-edge graph produced by the graph_interp pipeline. +""" + +import random + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.app.backend.utils import log_errors +from spd.graph_interp.schemas import LabelResult +from spd.topology import TransformerTopology + +# TODO(oli): Remove MOCK_MODE after real data is available +MOCK_MODE = False + +MAX_GRAPH_NODES = 500 + + +_ALREADY_CANONICAL = {"embed", "output"} + + +def _concrete_to_canonical_key(concrete_key: str, topology: TransformerTopology) -> str: + layer, idx = concrete_key.rsplit(":", 1) + if layer in _ALREADY_CANONICAL: + return concrete_key + canonical = topology.target_to_canon(layer) + return f"{canonical}:{idx}" + + +def _canonical_to_concrete_key( + canonical_layer: str, component_idx: int, topology: TransformerTopology +) -> str: + concrete = topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +# -- Schemas ------------------------------------------------------------------- + + +class GraphInterpHeadline(BaseModel): + label: str + confidence: str + output_label: str | None + input_label: str | None + + +class LabelDetail(BaseModel): + label: str + confidence: str + reasoning: str + prompt: str + + +class GraphInterpDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + + +class PromptEdgeResponse(BaseModel): + related_key: str + pass_name: str + attribution: float + related_label: str | None + related_confidence: str | None + token_str: str | None + + +class GraphInterpComponentDetail(BaseModel): + output: LabelDetail | None + input: LabelDetail | None + unified: LabelDetail | None + edges: list[PromptEdgeResponse] + + +class GraphNode(BaseModel): + component_key: str + label: str + confidence: str + + +class GraphEdge(BaseModel): + source: str + target: str + attribution: float + pass_name: str + + +class ModelGraphResponse(BaseModel): + nodes: list[GraphNode] + edges: list[GraphEdge] + + +# -- Router -------------------------------------------------------------------- + +router = APIRouter(prefix="/api/graph_interp", tags=["graph_interp"]) + + +@router.get("/labels") +@log_errors +def get_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + if MOCK_MODE: + return _mock_all_labels(loaded) + + repo = loaded.graph_interp + if repo is None: + return {} + + topology = loaded.topology + unified = repo.get_all_unified_labels() + output = repo.get_all_output_labels() + input_ = repo.get_all_input_labels() + + all_keys = set(unified) | set(output) | set(input_) + result: dict[str, GraphInterpHeadline] = {} + + for concrete_key in all_keys: + u = unified.get(concrete_key) + o = output.get(concrete_key) + i = input_.get(concrete_key) + + label = u or o or i + assert label is not None + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + + result[canonical_key] = GraphInterpHeadline( + label=label.label, + confidence=label.confidence, + output_label=o.label if o else None, + input_label=i.label if i else None, + ) + + return result + + +def _to_detail(label: LabelResult | None) -> LabelDetail | None: + if label is None: + return None + return LabelDetail( + label=label.label, + confidence=label.confidence, + reasoning=label.reasoning, + prompt=label.prompt, + ) + + +@router.get("/labels/{layer}/{c_idx}") +@log_errors +def get_label_detail(layer: str, c_idx: int, loaded: DepLoadedRun) -> GraphInterpDetail: + if MOCK_MODE: + return _mock_label_detail(layer, c_idx) + + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + concrete_key = _canonical_to_concrete_key(layer, c_idx, loaded.topology) + + return GraphInterpDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + ) + + +@router.get("/detail/{layer}/{c_idx}") +@log_errors +def get_component_detail( + layer: str, c_idx: int, loaded: DepLoadedRun +) -> GraphInterpComponentDetail: + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + concrete_key = _canonical_to_concrete_key(layer, c_idx, topology) + + raw_edges = repo.get_prompt_edges(concrete_key) + tokenizer = loaded.tokenizer + edges = [] + for e in raw_edges: + rel_layer, rel_idx = e.related_key.rsplit(":", 1) + token_str = tokenizer.decode([int(rel_idx)]) if rel_layer in ("embed", "output") else None + edges.append( + PromptEdgeResponse( + related_key=_concrete_to_canonical_key(e.related_key, topology), + pass_name=e.pass_name, + attribution=e.attribution, + related_label=e.related_label, + related_confidence=e.related_confidence, + token_str=token_str, + ) + ) + + return GraphInterpComponentDetail( + output=_to_detail(repo.get_output_label(concrete_key)), + input=_to_detail(repo.get_input_label(concrete_key)), + unified=_to_detail(repo.get_unified_label(concrete_key)), + edges=edges, + ) + + +@router.get("/graph") +@log_errors +def get_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + if MOCK_MODE: + return _mock_model_graph(loaded) + + repo = loaded.graph_interp + if repo is None: + raise HTTPException(status_code=404, detail="Graph interp data not available") + + topology = loaded.topology + + unified = repo.get_all_unified_labels() + nodes = [] + for concrete_key, label in unified.items(): + canonical_key = _concrete_to_canonical_key(concrete_key, topology) + nodes.append( + GraphNode( + component_key=canonical_key, + label=label.label, + confidence=label.confidence, + ) + ) + + nodes = nodes[:MAX_GRAPH_NODES] + node_keys = {n.component_key for n in nodes} + + raw_edges = repo.get_all_prompt_edges() + edges = [] + for e in raw_edges: + comp_canon = _concrete_to_canonical_key(e.component_key, topology) + rel_canon = _concrete_to_canonical_key(e.related_key, topology) + + match e.pass_name: + case "output": + source, target = comp_canon, rel_canon + case "input": + source, target = rel_canon, comp_canon + + if source not in node_keys or target not in node_keys: + continue + + edges.append( + GraphEdge( + source=source, + target=target, + attribution=e.attribution, + pass_name=e.pass_name, + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) + + +# -- Mock data (TODO: remove after real data available) ------------------------ + +_MOCK_LABELS = [ + "sentence-final punctuation", + "proper noun completion", + "emotional adjective selection", + "temporal adverb prediction", + "morphological suffix (-ing/-ed)", + "determiner before noun", + "dialogue quotation marks", + "plural noun suffix", + "clause boundary detection", + "verb tense agreement", + "spatial preposition", + "possessive pronoun", + "narrative action verb", + "abstract emotion noun", + "comparative adjective form", + "subject-verb agreement", + "article selection (a/the)", + "comma splice detection", + "pronoun resolution", + "negation scope", +] + +_MOCK_INPUT_LABELS = [ + "sentence-initial capitals", + "mid-sentence verb position", + "adjective-noun boundary", + "clause-final position", + "article-noun sequence", + "subject pronoun at boundary", + "preposition-object pair", + "verb stem before suffix", + "quotation boundary", + "comma-separated items", +] + + +def _mock_all_labels(loaded: DepLoadedRun) -> dict[str, GraphInterpHeadline]: + rng = random.Random(42) + topology = loaded.topology + confidences = ["high", "high", "high", "medium", "medium", "low"] + + result: dict[str, GraphInterpHeadline] = {} + for target_path, components in loaded.model.components.items(): + canon = topology.target_to_canon(target_path) + n_components = components.C + n_mock = min(n_components, rng.randint(5, 20)) + indices = sorted(rng.sample(range(n_components), n_mock)) + for idx in indices: + key = f"{canon}:{idx}" + result[key] = GraphInterpHeadline( + label=rng.choice(_MOCK_LABELS), + confidence=rng.choice(confidences), + output_label=rng.choice(_MOCK_LABELS), + input_label=rng.choice(_MOCK_INPUT_LABELS), + ) + return result + + +def _mock_label_detail(layer: str, c_idx: int) -> GraphInterpDetail: + rng = random.Random(hash((layer, c_idx))) + conf = rng.choice(["high", "medium", "low"]) + return GraphInterpDetail( + output=LabelDetail( + label=rng.choice(_MOCK_LABELS), + confidence=conf, + reasoning=f"Output: Component {layer}:{c_idx} writes {rng.choice(_MOCK_LABELS).lower()} tokens to the residual stream.", + prompt="(mock prompt)", + ), + input=LabelDetail( + label=rng.choice(_MOCK_INPUT_LABELS), + confidence=conf, + reasoning=f"Input: Component {layer}:{c_idx} fires on {rng.choice(_MOCK_INPUT_LABELS).lower()} patterns.", + prompt="(mock prompt)", + ), + unified=LabelDetail( + label=rng.choice(_MOCK_LABELS), + confidence=conf, + reasoning=f"Unified: Combines output ({rng.choice(_MOCK_LABELS).lower()}) and input ({rng.choice(_MOCK_INPUT_LABELS).lower()}) functions.", + prompt="(mock prompt)", + ), + ) + + +def _mock_model_graph(loaded: DepLoadedRun) -> ModelGraphResponse: + labels = _mock_all_labels(loaded) + + nodes = [ + GraphNode(component_key=key, label=h.label, confidence=h.confidence) + for key, h in labels.items() + ] + + rng = random.Random(42) + keys = list(labels.keys()) + edges: list[GraphEdge] = [] + + for key in keys: + layer = key.rsplit(":", 1)[0] + later_keys = [k for k in keys if k.rsplit(":", 1)[0] != layer] + n_edges = rng.randint(1, 4) + for target in rng.sample(later_keys, min(n_edges, len(later_keys))): + edges.append( + GraphEdge( + source=key, + target=target, + attribution=rng.uniform(-1.0, 1.0), + pass_name=rng.choice(["output", "input"]), + ) + ) + + return ModelGraphResponse(nodes=nodes, edges=edges) diff --git a/spd/app/backend/routers/graphs.py b/spd/app/backend/routers/graphs.py index a51b1649c..e2b439da1 100644 --- a/spd/app/backend/routers/graphs.py +++ b/spd/app/backend/routers/graphs.py @@ -19,26 +19,90 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, Edge, + compute_intervention, compute_prompt_attributions, compute_prompt_attributions_optimized, + compute_prompt_attributions_optimized_batched, +) +from spd.app.backend.database import ( + GraphType, + OptimizationParams, + PgdConfig, + PromptAttrDB, + StoredGraph, ) -from spd.app.backend.database import GraphType, OptimizationParams, StoredGraph from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.optim_cis import ( AdvPGDConfig, CELossConfig, + CISnapshot, KLLossConfig, + LogitLossConfig, LossConfig, MaskType, + MeanKLLossConfig, OptimCIConfig, ) from spd.app.backend.schemas import OutputProbability from spd.app.backend.utils import log_errors -from spd.configs import ImportanceMinimalityLossConfig +from spd.configs import ImportanceMinimalityLossConfig, SamplingType from spd.log import logger +from spd.models.component_model import ComponentModel +from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device +NON_INTERVENTABLE_LAYERS = {"embed", "output"} + + +def _save_base_intervention_run( + graph_id: int, + model: ComponentModel, + tokens: torch.Tensor, + node_ci_vals: dict[str, float], + tokenizer: AppTokenizer, + topology: TransformerTopology, + db: PromptAttrDB, + sampling: SamplingType, + loss_config: LossConfig | None = None, +) -> None: + """Compute intervention for all interventable nodes and save as an intervention run.""" + interventable_keys = [ + k + for k, ci in node_ci_vals.items() + if k.split(":")[0] not in NON_INTERVENTABLE_LAYERS and ci > 0 + ] + assert len(interventable_keys) > 0, "No interventable nodes with CI > 0" + + active_nodes: list[tuple[str, int, int]] = [] + for key in interventable_keys: + canon_layer, seq_str, cidx_str = key.split(":") + concrete_path = topology.canon_to_target(canon_layer) + active_nodes.append((concrete_path, int(seq_str), int(cidx_str))) + + effective_loss_config: LossConfig = ( + loss_config if loss_config is not None else MeanKLLossConfig() + ) + + result = compute_intervention( + model=model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=effective_loss_config, + sampling=sampling, + top_k=10, + ) + + db.save_intervention_run( + graph_id=graph_id, + selected_nodes=interventable_keys, + result_json=result.model_dump_json(), + ) + class EdgeData(BaseModel): """Edge in the attribution graph.""" @@ -65,12 +129,14 @@ class GraphData(BaseModel): graphType: GraphType tokens: list[str] edges: list[EdgeData] + edgesAbs: list[EdgeData] | None = None # absolute-target variant, None for old graphs outputProbs: dict[str, OutputProbability] nodeCiVals: dict[ str, float ] # node key -> CI value (or output prob for output nodes or 1 for embed node) nodeSubcompActs: dict[str, float] # node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: float # max absolute edge value + maxAbsAttrAbs: float | None = None # max absolute edge value for abs-target variant maxAbsSubcompAct: float # max absolute subcomponent activation for normalization l0_total: int # total active components at current CI threshold @@ -93,6 +159,16 @@ class KLLossResult(BaseModel): position: int +class LogitLossResult(BaseModel): + """Logit loss result (maximize pre-softmax logit).""" + + type: Literal["logit"] = "logit" + coeff: float + position: int + label_token: int + label_str: str + + class OptimizationMetricsResult(BaseModel): """Final loss metrics from CI optimization.""" @@ -112,10 +188,9 @@ class OptimizationResult(BaseModel): pnorm: float beta: float mask_type: MaskType - loss: CELossResult | KLLossResult + loss: CELossResult | KLLossResult | LogitLossResult metrics: OptimizationMetricsResult - adv_pgd_n_steps: int | None = None - adv_pgd_step_size: float | None = None + pgd: PgdConfig | None = None class GraphDataWithOptimization(GraphData): @@ -156,19 +231,6 @@ class TokenizeResponse(BaseModel): next_token_probs: list[float | None] # Probability of next token (last token is None) -class TokenInfo(BaseModel): - """A single token from the tokenizer vocabulary.""" - - id: int - string: str - - -class TokensResponse(BaseModel): - """Response containing all tokens in the vocabulary.""" - - tokens: list[TokenInfo] - - # SSE streaming message types class ProgressMessage(BaseModel): """Progress update during streaming computation.""" @@ -200,6 +262,12 @@ class CompleteMessageWithOptimization(BaseModel): data: GraphDataWithOptimization +class BatchGraphResult(BaseModel): + """Batch optimization result containing multiple graphs.""" + + graphs: list[GraphDataWithOptimization] + + router = APIRouter(prefix="/api/graphs", tags=["graphs"]) DEVICE = get_device() @@ -218,7 +286,6 @@ def _build_out_probs( ci_masked_out_logits: torch.Tensor, target_out_logits: torch.Tensor, tok_display: Callable[[int], str], - adv_pgd_out_logits: torch.Tensor | None = None, ) -> dict[str, OutputProbability]: """Build output probs dict from logit tensors. @@ -226,9 +293,6 @@ def _build_out_probs( """ ci_masked_out_probs = torch.softmax(ci_masked_out_logits, dim=-1) target_out_probs = torch.softmax(target_out_logits, dim=-1) - adv_pgd_out_probs = ( - torch.softmax(adv_pgd_out_logits, dim=-1) if adv_pgd_out_logits is not None else None - ) out_probs: dict[str, OutputProbability] = {} for s in range(ci_masked_out_probs.shape[0]): @@ -243,65 +307,78 @@ def _build_out_probs( target_prob = float(target_out_probs[s, c_idx].item()) target_logit = float(target_out_logits[s, c_idx].item()) - adv_pgd_prob: float | None = None - adv_pgd_logit: float | None = None - if adv_pgd_out_probs is not None and adv_pgd_out_logits is not None: - adv_pgd_prob = round(float(adv_pgd_out_probs[s, c_idx].item()), 6) - adv_pgd_logit = round(float(adv_pgd_out_logits[s, c_idx].item()), 4) - key = f"{s}:{c_idx}" out_probs[key] = OutputProbability( prob=round(prob, 6), logit=round(logit, 4), target_prob=round(target_prob, 6), target_logit=round(target_logit, 4), - adv_pgd_prob=adv_pgd_prob, - adv_pgd_logit=adv_pgd_logit, token=tok_display(c_idx), ) return out_probs +CISnapshotCallback = Callable[[CISnapshot], None] + + def stream_computation( - work: Callable[[ProgressCallback], GraphData | GraphDataWithOptimization], + work: Callable[[ProgressCallback, CISnapshotCallback | None], BaseModel], + gpu_lock: threading.Lock, ) -> StreamingResponse: - """Run graph computation in a thread with SSE streaming for progress updates.""" + """Run graph computation in a thread with SSE streaming for progress updates. + + Acquires gpu_lock before starting and holds it until computation completes. + Raises 503 if the lock is already held by another operation. + """ + # Try to acquire lock non-blocking - fail fast if GPU is busy + if not gpu_lock.acquire(blocking=False): + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() def on_progress(current: int, total: int, stage: str) -> None: progress_queue.put({"type": "progress", "current": current, "total": total, "stage": stage}) + def on_ci_snapshot(snapshot: CISnapshot) -> None: + progress_queue.put({"type": "ci_snapshot", **snapshot.model_dump()}) + def compute_thread() -> None: try: - result = work(on_progress) + result = work(on_progress, on_ci_snapshot) progress_queue.put({"type": "result", "result": result}) except Exception as e: traceback.print_exc(file=sys.stderr) progress_queue.put({"type": "error", "error": str(e)}) def generate() -> Generator[str]: - thread = threading.Thread(target=compute_thread) - thread.start() - - while True: - try: - msg = progress_queue.get(timeout=0.1) - except queue.Empty: - if not thread.is_alive(): + try: + thread = threading.Thread(target=compute_thread) + thread.start() + + while True: + try: + msg = progress_queue.get(timeout=0.1) + except queue.Empty: + if not thread.is_alive(): + break + continue + + if msg["type"] in ("progress", "ci_snapshot"): + yield f"data: {json.dumps(msg)}\n\n" + elif msg["type"] == "error": + yield f"data: {json.dumps(msg)}\n\n" + break + elif msg["type"] == "result": + complete_data = {"type": "complete", "data": msg["result"].model_dump()} + yield f"data: {json.dumps(complete_data)}\n\n" break - continue - - if msg["type"] == "progress": - yield f"data: {json.dumps(msg)}\n\n" - elif msg["type"] == "error": - yield f"data: {json.dumps(msg)}\n\n" - break - elif msg["type"] == "result": - complete_data = {"type": "complete", "data": msg["result"].model_dump()} - yield f"data: {json.dumps(complete_data)}\n\n" - break - thread.join() + thread.join() + finally: + gpu_lock.release() return StreamingResponse(generate(), media_type="text/event-stream") @@ -343,40 +420,55 @@ def tokenize_text(text: str, loaded: DepLoadedRun) -> TokenizeResponse: ) -@router.get("/tokens") -@log_errors -def get_all_tokens(loaded: DepLoadedRun) -> TokensResponse: - """Get all tokens in the tokenizer vocabulary for client-side search.""" - tokens = [ - TokenInfo(id=tid, string=loaded.tokenizer.get_tok_display(tid)) - for tid in range(loaded.tokenizer.vocab_size) - ] - return TokensResponse(tokens=tokens) +class TokenSearchResult(BaseModel): + """A token search result with model probability at the queried position.""" + + id: int + string: str + prob: float class TokenSearchResponse(BaseModel): """Response from token search endpoint.""" - tokens: list[TokenInfo] + tokens: list[TokenSearchResult] @router.get("/tokens/search") @log_errors def search_tokens( q: Annotated[str, Query(min_length=1)], + prompt_id: Annotated[int, Query()], + position: Annotated[int, Query()], loaded: DepLoadedRun, - limit: Annotated[int, Query(ge=1, le=50)] = 10, + manager: DepStateManager, + limit: Annotated[int, Query(ge=1, le=50)] = 20, ) -> TokenSearchResponse: - """Search tokens by substring match. Returns up to `limit` results.""" + """Search tokens by substring match, sorted by target model probability at position.""" + prompt = manager.state.db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail=f"prompt {prompt_id} not found") + if not (0 <= position < len(prompt.token_ids)): + raise HTTPException( + status_code=422, + detail=f"position {position} out of range for prompt with {len(prompt.token_ids)} tokens", + ) + + device = next(loaded.model.parameters()).device + tokens_tensor = torch.tensor([prompt.token_ids], device=device) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits[0, position], dim=-1) + query = q.lower() - matches: list[TokenInfo] = [] + matches: list[TokenSearchResult] = [] for tid in range(loaded.tokenizer.vocab_size): string = loaded.tokenizer.get_tok_display(tid) if query in string.lower(): - matches.append(TokenInfo(id=tid, string=string)) - if len(matches) >= limit: - break - return TokenSearchResponse(tokens=matches) + matches.append(TokenSearchResult(id=tid, string=string, prob=probs[tid].item())) + + matches.sort(key=lambda m: m.prob, reverse=True) + return TokenSearchResponse(tokens=matches[:limit]) NormalizeType = Literal["none", "target", "layer"] @@ -450,7 +542,9 @@ def compute_graph_stream( spans = loaded.tokenizer.get_spans(token_ids) tokens_tensor = torch.tensor([token_ids], device=DEVICE) - def work(on_progress: ProgressCallback) -> GraphData: + def work( + on_progress: ProgressCallback, _on_ci_snapshot: CISnapshotCallback | None + ) -> GraphData: t_total = time.perf_counter() result = compute_prompt_attributions( @@ -474,6 +568,7 @@ def work(on_progress: ProgressCallback) -> GraphData: graph=StoredGraph( graph_type=graph_type, edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, node_ci_vals=result.node_ci_vals, @@ -483,6 +578,19 @@ def work(on_progress: ProgressCallback) -> GraphData: ) logger.info(f"[perf] save_graph: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + ) + logger.info(f"[perf] base intervention run: {time.perf_counter() - t0:.2f}s") + t0 = time.perf_counter() fg = filter_graph_for_display( raw_edges=result.edges, @@ -494,6 +602,7 @@ def work(on_progress: ProgressCallback) -> GraphData: num_tokens=len(token_ids), ci_threshold=ci_threshold, normalize=normalize, + raw_edges_abs=result.edges_abs, ) logger.info( f"[perf] filter_graph: {time.perf_counter() - t0:.2f}s ({len(fg.edges)} edges after filter)" @@ -505,15 +614,17 @@ def work(on_progress: ProgressCallback) -> GraphData: graphType=graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, ) - return stream_computation(work) + return stream_computation(work, manager._gpu_lock) def _edge_to_edge_data(edge: Edge) -> EdgeData: @@ -557,7 +668,7 @@ def get_group_key(edge: Edge) -> str: return out_edges -LossType = Literal["ce", "kl"] +LossType = Literal["ce", "kl", "logit"] @router.post("/optimized/stream") @@ -597,6 +708,14 @@ def compute_graph_optimized_stream( ) case "kl": loss_config = KLLossConfig(coeff=loss_coeff, position=loss_position) + case "logit": + if label_token is None: + raise HTTPException( + status_code=400, detail="label_token is required for logit loss" + ) + loss_config = LogitLossConfig( + coeff=loss_coeff, position=loss_position, label_token=label_token + ) lr = 1e-2 @@ -627,8 +746,9 @@ def compute_graph_optimized_stream( beta=beta, mask_type=mask_type, loss=loss_config, - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) + if adv_pgd_n_steps is not None and adv_pgd_step_size is not None + else None, ) optim_config = OptimCIConfig( @@ -650,7 +770,9 @@ def compute_graph_optimized_stream( else None, ) - def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> GraphDataWithOptimization: result = compute_prompt_attributions_optimized( model=loaded.model, topology=loaded.topology, @@ -660,28 +782,42 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: output_prob_threshold=0.01, device=DEVICE, on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, ) ci_masked_out_logits = result.ci_masked_out_logits.cpu() target_out_logits = result.target_out_logits.cpu() - adv_pgd_out_logits = ( - result.adv_pgd_out_logits.cpu() if result.adv_pgd_out_logits is not None else None - ) + + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob graph_id = db.save_graph( prompt_id=prompt_id, graph=StoredGraph( graph_type="optimized", edges=result.edges, + edges_abs=result.edges_abs, ci_masked_out_logits=ci_masked_out_logits, target_out_logits=target_out_logits, - adv_pgd_out_logits=adv_pgd_out_logits, node_ci_vals=result.node_ci_vals, node_subcomp_acts=result.node_subcomp_acts, optimization_params=opt_params, ), ) + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, + ) + fg = filter_graph_for_display( raw_edges=result.edges, node_ci_vals=result.node_ci_vals, @@ -692,11 +828,11 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: num_tokens=num_tokens, ci_threshold=ci_threshold, normalize=normalize, - adv_pgd_out_logits=adv_pgd_out_logits, + raw_edges_abs=result.edges_abs, ) # Build loss result based on config type - loss_result: CELossResult | KLLossResult + loss_result: CELossResult | KLLossResult | LogitLossResult match loss_config: case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): assert label_str is not None @@ -708,16 +844,26 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: ) case KLLossConfig(coeff=coeff, position=pos): loss_result = KLLossResult(coeff=coeff, position=pos) + case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): + assert label_str is not None + loss_result = LogitLossResult( + coeff=coeff, + position=pos, + label_token=label_tok, + label_str=label_str, + ) return GraphDataWithOptimization( id=graph_id, graphType="optimized", tokens=spans_sliced, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=result.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, optimization=OptimizationResult( @@ -733,12 +879,240 @@ def work(on_progress: ProgressCallback) -> GraphDataWithOptimization: adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, l0_total=result.metrics.l0_total, ), - adv_pgd_n_steps=adv_pgd_n_steps, - adv_pgd_step_size=adv_pgd_step_size, + pgd=PgdConfig(n_steps=adv_pgd_n_steps, step_size=adv_pgd_step_size) + if adv_pgd_n_steps is not None and adv_pgd_step_size is not None + else None, + ), + ) + + return stream_computation(work, manager._gpu_lock) + + +class BatchOptimizedRequest(BaseModel): + """Request body for batch optimized graph computation.""" + + prompt_id: int + imp_min_coeffs: list[float] + steps: int + pnorm: float + beta: float + normalize: NormalizeType + ci_threshold: float + mask_type: MaskType + loss_type: LossType + loss_coeff: float + loss_position: int + label_token: int | None = None + adv_pgd_n_steps: int | None = None + adv_pgd_step_size: float | None = None + + +@router.post("/optimized/batch/stream") +@log_errors +def compute_graph_optimized_batch_stream( + body: BatchOptimizedRequest, + loaded: DepLoadedRun, + manager: DepStateManager, +): + """Compute optimized graphs for multiple sparsity coefficients in one batched optimization. + + Returns N graphs (one per imp_min_coeff) via SSE streaming. + All coefficients share the same loss config, steps, and other hyperparameters. + """ + assert len(body.imp_min_coeffs) > 0, "At least one coefficient required" + assert len(body.imp_min_coeffs) <= 20, "Too many coefficients (max 20)" + + loss_config: LossConfig + match body.loss_type: + case "ce": + assert body.label_token is not None, "label_token is required for CE loss" + loss_config = CELossConfig( + coeff=body.loss_coeff, position=body.loss_position, label_token=body.label_token + ) + case "kl": + loss_config = KLLossConfig(coeff=body.loss_coeff, position=body.loss_position) + case "logit": + assert body.label_token is not None, "label_token is required for logit loss" + loss_config = LogitLossConfig( + coeff=body.loss_coeff, position=body.loss_position, label_token=body.label_token + ) + + lr = 1e-2 + + db = manager.db + prompt = db.get_prompt(body.prompt_id) + assert prompt is not None, f"prompt {body.prompt_id} not found" + + token_ids = prompt.token_ids + assert body.loss_position < len(token_ids), ( + f"loss_position {body.loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + label_str = ( + loaded.tokenizer.get_tok_display(body.label_token) if body.label_token is not None else None + ) + spans = loaded.tokenizer.get_spans(token_ids) + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + num_tokens = body.loss_position + 1 + spans_sliced = spans[:num_tokens] + + adv_pgd = ( + AdvPGDConfig(n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size, init="random") + if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None + else None + ) + + configs = [ + OptimCIConfig( + seed=0, + lr=lr, + steps=body.steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, body.steps // 4), + imp_min_config=ImportanceMinimalityLossConfig( + coeff=coeff, pnorm=body.pnorm, beta=body.beta ), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type=body.mask_type, + adv_pgd=adv_pgd, ) + for coeff in body.imp_min_coeffs + ] - return stream_computation(work) + def work( + on_progress: ProgressCallback, on_ci_snapshot: CISnapshotCallback | None + ) -> BatchGraphResult: + results = compute_prompt_attributions_optimized_batched( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + configs=configs, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + on_ci_snapshot=on_ci_snapshot, + ) + + graphs: list[GraphDataWithOptimization] = [] + for result, coeff in zip(results, body.imp_min_coeffs, strict=True): + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + opt_params = OptimizationParams( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=loss_config, + pgd=PgdConfig(n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size) + if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None + else None, + ) + opt_params.ci_masked_label_prob = result.metrics.ci_masked_label_prob + opt_params.stoch_masked_label_prob = result.metrics.stoch_masked_label_prob + opt_params.adv_pgd_label_prob = result.metrics.adv_pgd_label_prob + + graph_id = db.save_graph( + prompt_id=body.prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + _save_base_intervention_run( + graph_id=graph_id, + model=loaded.model, + tokens=tokens_tensor, + node_ci_vals=result.node_ci_vals, + tokenizer=loaded.tokenizer, + topology=loaded.topology, + db=db, + sampling=loaded.config.sampling, + loss_config=loss_config, + ) + + fg = filter_graph_for_display( + raw_edges=result.edges, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + tok_display=loaded.tokenizer.get_tok_display, + num_tokens=num_tokens, + ci_threshold=body.ci_threshold, + normalize=body.normalize, + raw_edges_abs=result.edges_abs, + ) + + loss_result: CELossResult | KLLossResult | LogitLossResult + match loss_config: + case CELossConfig(coeff=lc, position=pos, label_token=label_tok): + assert label_str is not None + loss_result = CELossResult( + coeff=lc, position=pos, label_token=label_tok, label_str=label_str + ) + case KLLossConfig(coeff=lc, position=pos): + loss_result = KLLossResult(coeff=lc, position=pos) + case LogitLossConfig(coeff=lc, position=pos, label_token=label_tok): + assert label_str is not None + loss_result = LogitLossResult( + coeff=lc, position=pos, label_token=label_tok, label_str=label_str + ) + + graphs.append( + GraphDataWithOptimization( + id=graph_id, + graphType="optimized", + tokens=spans_sliced, + edges=fg.edges, + edgesAbs=fg.edges_abs, + outputProbs=fg.out_probs, + nodeCiVals=fg.node_ci_vals, + nodeSubcompActs=result.node_subcomp_acts, + maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, + maxAbsSubcompAct=fg.max_abs_subcomp_act, + l0_total=fg.l0_total, + optimization=OptimizationResult( + imp_min_coeff=coeff, + steps=body.steps, + pnorm=body.pnorm, + beta=body.beta, + mask_type=body.mask_type, + loss=loss_result, + metrics=OptimizationMetricsResult( + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + l0_total=result.metrics.l0_total, + ), + pgd=PgdConfig( + n_steps=body.adv_pgd_n_steps, step_size=body.adv_pgd_step_size + ) + if body.adv_pgd_n_steps is not None and body.adv_pgd_step_size is not None + else None, + ), + ) + ) + + return BatchGraphResult(graphs=graphs) + + return stream_computation(work, manager._gpu_lock) @dataclass @@ -746,9 +1120,11 @@ class FilteredGraph: """Result of filtering a raw graph for display.""" edges: list[EdgeData] + edges_abs: list[EdgeData] | None # absolute-target variant, None for old graphs node_ci_vals: dict[str, float] # with pseudo nodes out_probs: dict[str, OutputProbability] max_abs_attr: float + max_abs_attr_abs: float | None # max abs for absolute-target edges max_abs_subcomp_act: float l0_total: int @@ -763,8 +1139,8 @@ def filter_graph_for_display( num_tokens: int, ci_threshold: float, normalize: NormalizeType, + raw_edges_abs: list[Edge] | None = None, edge_limit: int = GLOBAL_EDGE_LIMIT, - adv_pgd_out_logits: torch.Tensor | None = None, ) -> FilteredGraph: """Filter and transform a raw attribution graph for display. @@ -775,9 +1151,7 @@ def filter_graph_for_display( 5. Normalize edge strengths (if requested) 6. Cap edges at edge_limit """ - out_probs = _build_out_probs( - ci_masked_out_logits, target_out_logits, tok_display, adv_pgd_out_logits - ) + out_probs = _build_out_probs(ci_masked_out_logits, target_out_logits, tok_display) filtered_node_ci_vals = {k: v for k, v in node_ci_vals.items() if v > ci_threshold} @@ -789,25 +1163,33 @@ def filter_graph_for_display( seq_pos, token_id = key.split(":") node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob - # Filter edges to only those connecting surviving nodes + # Filter, normalize, sort, and truncate an edge list to the surviving node set. node_keys = set(node_ci_vals_with_pseudo.keys()) - edges = [e for e in raw_edges if str(e.source) in node_keys and str(e.target) in node_keys] - edges = _normalize_edges(edges=edges, normalize=normalize) - max_abs_attr = compute_max_abs_attr(edges=edges) + def _filter_edges(raw: list[Edge]) -> tuple[list[EdgeData], float]: + filtered = [e for e in raw if str(e.source) in node_keys and str(e.target) in node_keys] + filtered = _normalize_edges(edges=filtered, normalize=normalize) + max_abs = compute_max_abs_attr(edges=filtered) + filtered = sorted(filtered, key=lambda e: abs(e.strength), reverse=True) + if len(filtered) > edge_limit: + logger.warning(f"Edge limit {edge_limit} exceeded ({len(filtered)} edges), truncating") + filtered = filtered[:edge_limit] + return [_edge_to_edge_data(e) for e in filtered], max_abs - # Always sort by abs(strength) desc so frontend can just slice(0, topK) without re-sorting - edges = sorted(edges, key=lambda e: abs(e.strength), reverse=True) + edges_out, max_abs_attr = _filter_edges(raw_edges) - if len(edges) > edge_limit: - logger.warning(f"Edge limit {edge_limit} exceeded ({len(edges)} edges), truncating") - edges = edges[:edge_limit] + edges_abs_out: list[EdgeData] | None = None + max_abs_attr_abs: float | None = None + if raw_edges_abs is not None: + edges_abs_out, max_abs_attr_abs = _filter_edges(raw_edges_abs) return FilteredGraph( - edges=[_edge_to_edge_data(e) for e in edges], + edges=edges_out, + edges_abs=edges_abs_out, node_ci_vals=node_ci_vals_with_pseudo, out_probs=out_probs, max_abs_attr=max_abs_attr, + max_abs_attr_abs=max_abs_attr_abs, max_abs_subcomp_act=compute_max_abs_subcomp_act(node_subcomp_acts), l0_total=len(filtered_node_ci_vals), ) @@ -840,7 +1222,7 @@ def stored_graph_to_response( num_tokens=num_tokens, ci_threshold=ci_threshold, normalize=normalize, - adv_pgd_out_logits=graph.adv_pgd_out_logits, + raw_edges_abs=graph.edges_abs, ) if not is_optimized: @@ -849,10 +1231,12 @@ def stored_graph_to_response( graphType=graph.graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, ) @@ -861,7 +1245,7 @@ def stored_graph_to_response( opt = graph.optimization_params # Build loss result based on stored config type - loss_result: CELossResult | KLLossResult + loss_result: CELossResult | KLLossResult | LogitLossResult match opt.loss: case CELossConfig(coeff=coeff, position=pos, label_token=label_tok): label_str = tokenizer.get_tok_display(label_tok) @@ -873,16 +1257,26 @@ def stored_graph_to_response( ) case KLLossConfig(coeff=coeff, position=pos): loss_result = KLLossResult(coeff=coeff, position=pos) + case LogitLossConfig(coeff=coeff, position=pos, label_token=label_tok): + label_str = tokenizer.get_tok_display(label_tok) + loss_result = LogitLossResult( + coeff=coeff, + position=pos, + label_token=label_tok, + label_str=label_str, + ) return GraphDataWithOptimization( id=graph.id, graphType=graph.graph_type, tokens=spans, edges=fg.edges, + edgesAbs=fg.edges_abs, outputProbs=fg.out_probs, nodeCiVals=fg.node_ci_vals, nodeSubcompActs=graph.node_subcomp_acts, maxAbsAttr=fg.max_abs_attr, + maxAbsAttrAbs=fg.max_abs_attr_abs, maxAbsSubcompAct=fg.max_abs_subcomp_act, l0_total=fg.l0_total, optimization=OptimizationResult( @@ -892,10 +1286,15 @@ def stored_graph_to_response( beta=opt.beta, mask_type=opt.mask_type, loss=loss_result, - # Metrics not stored in DB for cached graphs - use l0_total from graph - metrics=OptimizationMetricsResult(l0_total=float(fg.l0_total)), - adv_pgd_n_steps=opt.adv_pgd_n_steps, - adv_pgd_step_size=opt.adv_pgd_step_size, + metrics=OptimizationMetricsResult( + l0_total=float(fg.l0_total), + ci_masked_label_prob=opt.ci_masked_label_prob, + stoch_masked_label_prob=opt.stoch_masked_label_prob, + adv_pgd_label_prob=opt.adv_pgd_label_prob, + ), + pgd=PgdConfig(n_steps=opt.pgd.n_steps, step_size=opt.pgd.step_size) + if opt.pgd is not None + else None, ), ) diff --git a/spd/app/backend/routers/intervention.py b/spd/app/backend/routers/intervention.py index 1ccdbf86c..e26a73462 100644 --- a/spd/app/backend/routers/intervention.py +++ b/spd/app/backend/routers/intervention.py @@ -4,8 +4,12 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel -from spd.app.backend.compute import compute_intervention_forward +from spd.app.backend.compute import ( + InterventionResult, + compute_intervention, +) from spd.app.backend.dependencies import DepDB, DepLoadedRun, DepStateManager +from spd.app.backend.optim_cis import AdvPGDConfig, LossConfig, MeanKLLossConfig from spd.app.backend.utils import log_errors from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device @@ -15,56 +19,19 @@ # ============================================================================= -class InterventionNode(BaseModel): - """A specific node to activate during intervention.""" - - layer: str - seq_pos: int - component_idx: int - - -class InterventionRequest(BaseModel): - """Request for intervention forward pass.""" - - text: str - nodes: list[InterventionNode] - top_k: int - - -class TokenPrediction(BaseModel): - """A single token prediction with probability.""" - - token: str - token_id: int - spd_prob: float - target_prob: float - logit: float - target_logit: float - - -class InterventionResponse(BaseModel): - """Response from intervention forward pass.""" - - input_tokens: list[str] - predictions_per_position: list[list[TokenPrediction]] +class AdvPgdParams(BaseModel): + n_steps: int + step_size: float class RunInterventionRequest(BaseModel): """Request to run and save an intervention.""" graph_id: int - text: str selected_nodes: list[str] # node keys (layer:seq:cIdx) - top_k: int = 10 - - -class ForkedInterventionRunSummary(BaseModel): - """Summary of a forked intervention run with modified tokens.""" - - id: int - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - result: InterventionResponse - created_at: str + nodes_to_ablate: list[str] | None = None # node keys to ablate in ablated (omit to skip) + top_k: int + adv_pgd: AdvPgdParams class InterventionRunSummary(BaseModel): @@ -72,16 +39,8 @@ class InterventionRunSummary(BaseModel): id: int selected_nodes: list[str] - result: InterventionResponse + result: InterventionResult created_at: str - forked_runs: list[ForkedInterventionRunSummary] - - -class ForkInterventionRequest(BaseModel): - """Request to fork an intervention run with modified tokens.""" - - token_replacements: list[tuple[int, int]] # [(seq_pos, new_token_id), ...] - top_k: int = 10 router = APIRouter(prefix="/api/intervention", tags=["intervention"]) @@ -104,100 +63,15 @@ def _parse_node_key(key: str, topology: TransformerTopology) -> tuple[str, int, return concrete_path, int(seq_str), int(cidx_str) -def _run_intervention_forward( - text: str, - selected_nodes: list[str], - top_k: int, - loaded: DepLoadedRun, -) -> InterventionResponse: - """Run intervention forward pass and return response.""" - token_ids = loaded.tokenizer.encode(text) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [_parse_node_key(key, loaded.topology) for key in selected_nodes] - - seq_len = tokens.shape[1] - for _, seq_pos, _ in active_nodes: - if seq_pos >= seq_len: - raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) - - -@router.post("") -@log_errors -def run_intervention(request: InterventionRequest, loaded: DepLoadedRun) -> InterventionResponse: - """Run intervention forward pass with specified nodes active (legacy endpoint).""" - token_ids = loaded.tokenizer.encode(request.text) - tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) - - active_nodes = [ - ( - loaded.topology.canon_to_target(n.layer), - n.seq_pos, - n.component_idx, - ) - for n in request.nodes - ] - - seq_len = tokens.shape[1] +def _parse_and_validate_active_nodes( + selected_nodes: list[str], topology: TransformerTopology, seq_len: int +) -> list[tuple[str, int, int]]: + """Parse node keys and validate sequence bounds for the current prompt.""" + active_nodes = [_parse_node_key(key, topology) for key in selected_nodes] for _, seq_pos, _ in active_nodes: if seq_pos >= seq_len: raise ValueError(f"seq_pos {seq_pos} out of bounds for text with {seq_len} tokens") - - result = compute_intervention_forward( - model=loaded.model, - tokens=tokens, - active_nodes=active_nodes, - top_k=request.top_k, - tokenizer=loaded.tokenizer, - ) - - predictions_per_position = [ - [ - TokenPrediction( - token=token, - token_id=token_id, - spd_prob=spd_prob, - target_prob=target_prob, - logit=logit, - target_logit=target_logit, - ) - for token, token_id, spd_prob, logit, target_prob, target_logit in pos_predictions - ] - for pos_predictions in result.predictions_per_position - ] - - return InterventionResponse( - input_tokens=result.input_tokens, - predictions_per_position=predictions_per_position, - ) + return active_nodes @router.post("/run") @@ -206,19 +80,59 @@ def run_and_save_intervention( request: RunInterventionRequest, loaded: DepLoadedRun, db: DepDB, + manager: DepStateManager, ) -> InterventionRunSummary: """Run an intervention and save the result.""" - response = _run_intervention_forward( - text=request.text, - selected_nodes=request.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) + with manager.gpu_lock(): + graph_record = db.get_graph(request.graph_id) + if graph_record is None: + raise HTTPException(status_code=404, detail="Graph not found") + graph, prompt_id = graph_record + + prompt = db.get_prompt(prompt_id) + if prompt is None: + raise HTTPException(status_code=404, detail="Prompt not found") + + token_ids = prompt.token_ids + active_nodes = _parse_and_validate_active_nodes( + request.selected_nodes, loaded.topology, len(token_ids) + ) + nodes_to_ablate = ( + _parse_and_validate_active_nodes( + request.nodes_to_ablate, loaded.topology, len(token_ids) + ) + if request.nodes_to_ablate is not None + else None + ) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + # Use graph's loss config if optimized, else mean KL + loss_config: LossConfig = ( + graph.optimization_params.loss + if graph.optimization_params is not None + else MeanKLLossConfig() + ) + + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=nodes_to_ablate, + tokenizer=loaded.tokenizer, + adv_pgd_config=AdvPGDConfig( + n_steps=request.adv_pgd.n_steps, + step_size=request.adv_pgd.step_size, + init="random", + ), + loss_config=loss_config, + sampling=loaded.config.sampling, + top_k=request.top_k, + ) run_id = db.save_intervention_run( graph_id=request.graph_id, selected_nodes=request.selected_nodes, - result_json=response.model_dump_json(), + result_json=result.model_dump_json(), ) record = db.get_intervention_runs(request.graph_id) @@ -228,41 +142,25 @@ def run_and_save_intervention( return InterventionRunSummary( id=run_id, selected_nodes=request.selected_nodes, - result=response, + result=result, created_at=saved_run.created_at, - forked_runs=[], ) @router.get("/runs/{graph_id}") @log_errors def get_intervention_runs(graph_id: int, db: DepDB) -> list[InterventionRunSummary]: - """Get all intervention runs for a graph, including forked runs.""" + """Get all intervention runs for a graph.""" records = db.get_intervention_runs(graph_id) - results = [] - for r in records: - # Get forked runs for this intervention run - forked_records = db.get_forked_intervention_runs(r.id) - forked_runs = [ - ForkedInterventionRunSummary( - id=fr.id, - token_replacements=fr.token_replacements, - result=InterventionResponse.model_validate_json(fr.result_json), - created_at=fr.created_at, - ) - for fr in forked_records - ] - - results.append( - InterventionRunSummary( - id=r.id, - selected_nodes=r.selected_nodes, - result=InterventionResponse.model_validate_json(r.result_json), - created_at=r.created_at, - forked_runs=forked_runs, - ) + return [ + InterventionRunSummary( + id=r.id, + selected_nodes=r.selected_nodes, + result=InterventionResult.model_validate_json(r.result_json), + created_at=r.created_at, ) - return results + for r in records + ] @router.delete("/runs/{run_id}") @@ -271,86 +169,3 @@ def delete_intervention_run(run_id: int, db: DepDB) -> dict[str, bool]: """Delete an intervention run.""" db.delete_intervention_run(run_id) return {"success": True} - - -@router.post("/runs/{run_id}/fork") -@log_errors -def fork_intervention_run( - run_id: int, - request: ForkInterventionRequest, - loaded: DepLoadedRun, - manager: DepStateManager, -) -> ForkedInterventionRunSummary: - """Fork an intervention run with modified tokens. - - Takes the same selected_nodes from the parent run, applies token replacements - to the original prompt, and runs the intervention forward pass. - """ - db = manager.db - - # Get the parent intervention run - parent_run = db.get_intervention_run(run_id) - if parent_run is None: - raise HTTPException(status_code=404, detail="Intervention run not found") - - # Get the prompt_id from the graph - conn = db._get_conn() - row = conn.execute( - "SELECT prompt_id FROM graphs WHERE id = ?", (parent_run.graph_id,) - ).fetchone() - if row is None: - raise HTTPException(status_code=404, detail="Graph not found") - prompt_id = row["prompt_id"] - - # Get the prompt to get original token_ids - prompt = db.get_prompt(prompt_id) - if prompt is None: - raise HTTPException(status_code=404, detail="Prompt not found") - - # Apply token replacements to get modified token_ids - modified_token_ids = list(prompt.token_ids) # Make a copy - for seq_pos, new_token_id in request.token_replacements: - if seq_pos < 0 or seq_pos >= len(modified_token_ids): - raise HTTPException( - status_code=400, - detail=f"Invalid seq_pos {seq_pos} for prompt with {len(modified_token_ids)} tokens", - ) - modified_token_ids[seq_pos] = new_token_id - - # Decode the modified tokens back to text - modified_text = loaded.tokenizer.decode(modified_token_ids) - - # Run the intervention forward pass with modified tokens but same selected nodes - response = _run_intervention_forward( - text=modified_text, - selected_nodes=parent_run.selected_nodes, - top_k=request.top_k, - loaded=loaded, - ) - - # Save the forked run - fork_id = db.save_forked_intervention_run( - intervention_run_id=run_id, - token_replacements=request.token_replacements, - result_json=response.model_dump_json(), - ) - - # Get the saved record for created_at - forked_records = db.get_forked_intervention_runs(run_id) - saved_fork = next((f for f in forked_records if f.id == fork_id), None) - assert saved_fork is not None - - return ForkedInterventionRunSummary( - id=fork_id, - token_replacements=request.token_replacements, - result=response, - created_at=saved_fork.created_at, - ) - - -@router.delete("/forks/{fork_id}") -@log_errors -def delete_forked_intervention_run(fork_id: int, db: DepDB) -> dict[str, bool]: - """Delete a forked intervention run.""" - db.delete_forked_intervention_run(fork_id) - return {"success": True} diff --git a/spd/app/backend/routers/investigations.py b/spd/app/backend/routers/investigations.py new file mode 100644 index 000000000..0784a0eb4 --- /dev/null +++ b/spd/app/backend/routers/investigations.py @@ -0,0 +1,317 @@ +"""Investigations endpoint for viewing agent investigation results. + +Lists and serves investigation data from SPD_OUT_DIR/investigations/. +Each investigation directory contains findings from a single agent run. +""" + +import json +from datetime import datetime +from pathlib import Path +from typing import Any + +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +from spd.app.backend.dependencies import DepLoadedRun +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/investigations", tags=["investigations"]) + +INVESTIGATIONS_DIR = SPD_OUT_DIR / "investigations" + + +class InvestigationSummary(BaseModel): + """Summary of a single investigation.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + has_research_log: bool + has_explanations: bool + event_count: int + last_event_time: str | None + last_event_message: str | None + title: str | None + summary: str | None + status: str | None + + +class EventEntry(BaseModel): + """A single event from events.jsonl.""" + + event_type: str + timestamp: str + message: str + details: dict[str, Any] | None = None + + +class InvestigationDetail(BaseModel): + """Full detail of an investigation including logs.""" + + id: str + wandb_path: str | None + prompt: str | None + created_at: str + research_log: str | None + events: list[EventEntry] + explanations: list[dict[str, Any]] + artifact_ids: list[str] + title: str | None + summary: str | None + status: str | None + + +def _parse_metadata(inv_path: Path) -> dict[str, Any] | None: + """Parse metadata.json from an investigation directory.""" + metadata_path = inv_path / "metadata.json" + if not metadata_path.exists(): + return None + try: + data: dict[str, Any] = json.loads(metadata_path.read_text()) + return data + except Exception: + return None + + +def _get_last_event(events_path: Path) -> tuple[str | None, str | None, int]: + """Get the last event timestamp, message, and total count from events.jsonl.""" + if not events_path.exists(): + return None, None, 0 + + last_time = None + last_msg = None + count = 0 + + try: + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + count += 1 + try: + event = json.loads(line) + last_time = event.get("timestamp") + last_msg = event.get("message") + except json.JSONDecodeError: + continue + except Exception: + pass + + return last_time, last_msg, count + + +def _parse_task_summary(inv_path: Path) -> tuple[str | None, str | None, str | None]: + """Parse summary.json from an investigation directory. Returns (title, summary, status).""" + summary_path = inv_path / "summary.json" + if not summary_path.exists(): + return None, None, None + try: + data: dict[str, Any] = json.loads(summary_path.read_text()) + return data.get("title"), data.get("summary"), data.get("status") + except Exception: + return None, None, None + + +def _list_artifact_ids(inv_path: Path) -> list[str]: + """List all artifact IDs for an investigation.""" + artifacts_dir = inv_path / "artifacts" + if not artifacts_dir.exists(): + return [] + return [f.stem for f in sorted(artifacts_dir.glob("graph_*.json"))] + + +def _get_created_at(inv_path: Path, metadata: dict[str, Any] | None) -> str: + """Get creation time for an investigation.""" + events_path = inv_path / "events.jsonl" + if events_path.exists(): + try: + with open(events_path) as f: + first_line = f.readline().strip() + if first_line: + event = json.loads(first_line) + if "timestamp" in event: + return event["timestamp"] + except Exception: + pass + + if metadata and "created_at" in metadata: + return metadata["created_at"] + + return datetime.fromtimestamp(inv_path.stat().st_mtime).isoformat() + + +@router.get("") +def list_investigations(loaded: DepLoadedRun) -> list[InvestigationSummary]: + """List investigations for the currently loaded run.""" + if not INVESTIGATIONS_DIR.exists(): + return [] + + wandb_path = loaded.run.wandb_path + results = [] + + for inv_path in INVESTIGATIONS_DIR.iterdir(): + if not inv_path.is_dir() or not inv_path.name.startswith("inv-"): + continue + + inv_id = inv_path.name + metadata = _parse_metadata(inv_path) + + meta_wandb_path = metadata.get("wandb_path") if metadata else None + if meta_wandb_path is None: + continue + # Normalize to canonical form for comparison (strips "runs/", "wandb:" prefix, etc.) + try: + e, p, r = parse_wandb_run_path(meta_wandb_path) + canonical_meta_path = f"{e}/{p}/{r}" + except ValueError: + continue + if canonical_meta_path != wandb_path: + continue + + events_path = inv_path / "events.jsonl" + last_time, last_msg, event_count = _get_last_event(events_path) + title, summary, status = _parse_task_summary(inv_path) + + explanations_path = inv_path / "explanations.jsonl" + + results.append( + InvestigationSummary( + id=inv_id, + wandb_path=meta_wandb_path, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + has_research_log=(inv_path / "research_log.md").exists(), + has_explanations=explanations_path.exists() + and explanations_path.stat().st_size > 0, + event_count=event_count, + last_event_time=last_time, + last_event_message=last_msg, + title=title, + summary=summary, + status=status, + ) + ) + + results.sort(key=lambda x: x.created_at, reverse=True) + return results + + +@router.get("/{inv_id}") +def get_investigation(inv_id: str) -> InvestigationDetail: + """Get full details of an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + + if not inv_path.exists() or not inv_path.is_dir(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + + metadata = _parse_metadata(inv_path) + + research_log = None + research_log_path = inv_path / "research_log.md" + if research_log_path.exists(): + research_log = research_log_path.read_text() + + events = [] + events_path = inv_path / "events.jsonl" + if events_path.exists(): + with open(events_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + events.append( + EventEntry( + event_type=event.get("event_type", "unknown"), + timestamp=event.get("timestamp", ""), + message=event.get("message", ""), + details=event.get("details"), + ) + ) + except json.JSONDecodeError: + continue + + explanations: list[dict[str, Any]] = [] + explanations_path = inv_path / "explanations.jsonl" + if explanations_path.exists(): + with open(explanations_path) as f: + for line in f: + line = line.strip() + if not line: + continue + try: + explanations.append(json.loads(line)) + except json.JSONDecodeError: + continue + + title, summary, status = _parse_task_summary(inv_path) + artifact_ids = _list_artifact_ids(inv_path) + + return InvestigationDetail( + id=inv_id, + wandb_path=metadata.get("wandb_path") if metadata else None, + prompt=metadata.get("prompt") if metadata else None, + created_at=_get_created_at(inv_path, metadata), + research_log=research_log, + events=events, + explanations=explanations, + artifact_ids=artifact_ids, + title=title, + summary=summary, + status=status, + ) + + +class LaunchRequest(BaseModel): + prompt: str + + +class LaunchResponse(BaseModel): + inv_id: str + job_id: str + + +@router.post("/launch") +def launch_investigation_endpoint(request: LaunchRequest, loaded: DepLoadedRun) -> LaunchResponse: + """Launch a new investigation for the currently loaded run.""" + from spd.investigate.scripts.run_slurm import launch_investigation + + result = launch_investigation( + wandb_path=loaded.run.wandb_path, + prompt=request.prompt, + context_length=loaded.context_length, + max_turns=50, + partition="h200-reserved", + time="8:00:00", + job_suffix=None, + ) + return LaunchResponse(inv_id=result.inv_id, job_id=result.job_id) + + +@router.get("/{inv_id}/artifacts") +def list_artifacts(inv_id: str) -> list[str]: + """List all artifact IDs for an investigation.""" + inv_path = INVESTIGATIONS_DIR / inv_id + if not inv_path.exists(): + raise HTTPException(status_code=404, detail=f"Investigation {inv_id} not found") + return _list_artifact_ids(inv_path) + + +@router.get("/{inv_id}/artifacts/{artifact_id}") +def get_artifact(inv_id: str, artifact_id: str) -> dict[str, Any]: + """Get a specific artifact by ID.""" + inv_path = INVESTIGATIONS_DIR / inv_id + artifact_path = inv_path / "artifacts" / f"{artifact_id}.json" + + if not artifact_path.exists(): + raise HTTPException( + status_code=404, + detail=f"Artifact {artifact_id} not found in {inv_id}", + ) + + data: dict[str, Any] = json.loads(artifact_path.read_text()) + return data diff --git a/spd/app/backend/routers/mcp.py b/spd/app/backend/routers/mcp.py new file mode 100644 index 000000000..5094bc92c --- /dev/null +++ b/spd/app/backend/routers/mcp.py @@ -0,0 +1,1638 @@ +"""MCP (Model Context Protocol) endpoint for Claude Code integration. + +This router implements the MCP JSON-RPC protocol over HTTP, allowing Claude Code +to use SPD tools directly with proper schemas and streaming progress. + +MCP Spec: https://modelcontextprotocol.io/specification/2025-06-18/basic/transports +""" + +import inspect +import json +import queue +import threading +import traceback +from collections.abc import Callable, Generator +from dataclasses import dataclass +from datetime import UTC, datetime +from pathlib import Path +from typing import Any, Literal + +import torch +from fastapi import APIRouter, Request +from fastapi.responses import JSONResponse, StreamingResponse +from pydantic import BaseModel + +from spd.app.backend.compute import ( + compute_ci_only, + compute_prompt_attributions_optimized, +) +from spd.app.backend.database import StoredGraph +from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig +from spd.app.backend.routers.graphs import _build_out_probs +from spd.app.backend.routers.pretrain_info import _get_pretrain_info +from spd.app.backend.state import StateManager +from spd.configs import ImportanceMinimalityLossConfig +from spd.harvest import analysis +from spd.log import logger +from spd.utils.distributed_utils import get_device + +router = APIRouter(tags=["mcp"]) + +DEVICE = get_device() + +# MCP protocol version +MCP_PROTOCOL_VERSION = "2024-11-05" + + +@dataclass +class InvestigationConfig: + """Configuration for investigation mode. All paths are required when in investigation mode.""" + + events_log_path: Path + investigation_dir: Path + + +_investigation_config: InvestigationConfig | None = None + + +def set_investigation_config(config: InvestigationConfig) -> None: + """Configure MCP for investigation mode.""" + global _investigation_config + _investigation_config = config + + +def _log_event(event_type: str, message: str, details: dict[str, Any] | None = None) -> None: + """Log an event to the events file if in investigation mode.""" + if _investigation_config is None: + return + event = { + "event_type": event_type, + "timestamp": datetime.now(UTC).isoformat(), + "message": message, + "details": details or {}, + } + with open(_investigation_config.events_log_path, "a") as f: + f.write(json.dumps(event) + "\n") + + +# ============================================================================= +# MCP Protocol Types +# ============================================================================= + + +class MCPRequest(BaseModel): + """JSON-RPC 2.0 request.""" + + jsonrpc: Literal["2.0"] + id: int | str | None = None + method: str + params: dict[str, Any] | None = None + + +class MCPResponse(BaseModel): + """JSON-RPC 2.0 response. + + Per JSON-RPC 2.0 spec, exactly one of result/error must be present (not both, not neither). + Use model_dump(exclude_none=True) when serializing to avoid including null fields. + """ + + jsonrpc: Literal["2.0"] = "2.0" + id: int | str | None + result: Any | None = None + error: dict[str, Any] | None = None + + +class ToolDefinition(BaseModel): + """MCP tool definition.""" + + name: str + description: str + inputSchema: dict[str, Any] + + +# ============================================================================= +# Tool Definitions +# ============================================================================= + +TOOLS: list[ToolDefinition] = [ + ToolDefinition( + name="optimize_graph", + description="""Optimize a sparse circuit for a specific behavior. + +Given a prompt and target token, finds the minimal set of components that produce the target prediction. +Returns the optimized graph with component CI values and edges showing information flow. + +This is the primary tool for understanding how the model produces a specific output.""", + inputSchema={ + "type": "object", + "properties": { + "prompt_text": { + "type": "string", + "description": "The input text to analyze (e.g., 'The boy said that')", + }, + "target_token": { + "type": "string", + "description": "The token to predict (e.g., ' he'). Include leading space if needed.", + }, + "loss_position": { + "type": "integer", + "description": "Position to optimize prediction at (0-indexed, usually last position). If not specified, uses the last position.", + }, + "steps": { + "type": "integer", + "description": "Optimization steps (default: 100, more = sparser but slower)", + "default": 100, + }, + "ci_threshold": { + "type": "number", + "description": "CI threshold for including components (default: 0.5, lower = more components)", + "default": 0.5, + }, + }, + "required": ["prompt_text", "target_token"], + }, + ), + ToolDefinition( + name="get_component_info", + description="""Get detailed information about a component. + +Returns the component's interpretation (what it does), token statistics (what tokens +activate it and what it predicts), and correlated components. + +Use this to understand what role a component plays in a circuit.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up', '2.attn.o')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "top_k": { + "type": "integer", + "description": "Number of top tokens/correlations to return (default: 20)", + "default": 20, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="run_ablation", + description="""Run an ablation experiment with only selected components active. + +Tests a hypothesis by running the model with a sparse set of components. +Returns predictions showing what the circuit produces vs the full model. + +Use this to verify that identified components are necessary and sufficient.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "Input text for the ablation", + }, + "selected_nodes": { + "type": "array", + "items": {"type": "string"}, + "description": "Node keys to keep active (format: 'layer:seq_pos:component_idx')", + }, + "top_k": { + "type": "integer", + "description": "Number of top predictions to return per position (default: 10)", + "default": 10, + }, + }, + "required": ["text", "selected_nodes"], + }, + ), + ToolDefinition( + name="search_dataset", + description="""Search the SimpleStories training dataset for patterns. + +Finds stories containing the query string. Use this to find examples of +specific linguistic patterns (pronouns, verb forms, etc.) for investigation.""", + inputSchema={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Text to search for (case-insensitive)", + }, + "limit": { + "type": "integer", + "description": "Maximum results to return (default: 20)", + "default": 20, + }, + }, + "required": ["query"], + }, + ), + ToolDefinition( + name="create_prompt", + description="""Create a prompt for analysis. + +Tokenizes the text and returns token IDs and next-token probabilities. +The returned prompt_id can be used with other tools.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The text to create a prompt from", + }, + }, + "required": ["text"], + }, + ), + ToolDefinition( + name="update_research_log", + description="""Append content to your research log. + +Use this to document your investigation progress, findings, and next steps. +The research log is your primary output for humans to follow your work. + +Call this frequently (every few minutes) with updates on what you're doing.""", + inputSchema={ + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Markdown content to append to the research log", + }, + }, + "required": ["content"], + }, + ), + ToolDefinition( + name="save_explanation", + description="""Save a complete behavior explanation. + +Use this when you have finished investigating a behavior and want to document +your findings. This creates a structured record of the behavior, the components +involved, and your explanation of how they work together. + +Only call this for complete, validated explanations - not preliminary hypotheses.""", + inputSchema={ + "type": "object", + "properties": { + "subject_prompt": { + "type": "string", + "description": "A prompt that demonstrates the behavior", + }, + "behavior_description": { + "type": "string", + "description": "Clear description of the behavior", + }, + "components_involved": { + "type": "array", + "items": { + "type": "object", + "properties": { + "component_key": { + "type": "string", + "description": "Component key (e.g., '0.mlp.up:5')", + }, + "role": { + "type": "string", + "description": "The role this component plays", + }, + "interpretation": { + "type": "string", + "description": "Auto-interp label if available", + }, + }, + "required": ["component_key", "role"], + }, + "description": "List of components and their roles", + }, + "explanation": { + "type": "string", + "description": "How the components work together", + }, + "supporting_evidence": { + "type": "array", + "items": { + "type": "object", + "properties": { + "evidence_type": { + "type": "string", + "enum": [ + "ablation", + "attribution", + "activation_pattern", + "correlation", + "other", + ], + }, + "description": {"type": "string"}, + "details": {"type": "object"}, + }, + "required": ["evidence_type", "description"], + }, + "description": "Evidence supporting this explanation", + }, + "confidence": { + "type": "string", + "enum": ["high", "medium", "low"], + "description": "Your confidence level", + }, + "alternative_hypotheses": { + "type": "array", + "items": {"type": "string"}, + "description": "Other hypotheses you considered", + }, + "limitations": { + "type": "array", + "items": {"type": "string"}, + "description": "Known limitations of this explanation", + }, + }, + "required": [ + "subject_prompt", + "behavior_description", + "components_involved", + "explanation", + "confidence", + ], + }, + ), + ToolDefinition( + name="set_investigation_summary", + description="""Set a title and summary for your investigation. + +Call this when you've completed your investigation (or periodically as you make progress) +to provide a human-readable title and summary that will be shown in the investigations UI. + +The title should be short and descriptive. The summary should be 1-3 sentences +explaining what you investigated and what you found.""", + inputSchema={ + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "Short title for the investigation (e.g., 'Gendered Pronoun Circuit')", + }, + "summary": { + "type": "string", + "description": "Brief summary of findings (1-3 sentences)", + }, + "status": { + "type": "string", + "enum": ["in_progress", "completed", "inconclusive"], + "description": "Current status of the investigation", + "default": "in_progress", + }, + }, + "required": ["title", "summary"], + }, + ), + ToolDefinition( + name="save_graph_artifact", + description="""Save a graph as an artifact for inclusion in your research report. + +After calling optimize_graph and getting a graph_id, call this to save the graph +as an artifact. Then reference it in your research log using the spd:graph syntax: + +```spd:graph +artifact: graph_001 +``` + +This allows humans reviewing your investigation to see interactive circuit visualizations +inline with your research notes.""", + inputSchema={ + "type": "object", + "properties": { + "graph_id": { + "type": "integer", + "description": "The graph ID returned by optimize_graph", + }, + "caption": { + "type": "string", + "description": "Optional caption describing what this graph shows", + }, + }, + "required": ["graph_id"], + }, + ), + ToolDefinition( + name="probe_component", + description="""Fast CI probing on custom text. + +Computes causal importance values and subcomponent activations for a specific component +across all positions in the input text. Also returns next-token probabilities. + +Use this for quick, targeted analysis of how a component responds to specific inputs.""", + inputSchema={ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "The input text to probe", + }, + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + }, + "required": ["text", "layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_activation_examples", + description="""Get activation examples from harvest data for a component. + +Returns examples showing token windows where the component fires, along with +CI values and activation strengths at each position. + +Use this to understand what inputs activate a component.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up')", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "limit": { + "type": "integer", + "description": "Maximum number of examples to return (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_component_attributions", + description="""Get dataset-level component dependencies from pre-computed attributions. + +Returns the top source and target components that this component attributes to/from, +aggregated over the training dataset. Both positive and negative attributions are returned. + +Use this to understand a component's role in the broader network.""", + inputSchema={ + "type": "object", + "properties": { + "layer": { + "type": "string", + "description": "Canonical layer name (e.g., '0.mlp.up') or 'output'", + }, + "component_idx": { + "type": "integer", + "description": "Component index within the layer", + }, + "k": { + "type": "integer", + "description": "Number of top attributions to return per direction (default: 10)", + "default": 10, + }, + }, + "required": ["layer", "component_idx"], + }, + ), + ToolDefinition( + name="get_model_info", + description="""Get architecture details about the pretrained model. + +Returns model type, summary, target model config, topology, and pretrain info. +No parameters required.""", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), + ToolDefinition( + name="get_attribution_strength", + description="""Query the attribution strength between two specific components. + +Returns the dataset-aggregated attribution value from source to target. + +Use this to check if two components have a strong connection.""", + inputSchema={ + "type": "object", + "properties": { + "source_layer": { + "type": "string", + "description": "Canonical layer name of source component (e.g., '0.mlp.up')", + }, + "source_idx": { + "type": "integer", + "description": "Source component index", + }, + "target_layer": { + "type": "string", + "description": "Canonical layer name of target component (e.g., '1.attn.q') or 'output'", + }, + "target_idx": { + "type": "integer", + "description": "Target component index", + }, + }, + "required": ["source_layer", "source_idx", "target_layer", "target_idx"], + }, + ), +] + + +# ============================================================================= +# Tool Implementations +# ============================================================================= + + +def _get_state(): + """Get state manager and loaded run, raising clear errors if not available.""" + manager = StateManager.get() + if manager.run_state is None: + raise ValueError("No run loaded. The backend must load a run first.") + return manager, manager.run_state + + +def _concrete_key(canonical_layer: str, component_idx: int, loaded: Any) -> str: + """Translate canonical layer + idx to concrete storage key.""" + if canonical_layer == "output": + return f"output:{component_idx}" + concrete = loaded.topology.canon_to_target(canonical_layer) + return f"{concrete}:{component_idx}" + + +def _canonicalize_layer(layer: str, loaded: Any) -> str: + """Translate concrete layer name to canonical, passing through 'output'.""" + if layer == "output": + return layer + return loaded.topology.target_to_canon(layer) + + +def _canonicalize_key(concrete_key: str, loaded: Any) -> str: + """Translate concrete component key (e.g. 'h.0.mlp.c_fc:444') to canonical ('0.mlp.up:444').""" + layer, idx = concrete_key.rsplit(":", 1) + return f"{_canonicalize_layer(layer, loaded)}:{idx}" + + +def _tool_optimize_graph(params: dict[str, Any]) -> Generator[dict[str, Any]]: + """Optimize a sparse circuit for a behavior. Yields progress events.""" + manager, loaded = _get_state() + + prompt_text = params["prompt_text"] + target_token = params["target_token"] + steps = params.get("steps", 100) + ci_threshold = params.get("ci_threshold", 0.5) + + # Tokenize prompt + token_ids = loaded.tokenizer.encode(prompt_text) + if not token_ids: + raise ValueError("Prompt text produced no tokens") + + # Find target token ID + target_token_ids = loaded.tokenizer.encode(target_token) + if len(target_token_ids) != 1: + raise ValueError( + f"Target token '{target_token}' tokenizes to {len(target_token_ids)} tokens, expected 1. " + f"Token IDs: {target_token_ids}" + ) + label_token = target_token_ids[0] + + # Determine loss position + loss_position = params.get("loss_position") + if loss_position is None: + loss_position = len(token_ids) - 1 + + if loss_position >= len(token_ids): + raise ValueError( + f"loss_position {loss_position} out of bounds for prompt with {len(token_ids)} tokens" + ) + + _log_event( + "tool_start", + f"optimize_graph: '{prompt_text}' → '{target_token}'", + {"steps": steps, "loss_position": loss_position}, + ) + + yield {"type": "progress", "current": 0, "total": steps, "stage": "starting optimization"} + + # Create prompt in DB + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Build optimization config + loss_config = CELossConfig(coeff=1.0, position=loss_position, label_token=label_token) + + optim_config = OptimCIConfig( + adv_pgd=None, # AdvPGDConfig(n_steps=10, step_size=0.01, init="random"), + seed=0, + lr=1e-2, + steps=steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.01, + log_freq=max(1, steps // 10), + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=0.0), + loss_config=loss_config, + sampling=loaded.config.sampling, + ce_kl_rounding_threshold=0.5, + mask_type="ci", + ) + + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + progress_queue: queue.Queue[dict[str, Any]] = queue.Queue() + + def on_progress(current: int, total: int, stage: str) -> None: + progress_queue.put({"current": current, "total": total, "stage": stage}) + + # Run optimization in thread + result_holder: list[Any] = [] + error_holder: list[Exception] = [] + + def compute(): + try: + with manager.gpu_lock(): + result = compute_prompt_attributions_optimized( + model=loaded.model, + topology=loaded.topology, + tokens=tokens_tensor, + sources_by_target=loaded.sources_by_target, + optim_config=optim_config, + output_prob_threshold=0.01, + device=DEVICE, + on_progress=on_progress, + ) + result_holder.append(result) + except Exception as e: + error_holder.append(e) + + thread = threading.Thread(target=compute) + thread.start() + + # Yield progress events (throttle logging to every 10% or 10 steps) + last_logged_step = -1 + log_interval = max(1, steps // 10) + + while thread.is_alive() or not progress_queue.empty(): + try: + progress = progress_queue.get(timeout=0.1) + current = progress["current"] + # Log to events.jsonl at intervals (for human monitoring) + if current - last_logged_step >= log_interval or current == progress["total"]: + _log_event( + "optimization_progress", + f"optimize_graph: step {current}/{progress['total']} ({progress['stage']})", + {"prompt": prompt_text, "target": target_token, **progress}, + ) + last_logged_step = current + # Always yield to SSE stream (for Claude) + yield {"type": "progress", **progress} + except queue.Empty: + continue + + thread.join() + + if error_holder: + raise error_holder[0] + + if not result_holder: + raise RuntimeError("Optimization completed but no result was produced") + + result = result_holder[0] + + ci_masked_out_logits = result.ci_masked_out_logits.cpu() + target_out_logits = result.target_out_logits.cpu() + + # Build output probs for response + out_probs = _build_out_probs( + ci_masked_out_logits, + target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Save graph to DB + from spd.app.backend.database import OptimizationParams + + opt_params = OptimizationParams( + imp_min_coeff=0.1, + steps=steps, + pnorm=0.5, + beta=0.0, + mask_type="ci", + loss=loss_config, + ci_masked_label_prob=result.metrics.ci_masked_label_prob, + stoch_masked_label_prob=result.metrics.stoch_masked_label_prob, + adv_pgd_label_prob=result.metrics.adv_pgd_label_prob, + ) + graph_id = manager.db.save_graph( + prompt_id=prompt_id, + graph=StoredGraph( + graph_type="optimized", + edges=result.edges, + edges_abs=result.edges_abs, + ci_masked_out_logits=ci_masked_out_logits, + target_out_logits=target_out_logits, + node_ci_vals=result.node_ci_vals, + node_subcomp_acts=result.node_subcomp_acts, + optimization_params=opt_params, + ), + ) + + # Filter nodes by CI threshold + active_components = {k: v for k, v in result.node_ci_vals.items() if v >= ci_threshold} + + # Get target token probability + target_key = f"{loss_position}:{label_token}" + target_prob = out_probs.get(target_key) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + final_result = { + "graph_id": graph_id, + "prompt_id": prompt_id, + "tokens": token_strings, + "target_token": target_token, + "target_token_id": label_token, + "target_position": loss_position, + "target_probability": target_prob.prob if target_prob else None, + "target_probability_baseline": target_prob.target_prob if target_prob else None, + "active_components": active_components, + "total_active": len(active_components), + "output_probs": {k: {"prob": v.prob, "token": v.token} for k, v in out_probs.items()}, + } + + _log_event( + "tool_complete", + f"optimize_graph complete: {len(active_components)} active components", + {"graph_id": graph_id, "target_prob": target_prob.prob if target_prob else None}, + ) + + yield {"type": "result", "data": final_result} + + +def _tool_get_component_info(params: dict[str, Any]) -> dict[str, Any]: + """Get detailed information about a component.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + top_k = params.get("top_k", 20) + canonical_key = f"{layer}:{component_idx}" + + # Harvest/interp repos store concrete keys (e.g. "h.0.mlp.c_fc:444") + concrete_layer = loaded.topology.canon_to_target(layer) + concrete_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_info: {canonical_key}", + {"layer": layer, "idx": component_idx}, + ) + + result: dict[str, Any] = {"component_key": canonical_key} + + # Get interpretation + if loaded.interp is not None: + interp = loaded.interp.get_interpretation(concrete_key) + if interp is not None: + result["interpretation"] = { + "label": interp.label, + "confidence": interp.confidence, + "reasoning": interp.reasoning, + } + else: + result["interpretation"] = None + else: + result["interpretation"] = None + + # Get token stats + assert loaded.harvest is not None, "harvest data not loaded" + token_stats = loaded.harvest.get_token_stats() + if token_stats is not None: + input_stats = analysis.get_input_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + output_stats = analysis.get_output_token_stats( + token_stats, concrete_key, loaded.tokenizer, top_k + ) + if input_stats and output_stats: + result["token_stats"] = { + "input": { + "top_recall": input_stats.top_recall, + "top_precision": input_stats.top_precision, + "top_pmi": input_stats.top_pmi, + }, + "output": { + "top_recall": output_stats.top_recall, + "top_precision": output_stats.top_precision, + "top_pmi": output_stats.top_pmi, + "bottom_pmi": output_stats.bottom_pmi, + }, + } + else: + result["token_stats"] = None + else: + result["token_stats"] = None + + # Get correlations (return canonical keys) + correlations = loaded.harvest.get_correlations() + if correlations is not None and analysis.has_component(correlations, concrete_key): + result["correlated_components"] = { + "precision": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "precision", top_k + ) + ], + "pmi": [ + {"key": _canonicalize_key(c.component_key, loaded), "score": c.score} + for c in analysis.get_correlated_components( + correlations, concrete_key, "pmi", top_k + ) + ], + } + else: + result["correlated_components"] = None + + return result + + +def _tool_run_ablation(params: dict[str, Any]) -> dict[str, Any]: + """Run ablation with selected components.""" + from spd.app.backend.compute import ( + DEFAULT_EVAL_PGD_CONFIG, + compute_intervention, + ) + from spd.app.backend.optim_cis import MeanKLLossConfig + + manager, loaded = _get_state() + + text = params["text"] + selected_nodes = params["selected_nodes"] + top_k = params.get("top_k", 10) + + _log_event( + "tool_call", + f"run_ablation: '{text[:50]}...' with {len(selected_nodes)} nodes", + {"text": text, "n_nodes": len(selected_nodes)}, + ) + + token_ids = loaded.tokenizer.encode(text) + tokens = torch.tensor([token_ids], dtype=torch.long, device=DEVICE) + + active_nodes = [] + for key in selected_nodes: + parts = key.split(":") + if len(parts) != 3: + raise ValueError(f"Invalid node key format: {key!r} (expected 'layer:seq:cIdx')") + layer, seq_str, cidx_str = parts + if layer in ("wte", "embed", "output"): + raise ValueError(f"Cannot intervene on {layer!r} nodes - only internal layers allowed") + active_nodes.append((layer, int(seq_str), int(cidx_str))) + + with manager.gpu_lock(): + result = compute_intervention( + model=loaded.model, + tokens=tokens, + active_nodes=active_nodes, + nodes_to_ablate=None, + tokenizer=loaded.tokenizer, + adv_pgd_config=DEFAULT_EVAL_PGD_CONFIG, + loss_config=MeanKLLossConfig(), + sampling=loaded.config.sampling, + top_k=top_k, + ) + + predictions = [] + for pos_predictions in result.ci: + pos_result = [] + for pred in pos_predictions: + pos_result.append( + { + "token": pred.token, + "token_id": pred.token_id, + "circuit_prob": round(pred.prob, 6), + "full_model_prob": round(pred.target_prob, 6), + } + ) + predictions.append(pos_result) + + return { + "input_tokens": result.input_tokens, + "predictions_per_position": predictions, + "selected_nodes": selected_nodes, + } + + +def _tool_search_dataset(params: dict[str, Any]) -> dict[str, Any]: + """Search the SimpleStories dataset.""" + import time + + from datasets import Dataset, load_dataset + + query = params["query"] + limit = params.get("limit", 20) + search_query = query.lower() + + _log_event("tool_call", f"search_dataset: '{query}'", {"query": query, "limit": limit}) + + start_time = time.time() + dataset = load_dataset("lennart-finke/SimpleStories", split="train") + assert isinstance(dataset, Dataset) + + filtered = dataset.filter( + lambda x: search_query in x["story"].lower(), + num_proc=4, + ) + + results = [] + for i, item in enumerate(filtered): + if i >= limit: + break + item_dict: dict[str, Any] = dict(item) + story: str = item_dict["story"] + results.append( + { + "story": story[:500] + "..." if len(story) > 500 else story, + "occurrence_count": story.lower().count(search_query), + } + ) + + return { + "query": query, + "total_matches": len(filtered), + "returned": len(results), + "search_time_seconds": round(time.time() - start_time, 2), + "results": results, + } + + +def _tool_create_prompt(params: dict[str, Any]) -> dict[str, Any]: + """Create a prompt from text.""" + manager, loaded = _get_state() + + text = params["text"] + + _log_event("tool_call", f"create_prompt: '{text[:50]}...'", {"text": text}) + + token_ids = loaded.tokenizer.encode(text) + if not token_ids: + raise ValueError("Text produced no tokens") + + prompt_id = manager.db.add_custom_prompt( + run_id=loaded.run.id, + token_ids=token_ids, + context_length=loaded.context_length, + ) + + # Compute next token probs + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + with torch.no_grad(): + logits = loaded.model(tokens_tensor) + probs = torch.softmax(logits, dim=-1) + + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "prompt_id": prompt_id, + "text": text, + "tokens": token_strings, + "token_ids": token_ids, + "next_token_probs": next_token_probs, + } + + +def _require_investigation_config() -> InvestigationConfig: + """Get investigation config, raising if not in investigation mode.""" + assert _investigation_config is not None, "Not running in investigation mode" + return _investigation_config + + +def _tool_update_research_log(params: dict[str, Any]) -> dict[str, Any]: + """Append content to the research log.""" + config = _require_investigation_config() + content = params["content"] + research_log_path = config.investigation_dir / "research_log.md" + + _log_event( + "tool_call", f"update_research_log: {len(content)} chars", {"preview": content[:100]} + ) + + with open(research_log_path, "a") as f: + f.write(content) + if not content.endswith("\n"): + f.write("\n") + + return {"status": "ok", "path": str(research_log_path)} + + +def _tool_save_explanation(params: dict[str, Any]) -> dict[str, Any]: + """Save a behavior explanation to explanations.jsonl.""" + from spd.investigate.schemas import BehaviorExplanation, ComponentInfo, Evidence + + config = _require_investigation_config() + + _log_event( + "tool_call", + f"save_explanation: '{params['behavior_description'][:50]}...'", + {"prompt": params["subject_prompt"]}, + ) + + components = [ + ComponentInfo( + component_key=c["component_key"], + role=c["role"], + interpretation=c.get("interpretation"), + ) + for c in params["components_involved"] + ] + + evidence = [ + Evidence( + evidence_type=e["evidence_type"], + description=e["description"], + details=e.get("details", {}), + ) + for e in params.get("supporting_evidence", []) + ] + + explanation = BehaviorExplanation( + subject_prompt=params["subject_prompt"], + behavior_description=params["behavior_description"], + components_involved=components, + explanation=params["explanation"], + supporting_evidence=evidence, + confidence=params["confidence"], + alternative_hypotheses=params.get("alternative_hypotheses", []), + limitations=params.get("limitations", []), + ) + + explanations_path = config.investigation_dir / "explanations.jsonl" + with open(explanations_path, "a") as f: + f.write(explanation.model_dump_json() + "\n") + + _log_event( + "explanation", + f"Saved explanation: {params['behavior_description']}", + {"confidence": params["confidence"], "n_components": len(components)}, + ) + + return {"status": "ok", "path": str(explanations_path)} + + +def _tool_set_investigation_summary(params: dict[str, Any]) -> dict[str, Any]: + """Set the investigation title and summary.""" + config = _require_investigation_config() + + summary = { + "title": params["title"], + "summary": params["summary"], + "status": params.get("status", "in_progress"), + "updated_at": datetime.now(UTC).isoformat(), + } + + _log_event( + "tool_call", + f"set_investigation_summary: {params['title']}", + summary, + ) + + summary_path = config.investigation_dir / "summary.json" + summary_path.write_text(json.dumps(summary, indent=2)) + + return {"status": "ok", "path": str(summary_path)} + + +def _tool_save_graph_artifact(params: dict[str, Any]) -> dict[str, Any]: + """Save a graph as an artifact for the research report. + + Uses the same filtering logic as the main graph API: + 1. Filter nodes by CI threshold + 2. Add pseudo nodes (wte, output) + 3. Filter edges to only active nodes + 4. Apply edge limit + """ + config = _require_investigation_config() + manager, loaded = _get_state() + + graph_id = params["graph_id"] + caption = params.get("caption") + ci_threshold = params.get("ci_threshold", 0.5) + edge_limit = params.get("edge_limit", 5000) + + _log_event( + "tool_call", + f"save_graph_artifact: graph_id={graph_id}", + {"graph_id": graph_id, "caption": caption}, + ) + + # Fetch graph from DB + result = manager.db.get_graph(graph_id) + if result is None: + raise ValueError(f"Graph with id={graph_id} not found") + + graph, prompt_id = result + + # Get tokens from prompt + prompt_record = manager.db.get_prompt(prompt_id) + if prompt_record is None: + raise ValueError(f"Prompt with id={prompt_id} not found") + + tokens = [loaded.tokenizer.get_tok_display(tid) for tid in prompt_record.token_ids] + num_tokens = len(tokens) + + # Create artifacts directory + artifacts_dir = config.investigation_dir / "artifacts" + artifacts_dir.mkdir(exist_ok=True) + + # Generate artifact ID (find max existing number to avoid collisions) + existing_nums = [] + for f in artifacts_dir.glob("graph_*.json"): + try: + num = int(f.stem.split("_")[1]) + existing_nums.append(num) + except (IndexError, ValueError): + continue + artifact_num = max(existing_nums, default=0) + 1 + artifact_id = f"graph_{artifact_num:03d}" + + # Compute out_probs from stored logits + out_probs = _build_out_probs( + graph.ci_masked_out_logits, + graph.target_out_logits, + loaded.tokenizer.get_tok_display, + ) + + # Step 1: Filter nodes by CI threshold (same as main graph API) + filtered_ci_vals = {k: v for k, v in graph.node_ci_vals.items() if v > ci_threshold} + l0_total = len(filtered_ci_vals) + + # Step 2: Add pseudo nodes (embed and output) - same as _add_pseudo_layer_nodes + node_ci_vals_with_pseudo = dict(filtered_ci_vals) + for seq_pos in range(num_tokens): + node_ci_vals_with_pseudo[f"embed:{seq_pos}:0"] = 1.0 + for key, out_prob in out_probs.items(): + seq_pos, token_id = key.split(":") + node_ci_vals_with_pseudo[f"output:{seq_pos}:{token_id}"] = out_prob.prob + + # Step 3: Filter edges to only active nodes + active_node_keys = set(node_ci_vals_with_pseudo.keys()) + filtered_edges = [ + e + for e in graph.edges + if str(e.source) in active_node_keys and str(e.target) in active_node_keys + ] + + # Step 4: Sort by strength and apply edge limit + filtered_edges.sort(key=lambda e: abs(e.strength), reverse=True) + filtered_edges = filtered_edges[:edge_limit] + + # Build edges data + edges_data = [ + { + "src": str(e.source), + "tgt": str(e.target), + "val": e.strength, + } + for e in filtered_edges + ] + + # Compute max abs attr from filtered edges + max_abs_attr = max((abs(e.strength) for e in filtered_edges), default=0.0) + + # Filter nodeSubcompActs to match nodeCiVals + filtered_subcomp_acts = { + k: v for k, v in graph.node_subcomp_acts.items() if k in node_ci_vals_with_pseudo + } + + # Build artifact data (self-contained GraphData, same structure as API response) + artifact = { + "type": "graph", + "id": artifact_id, + "caption": caption, + "graph_id": graph_id, + "data": { + "tokens": tokens, + "edges": edges_data, + "outputProbs": { + k: { + "prob": v.prob, + "logit": v.logit, + "target_prob": v.target_prob, + "target_logit": v.target_logit, + "token": v.token, + } + for k, v in out_probs.items() + }, + "nodeCiVals": node_ci_vals_with_pseudo, + "nodeSubcompActs": filtered_subcomp_acts, + "maxAbsAttr": max_abs_attr, + "l0_total": l0_total, + }, + } + + # Save artifact + artifact_path = artifacts_dir / f"{artifact_id}.json" + artifact_path.write_text(json.dumps(artifact, indent=2)) + + _log_event( + "artifact_saved", + f"Saved graph artifact: {artifact_id}", + {"artifact_id": artifact_id, "graph_id": graph_id, "path": str(artifact_path)}, + ) + + return {"artifact_id": artifact_id, "path": str(artifact_path)} + + +def _tool_probe_component(params: dict[str, Any]) -> dict[str, Any]: + """Fast CI probing on custom text for a specific component.""" + manager, loaded = _get_state() + + text = params["text"] + layer = params["layer"] + component_idx = params["component_idx"] + + _log_event( + "tool_call", + f"probe_component: '{text[:50]}...' layer={layer} idx={component_idx}", + {"text": text, "layer": layer, "component_idx": component_idx}, + ) + + token_ids = loaded.tokenizer.encode(text) + assert token_ids, "Text produced no tokens" + tokens_tensor = torch.tensor([token_ids], device=DEVICE) + + concrete_layer = loaded.topology.canon_to_target(layer) + + with manager.gpu_lock(): + result = compute_ci_only( + model=loaded.model, tokens=tokens_tensor, sampling=loaded.config.sampling + ) + + ci_values = result.ci_lower_leaky[concrete_layer][0, :, component_idx].tolist() + subcomp_acts = result.component_acts[concrete_layer][0, :, component_idx].tolist() + + # Get next token probs from target model output + next_token_probs = [] + for i in range(len(token_ids) - 1): + next_token_id = token_ids[i + 1] + prob = result.target_out_probs[0, i, next_token_id].item() + next_token_probs.append(round(prob, 6)) + next_token_probs.append(None) + + token_strings = [loaded.tokenizer.get_tok_display(t) for t in token_ids] + + return { + "tokens": token_strings, + "ci_values": ci_values, + "subcomp_acts": subcomp_acts, + "next_token_probs": next_token_probs, + } + + +def _tool_get_component_activation_examples(params: dict[str, Any]) -> dict[str, Any]: + """Get activation examples from harvest data.""" + _, loaded = _get_state() + + layer = params["layer"] + component_idx = params["component_idx"] + limit = params.get("limit", 10) + + concrete_layer = loaded.topology.canon_to_target(layer) + component_key = f"{concrete_layer}:{component_idx}" + + _log_event( + "tool_call", + f"get_component_activation_examples: {component_key}", + {"layer": layer, "component_idx": component_idx, "limit": limit}, + ) + + assert loaded.harvest is not None, "harvest data not loaded" + canonical_key = f"{layer}:{component_idx}" + comp = loaded.harvest.get_component(component_key) + if comp is None: + return {"component_key": canonical_key, "examples": [], "total": 0} + + examples = [] + for ex in comp.activation_examples[:limit]: + token_strings = [loaded.tokenizer.get_tok_display(t) for t in ex.token_ids] + examples.append( + { + "tokens": token_strings, + "ci_values": ex.activations["causal_importance"], + "component_acts": ex.activations["component_activation"], + } + ) + + return { + "component_key": canonical_key, + "examples": examples, + "total": len(comp.activation_examples), + "mean_ci": comp.mean_activations["causal_importance"], + } + + +# def _tool_get_component_attributions(params: dict[str, Any]) -> dict[str, Any]: +# """Get dataset-level component dependencies.""" +# _, loaded = _get_state() + +# layer = params["layer"] +# component_idx = params["component_idx"] +# k = params.get("k", 10) + +# assert loaded.attributions is not None, "dataset attributions not loaded" +# storage = loaded.attributions.get_attributions() + +# concrete_layer = loaded.topology.canon_to_target(layer) if layer != "output" else "output" +# component_key = f"{concrete_layer}:{component_idx}" + +# _log_event( +# "tool_call", +# f"get_component_attributions: {component_key}", +# {"layer": layer, "component_idx": component_idx, "k": k}, +# ) + +# is_source = storage.has_source(component_key) +# is_target = storage.has_target(component_key) + +# assert is_source or is_target, f"Component {component_key} not found in attributions" + +# w_unembed = loaded.topology.get_unembed_weight() if is_source else None + +# def _entries_to_dicts( +# entries: list[Any], +# ) -> list[dict[str, Any]]: +# return [ +# { +# "component_key": f"{_canonicalize_layer(e.layer, loaded)}:{e.component_idx}", +# "layer": _canonicalize_layer(e.layer, loaded), +# "component_idx": e.component_idx, +# "value": e.value, +# } +# for e in entries +# ] + +# positive_sources = ( +# _entries_to_dicts(storage.get_top_sources(component_key, k, "positive")) +# if is_target +# else [] +# ) +# negative_sources = ( +# _entries_to_dicts(storage.get_top_sources(component_key, k, "negative")) +# if is_target +# else [] +# ) +# positive_targets = ( +# _entries_to_dicts( +# storage.get_top_targets( +# component_key, +# k, +# "positive", +# w_unembed=w_unembed, +# include_outputs=w_unembed is not None, +# ) +# ) +# if is_source +# else [] +# ) +# negative_targets = ( +# _entries_to_dicts( +# storage.get_top_targets( +# component_key, +# k, +# "negative", +# w_unembed=w_unembed, +# include_outputs=w_unembed is not None, +# ) +# ) +# if is_source +# else [] +# ) + +# return { +# "component_key": component_key, +# "positive_sources": positive_sources, +# "negative_sources": negative_sources, +# "positive_targets": positive_targets, +# "negative_targets": negative_targets, +# } + + +def _tool_get_model_info(_params: dict[str, Any]) -> dict[str, Any]: + """Get architecture details about the pretrained model.""" + _, loaded = _get_state() + + _log_event("tool_call", "get_model_info", {}) + + info = _get_pretrain_info(loaded.config) + return info.model_dump() + + +def _tool_get_attribution_strength(params: dict[str, Any]) -> dict[str, Any]: + """Query attribution between two specific components.""" + _, loaded = _get_state() + + source_layer = params["source_layer"] + source_idx = params["source_idx"] + target_layer = params["target_layer"] + target_idx = params["target_idx"] + + assert loaded.attributions is not None, "dataset attributions not loaded" + storage = loaded.attributions.get_attributions() + + source_key = _concrete_key(source_layer, source_idx, loaded) + target_key = _concrete_key(target_layer, target_idx, loaded) + + _log_event( + "tool_call", + f"get_attribution_strength: {source_key} → {target_key}", + {"source": source_key, "target": target_key}, + ) + + value = storage.get_attribution(source_key, target_key) + + return {"value": value} + + +# ============================================================================= +# MCP Protocol Handler +# ============================================================================= + + +_STREAMING_TOOLS: dict[str, Callable[..., Generator[dict[str, Any]]]] = { + "optimize_graph": _tool_optimize_graph, +} + +_SIMPLE_TOOLS: dict[str, Callable[..., dict[str, Any]]] = { + "get_component_info": _tool_get_component_info, + "run_ablation": _tool_run_ablation, + "search_dataset": _tool_search_dataset, + "create_prompt": _tool_create_prompt, + "update_research_log": _tool_update_research_log, + "save_explanation": _tool_save_explanation, + "set_investigation_summary": _tool_set_investigation_summary, + "save_graph_artifact": _tool_save_graph_artifact, + "probe_component": _tool_probe_component, + "get_component_activation_examples": _tool_get_component_activation_examples, + # "get_component_attributions": _tool_get_component_attributions, + "get_model_info": _tool_get_model_info, + "get_attribution_strength": _tool_get_attribution_strength, +} + + +def _handle_initialize(_params: dict[str, Any] | None) -> dict[str, Any]: + """Handle initialize request.""" + return { + "protocolVersion": MCP_PROTOCOL_VERSION, + "capabilities": {"tools": {}}, + "serverInfo": {"name": "spd-app", "version": "1.0.0"}, + } + + +def _handle_tools_list() -> dict[str, Any]: + """Handle tools/list request.""" + return {"tools": [t.model_dump() for t in TOOLS]} + + +def _handle_tools_call( + params: dict[str, Any], +) -> Generator[dict[str, Any]] | dict[str, Any]: + """Handle tools/call request. May return generator for streaming tools.""" + name = params.get("name") + arguments = params.get("arguments", {}) + + if name in _STREAMING_TOOLS: + return _STREAMING_TOOLS[name](arguments) + + if name in _SIMPLE_TOOLS: + result = _SIMPLE_TOOLS[name](arguments) + return {"content": [{"type": "text", "text": json.dumps(result, indent=2)}]} + + raise ValueError(f"Unknown tool: {name}") + + +@router.post("/mcp") +async def mcp_endpoint(request: Request): + """MCP JSON-RPC endpoint. + + Handles initialize, tools/list, and tools/call methods. + Returns SSE stream for streaming tools, JSON for others. + """ + try: + body = await request.json() + mcp_request = MCPRequest(**body) + except Exception as e: + return JSONResponse( + status_code=400, + content=MCPResponse( + id=None, error={"code": -32700, "message": f"Parse error: {e}"} + ).model_dump(exclude_none=True), + ) + + logger.info(f"[MCP] {mcp_request.method} (id={mcp_request.id})") + + try: + if mcp_request.method == "initialize": + result = _handle_initialize(mcp_request.params) + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True), + headers={"Mcp-Session-Id": "spd-session"}, + ) + + elif mcp_request.method == "notifications/initialized": + # Client confirms initialization + return JSONResponse(status_code=202, content={}) + + elif mcp_request.method == "tools/list": + result = _handle_tools_list() + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump(exclude_none=True) + ) + + elif mcp_request.method == "tools/call": + if mcp_request.params is None: + raise ValueError("tools/call requires params") + + result = _handle_tools_call(mcp_request.params) + + # Check if result is a generator (streaming) + if inspect.isgenerator(result): + # Streaming response via SSE + gen = result # Capture for closure + + def generate_sse() -> Generator[str]: + try: + final_result = None + for event in gen: + if event.get("type") == "progress": + # Send progress notification + progress_msg = { + "jsonrpc": "2.0", + "method": "notifications/progress", + "params": event, + } + yield f"data: {json.dumps(progress_msg)}\n\n" + elif event.get("type") == "result": + final_result = event["data"] + + # Send final response + response = MCPResponse( + id=mcp_request.id, + result={ + "content": [ + {"type": "text", "text": json.dumps(final_result, indent=2)} + ] + }, + ) + yield f"data: {json.dumps(response.model_dump(exclude_none=True))}\n\n" + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Tool error: {e}\n{tb}") + error_response = MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ) + yield f"data: {json.dumps(error_response.model_dump(exclude_none=True))}\n\n" + + return StreamingResponse(generate_sse(), media_type="text/event-stream") + + else: + # Non-streaming response + return JSONResponse( + content=MCPResponse(id=mcp_request.id, result=result).model_dump( + exclude_none=True + ) + ) + + else: + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32601, "message": f"Method not found: {mcp_request.method}"}, + ).model_dump(exclude_none=True) + ) + + except Exception as e: + tb = traceback.format_exc() + logger.error(f"[MCP] Error handling {mcp_request.method}: {e}\n{tb}") + return JSONResponse( + content=MCPResponse( + id=mcp_request.id, + error={"code": -32000, "message": str(e)}, + ).model_dump(exclude_none=True) + ) diff --git a/spd/app/backend/routers/pretrain_info.py b/spd/app/backend/routers/pretrain_info.py index 2872423c9..424f7b035 100644 --- a/spd/app/backend/routers/pretrain_info.py +++ b/spd/app/backend/routers/pretrain_info.py @@ -38,6 +38,7 @@ class TopologyInfo(BaseModel): class PretrainInfoResponse(BaseModel): model_type: str summary: str + dataset_short: str | None target_model_config: dict[str, Any] | None pretrain_config: dict[str, Any] | None pretrain_wandb_path: str | None @@ -161,6 +162,27 @@ def _build_summary(model_type: str, target_model_config: dict[str, Any] | None) return " · ".join(parts) +_DATASET_SHORT_NAMES: dict[str, str] = { + "simplestories": "SS", + "pile": "Pile", + "tinystories": "TS", +} + + +def _get_dataset_short(pretrain_config: dict[str, Any] | None) -> str | None: + """Extract a short dataset label from the pretrain config.""" + if pretrain_config is None: + return None + dataset_name: str = ( + pretrain_config.get("train_dataset_config", {}).get("name", "") + or pretrain_config.get("dataset", "") + ).lower() + for key, short in _DATASET_SHORT_NAMES.items(): + if key in dataset_name: + return short + return None + + def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: """Extract pretrain info from an SPD config.""" model_class_name = spd_config.pretrained_model_class @@ -190,10 +212,12 @@ def _get_pretrain_info(spd_config: Config) -> PretrainInfoResponse: n_blocks = target_model_config.get("n_layer", 0) if target_model_config else 0 topology = _build_topology(model_type, n_blocks) summary = _build_summary(model_type, target_model_config) + dataset_short = _get_dataset_short(pretrain_config) return PretrainInfoResponse( model_type=model_type, summary=summary, + dataset_short=dataset_short, target_model_config=target_model_config, pretrain_config=pretrain_config, pretrain_wandb_path=pretrain_wandb_path, diff --git a/spd/app/backend/routers/prompts.py b/spd/app/backend/routers/prompts.py index 7f06bcf68..5aec55ea5 100644 --- a/spd/app/backend/routers/prompts.py +++ b/spd/app/backend/routers/prompts.py @@ -91,6 +91,14 @@ def list_prompts(manager: DepStateManager, loaded: DepLoadedRun) -> list[PromptP return results +@router.delete("/{prompt_id}") +@log_errors +def delete_prompt(prompt_id: int, manager: DepStateManager) -> dict[str, bool]: + """Delete a prompt and all its graphs and intervention runs.""" + manager.db.delete_prompt(prompt_id) + return {"success": True} + + @router.post("/custom") @log_errors def add_custom_prompt(text: str, manager: DepStateManager, loaded: DepLoadedRun) -> PromptPreview: diff --git a/spd/app/backend/routers/run_registry.py b/spd/app/backend/routers/run_registry.py new file mode 100644 index 000000000..d44a7108f --- /dev/null +++ b/spd/app/backend/routers/run_registry.py @@ -0,0 +1,95 @@ +"""Run registry endpoint. + +Returns architecture and data availability for requested SPD runs. +The canonical run list lives in the frontend; the backend just hydrates it. +""" + +import asyncio +from pathlib import Path + +from fastapi import APIRouter +from pydantic import BaseModel + +from spd.app.backend.routers.pretrain_info import _get_pretrain_info, _load_spd_config_lightweight +from spd.app.backend.utils import log_errors +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.wandb_utils import parse_wandb_run_path + +router = APIRouter(prefix="/api/run_registry", tags=["run_registry"]) + + +class DataAvailability(BaseModel): + harvest: bool + autointerp: bool + attributions: bool + graph_interp: bool + + +class RunInfoResponse(BaseModel): + wandb_run_id: str + architecture: str | None + availability: DataAvailability + + +def _has_glob_match(pattern_dir: Path, glob_pattern: str) -> bool: + """Check if any file matches a glob pattern under a directory.""" + if not pattern_dir.exists(): + return False + return next(pattern_dir.glob(glob_pattern), None) is not None + + +def _check_availability(run_id: str) -> DataAvailability: + """Lightweight filesystem checks for post-processing data availability.""" + harvest_dir = SPD_OUT_DIR / "harvest" / run_id + autointerp_dir = SPD_OUT_DIR / "autointerp" / run_id + attributions_dir = SPD_OUT_DIR / "dataset_attributions" / run_id + graph_interp_dir = SPD_OUT_DIR / "graph_interp" / run_id + + return DataAvailability( + harvest=_has_glob_match(harvest_dir, "h-*/harvest.db"), + autointerp=_has_glob_match(autointerp_dir, "a-*/.done"), + attributions=_has_glob_match(attributions_dir, "da-*/dataset_attributions.pt"), + graph_interp=_has_glob_match(graph_interp_dir, "*/interp.db"), + ) + + +def _get_architecture_summary(wandb_path: str) -> str | None: + """Get a short architecture label for a run. Returns None on failure.""" + try: + spd_config = _load_spd_config_lightweight(wandb_path) + info = _get_pretrain_info(spd_config) + parts: list[str] = [] + if info.dataset_short: + parts.append(info.dataset_short) + parts.append(info.model_type) + cfg = info.target_model_config + if cfg: + n_layer = cfg.get("n_layer") + n_embd = cfg.get("n_embd") + if n_layer is not None: + parts.append(f"{n_layer}L") + if n_embd is not None: + parts.append(f"d{n_embd}") + return " ".join(parts) + except Exception: + logger.exception(f"[run_registry] Failed to get architecture for {wandb_path}") + return None + + +def _build_run_info(wandb_run_id: str) -> RunInfoResponse: + _, _, run_id = parse_wandb_run_path(wandb_run_id) + return RunInfoResponse( + wandb_run_id=wandb_run_id, + architecture=_get_architecture_summary(wandb_run_id), + availability=_check_availability(run_id), + ) + + +@router.post("") +@log_errors +async def get_run_info(wandb_run_ids: list[str]) -> list[RunInfoResponse]: + """Return architecture and availability for the requested runs.""" + loop = asyncio.get_running_loop() + tasks = [loop.run_in_executor(None, _build_run_info, wid) for wid in wandb_run_ids] + return list(await asyncio.gather(*tasks)) diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index 0989cea54..b0e323cc2 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -10,11 +10,13 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.dependencies import DepStateManager +from spd.app.backend.routers.graph_interp import MOCK_MODE as _GRAPH_INTERP_MOCK_MODE from spd.app.backend.state import RunState from spd.app.backend.utils import log_errors from spd.autointerp.repo import InterpRepo from spd.configs import LMTaskConfig from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -42,6 +44,8 @@ class LoadedRun(BaseModel): backend_user: str dataset_attributions_available: bool dataset_search_enabled: bool + graph_interp_available: bool + autointerp_available: bool router = APIRouter(prefix="/api", tags=["runs"]) @@ -128,6 +132,7 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): harvest=HarvestRepo.open_most_recent(run_id), interp=InterpRepo.open(run_id), attributions=AttributionRepo.open(run_id), + graph_interp=GraphInterpRepo.open(run_id), ) logger.info(f"[API] Run {run.id} loaded on {DEVICE}") @@ -165,6 +170,10 @@ def get_status(manager: DepStateManager) -> LoadedRun | None: backend_user=getpass.getuser(), dataset_attributions_available=manager.run_state.attributions is not None, dataset_search_enabled=dataset_search_enabled, + # TODO(oli): Remove MOCK_MODE import after real data available + graph_interp_available=manager.run_state.graph_interp is not None + or _GRAPH_INTERP_MOCK_MODE, + autointerp_available=manager.run_state.interp is not None, ) diff --git a/spd/app/backend/schemas.py b/spd/app/backend/schemas.py index bc99dc831..61fa3d9d2 100644 --- a/spd/app/backend/schemas.py +++ b/spd/app/backend/schemas.py @@ -19,8 +19,6 @@ class OutputProbability(BaseModel): logit: float # CI-masked (SPD model) raw logit target_prob: float # Target model probability target_logit: float # Target model raw logit - adv_pgd_prob: float | None = None # Adversarial PGD probability - adv_pgd_logit: float | None = None # Adversarial PGD raw logit token: str diff --git a/spd/app/backend/server.py b/spd/app/backend/server.py index 3804ce756..89ac602b3 100644 --- a/spd/app/backend/server.py +++ b/spd/app/backend/server.py @@ -32,14 +32,19 @@ data_sources_router, dataset_attributions_router, dataset_search_router, + graph_interp_router, graphs_router, intervention_router, + investigations_router, + mcp_router, pretrain_info_router, prompts_router, + run_registry_router, runs_router, ) from spd.app.backend.state import StateManager from spd.log import logger +from spd.settings import SPD_APP_DEFAULT_RUN from spd.utils.distributed_utils import get_device DEVICE = get_device() @@ -48,6 +53,11 @@ @asynccontextmanager async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] """Initialize DB connection at startup. Model loaded on-demand via /api/runs/load.""" + import os + from pathlib import Path + + from spd.app.backend.routers.mcp import InvestigationConfig, set_investigation_config + manager = StateManager.get() db = PromptAttrDB(check_same_thread=False) @@ -58,6 +68,24 @@ async def lifespan(app: FastAPI): # pyright: ignore[reportUnusedParameter] logger.info(f"[STARTUP] Device: {DEVICE}") logger.info(f"[STARTUP] CUDA available: {torch.cuda.is_available()}") + # Configure MCP for investigation mode (derives paths from investigation dir) + investigation_dir = os.environ.get("SPD_INVESTIGATION_DIR") + if investigation_dir: + inv_dir = Path(investigation_dir) + set_investigation_config( + InvestigationConfig( + events_log_path=inv_dir / "events.jsonl", + investigation_dir=inv_dir, + ) + ) + logger.info(f"[STARTUP] Investigation mode enabled: dir={investigation_dir}") + + if SPD_APP_DEFAULT_RUN is not None: + from spd.app.backend.routers.runs import load_run + + logger.info(f"[STARTUP] Auto-loading default run: {SPD_APP_DEFAULT_RUN}") + load_run(SPD_APP_DEFAULT_RUN, context_length=512, manager=manager) + yield manager.close() @@ -157,8 +185,12 @@ async def global_exception_handler(request: Request, exc: Exception) -> JSONResp app.include_router(dataset_search_router) app.include_router(dataset_attributions_router) app.include_router(agents_router) +app.include_router(investigations_router) +app.include_router(mcp_router) app.include_router(data_sources_router) +app.include_router(graph_interp_router) app.include_router(pretrain_info_router) +app.include_router(run_registry_router) def cli(port: int = 8000) -> None: diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index cf71c2bc6..2cdabda73 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -5,14 +5,20 @@ - StateManager: Singleton managing app-wide state with proper lifecycle """ +import threading +from collections.abc import Generator +from contextlib import contextmanager from dataclasses import dataclass, field from typing import Any +from fastapi import HTTPException + from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB, Run from spd.autointerp.repo import InterpRepo from spd.configs import Config from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.repo import GraphInterpRepo from spd.harvest.repo import HarvestRepo from spd.models.component_model import ComponentModel from spd.topology import TransformerTopology @@ -32,6 +38,7 @@ class RunState: harvest: HarvestRepo | None interp: InterpRepo | None attributions: AttributionRepo | None + graph_interp: GraphInterpRepo | None @dataclass @@ -62,6 +69,7 @@ class StateManager: def __init__(self) -> None: self._state: AppState | None = None + self._gpu_lock = threading.Lock() @classmethod def get(cls) -> "StateManager": @@ -104,3 +112,21 @@ def close(self) -> None: """Clean up resources.""" if self._state is not None: self._state.db.close() + + @contextmanager + def gpu_lock(self) -> Generator[None]: + """Acquire GPU lock or fail with 503 if another GPU operation is in progress. + + Use this for GPU-intensive endpoints to prevent concurrent operations + that would cause the server to hang. + """ + acquired = self._gpu_lock.acquire(blocking=False) + if not acquired: + raise HTTPException( + status_code=503, + detail="GPU operation already in progress. Please wait and retry.", + ) + try: + yield + finally: + self._gpu_lock.release() diff --git a/spd/app/frontend/package-lock.json b/spd/app/frontend/package-lock.json index b6c451303..32da0218c 100644 --- a/spd/app/frontend/package-lock.json +++ b/spd/app/frontend/package-lock.json @@ -7,6 +7,9 @@ "": { "name": "frontend", "version": "0.0.0", + "dependencies": { + "marked": "^17.0.1" + }, "devDependencies": { "@eslint/js": "^9.38.0", "@sveltejs/vite-plugin-svelte": "^6.2.1", @@ -2347,6 +2350,18 @@ "@jridgewell/sourcemap-codec": "^1.5.5" } }, + "node_modules/marked": { + "version": "17.0.1", + "resolved": "https://registry.npmjs.org/marked/-/marked-17.0.1.tgz", + "integrity": "sha512-boeBdiS0ghpWcSwoNm/jJBwdpFaMnZWRzjA6SkUMYb40SVaN1x7mmfGKp0jvexGcx+7y2La5zRZsYFZI6Qpypg==", + "license": "MIT", + "bin": { + "marked": "bin/marked.js" + }, + "engines": { + "node": ">= 20" + } + }, "node_modules/merge2": { "version": "1.4.1", "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", diff --git a/spd/app/frontend/package.json b/spd/app/frontend/package.json index f54e1bb3d..f298885ce 100644 --- a/spd/app/frontend/package.json +++ b/spd/app/frontend/package.json @@ -27,5 +27,8 @@ "typescript": "~5.9.3", "typescript-eslint": "^8.46.2", "vite": "^7.1.7" + }, + "dependencies": { + "marked": "^17.0.1" } } diff --git a/spd/app/frontend/src/app.css b/spd/app/frontend/src/app.css index 8bb0c490f..bf6649aee 100644 --- a/spd/app/frontend/src/app.css +++ b/spd/app/frontend/src/app.css @@ -1,22 +1,22 @@ :root { - /* Punchy Research - crisp whites, bold contrasts */ + /* Goodfire-inspired - warm whites, navy text, vibrant blue accent */ --bg-base: #ffffff; --bg-surface: #ffffff; --bg-elevated: #ffffff; - --bg-inset: #f8f9fa; - --bg-hover: #f3f4f6; + --bg-inset: #f7f6f2; + --bg-hover: #f0efeb; - --border-subtle: #e0e0e0; - --border-default: #c0c0c0; - --border-strong: #888888; + --border-subtle: #e5e3dc; + --border-default: #c8c5bc; + --border-strong: #8a8780; - --text-primary: #111111; - --text-secondary: #555555; - --text-muted: #999999; + --text-primary: #1d272a; + --text-secondary: #646464; + --text-muted: #b4b4b4; - --accent-primary: #2563eb; - --accent-primary-bright: #3b82f6; - --accent-primary-dim: #1d4ed8; + --accent-primary: #7c4d33; + --accent-primary-bright: #96613f; + --accent-primary-dim: #5e3a27; --status-positive: #16a34a; --status-positive-bright: #22c55e; @@ -24,8 +24,10 @@ --status-negative-bright: #ef4444; --status-warning: #eab308; --status-warning-bright: #facc15; - --status-info: #2563eb; - --status-info-bright: #3b82f6; + --status-info: #4d65ff; + --status-info-bright: #6b7fff; + + --focus-ring: #4d65ff; /* Typography - Clean system fonts with mono for code */ --font-mono: "SF Mono", "Menlo", "Monaco", "Consolas", monospace; diff --git a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte index e20ba1adf..c9c304950 100644 --- a/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte +++ b/spd/app/frontend/src/components/ActivationContextsPagedTable.svelte @@ -1,23 +1,29 @@ @@ -106,21 +112,22 @@
- @@ -129,58 +136,69 @@
-
- {#if displaySettings.centerOnPeak} -
- {#each paginatedIndices as idx (idx)} - {@const fp = firingPositions[idx]} -
-
- -
-
- + {#if loading} +
+
+ {#each Array(pageSize) as _, i (i)} +
+ {/each} +
+
+ {:else} + {@const d = loaded!} +
+ {#if displaySettings.centerOnPeak} +
+ {#each paginatedIndices as idx (idx)} + {@const fp = firingPositions[idx]} +
+
+ +
+
+ +
+
+ +
-
+ {/each} +
+ {:else} +
+ {#each paginatedIndices as idx (idx)} +
-
- {/each} -
- {:else} -
- {#each paginatedIndices as idx (idx)} -
- -
- {/each} -
- {/if} -
+ {/each} +
+ {/if} +
+ {/if}
diff --git a/spd/app/frontend/src/components/ActivationContextsViewer.svelte b/spd/app/frontend/src/components/ActivationContextsViewer.svelte index d20831c1a..232e4cd39 100644 --- a/spd/app/frontend/src/components/ActivationContextsViewer.svelte +++ b/spd/app/frontend/src/components/ActivationContextsViewer.svelte @@ -1,15 +1,16 @@
@@ -288,7 +304,7 @@
@@ -412,26 +428,22 @@ {/if} - +
+ + {#if currentGraphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail.status === "loading"} -
Loading component data...
- {:else if componentData.componentDetail.status === "loaded"} - - {:else if componentData.componentDetail.status === "error"} - Error loading component data: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading component data: {String(activationExamples.error)} {:else} - Something went wrong loading component data. + {/if} + import { getContext, onMount } from "svelte"; + import { computeMaxAbsComponentAct } from "../lib/colors"; + import { mapLoadable } from "../lib/index"; + import { anyCorrelationStatsEnabled } from "../lib/displaySettings.svelte"; + import { useComponentDataExpectCached } from "../lib/useComponentDataExpectCached.svelte"; + import { RUN_KEY, type RunContext } from "../lib/useRun.svelte"; + import ActivationContextsPagedTable, { type ActivationExamplesData } from "./ActivationContextsPagedTable.svelte"; + import ComponentProbeInput from "./ComponentProbeInput.svelte"; + import ComponentCorrelationMetrics from "./ui/ComponentCorrelationMetrics.svelte"; + import DatasetAttributionsSection from "./ui/DatasetAttributionsSection.svelte"; + import InterpretationBadge from "./ui/InterpretationBadge.svelte"; + import SectionHeader from "./ui/SectionHeader.svelte"; + import StatusText from "./ui/StatusText.svelte"; + import TokenStatsSection from "./ui/TokenStatsSection.svelte"; + + const runState = getContext(RUN_KEY); + + type Props = { + layer: string; + cIdx: number; + }; + + let { layer, cIdx }: Props = $props(); + + const intruderScore = $derived(runState.getIntruderScore(`${layer}:${cIdx}`)); + + const componentData = useComponentDataExpectCached(); + + onMount(() => { + componentData.load(layer, cIdx); + }); + + const inputTokenLists = $derived.by(() => { + const tokenStats = componentData.tokenStats; + if (tokenStats.status !== "loaded" || tokenStats.data === null) return null; + return [ + { + title: "Top Precision", + mathNotation: "P(component fires | token)", + items: tokenStats.data.input.top_precision.map(([token, value]) => ({ + token, + value, + })), + maxScale: 1, + }, + ]; + }); + + const outputTokenLists = $derived.by(() => { + const tokenStats = componentData.tokenStats; + if (tokenStats.status !== "loaded" || tokenStats.data === null) return null; + const maxAbsPmi = Math.max( + tokenStats.data.output.top_pmi[0]?.[1] ?? 0, + Math.abs(tokenStats.data.output.bottom_pmi?.[0]?.[1] ?? 0), + ); + return [ + { + title: "Top PMI", + mathNotation: "positive association with predictions", + items: tokenStats.data.output.top_pmi.map(([token, value]) => ({ token, value })), + maxScale: maxAbsPmi, + }, + { + title: "Bottom PMI", + mathNotation: "negative association with predictions", + items: tokenStats.data.output.bottom_pmi.map(([token, value]) => ({ + token, + value, + })), + maxScale: maxAbsPmi, + }, + ]; + }); + + function formatNumericalValue(val: number): string { + return Math.abs(val) < 0.001 ? val.toExponential(2) : val.toFixed(3); + } + + const maxAbsComponentAct = $derived.by(() => { + if (componentData.componentDetail.status !== "loaded") return 1; + return computeMaxAbsComponentAct(componentData.componentDetail.data.example_component_acts); + }); + + const activationExamples = $derived( + mapLoadable( + componentData.componentDetail, + (d): ActivationExamplesData => ({ + tokens: d.example_tokens, + ci: d.example_ci, + componentActs: d.example_component_acts, + maxAbsComponentAct: computeMaxAbsComponentAct(d.example_component_acts), + }), + ), + ); + + +
+
+

{layer}:{cIdx}

+
+ {#if componentData.componentDetail.status === "loaded"} + Mean CI: {formatNumericalValue(componentData.componentDetail.data.mean_ci)} + {/if} + {#if intruderScore !== null} + Intruder: {Math.round(intruderScore * 100)}% + {/if} +
+
+ + + +
+ + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + + {/if} +
+ + + + {#if componentData.datasetAttributions.status === "uninitialized"} + uninitialized + {:else if componentData.datasetAttributions.status === "loaded"} + {#if componentData.datasetAttributions.data !== null} + + {:else} + No dataset attributions available. + {/if} + {:else if componentData.datasetAttributions.status === "loading"} +
+ + Loading... +
+ {:else if componentData.datasetAttributions.status === "error"} +
+ + Error: {String(componentData.datasetAttributions.error)} +
+ {/if} + +
+ +
+ {#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} + Loading token stats... + {:else if componentData.tokenStats.status === "error"} + Error: {String(componentData.tokenStats.error)} + {:else} + + + + {/if} +
+
+ + {#if anyCorrelationStatsEnabled()} +
+ + {#if componentData.correlations.status === "loading"} + Loading... + {:else if componentData.correlations.status === "loaded" && componentData.correlations.data} + + {:else if componentData.correlations.status === "error"} + Error loading correlations: {String(componentData.correlations.error)} + {:else} + No correlations available. + {/if} +
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClusterPathInput.svelte b/spd/app/frontend/src/components/ClusterPathInput.svelte index 27b066c2c..6adcfb2b3 100644 --- a/spd/app/frontend/src/components/ClusterPathInput.svelte +++ b/spd/app/frontend/src/components/ClusterPathInput.svelte @@ -1,8 +1,8 @@ + +
+ {#if clusterMapping} + + {:else} + No clusters loaded. Use the cluster path input in the header bar to load a cluster mapping. + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ClustersViewer.svelte b/spd/app/frontend/src/components/ClustersViewer.svelte new file mode 100644 index 000000000..324b04e77 --- /dev/null +++ b/spd/app/frontend/src/components/ClustersViewer.svelte @@ -0,0 +1,252 @@ + + +
+ {#if selectedClusterId === null} +
+

Clusters ({clusterGroups.sorted.length})

+ {#each clusterGroups.sorted as [clusterId, members] (clusterId)} + {@const previewLabels = getPreviewLabels(members)} + + {/each} + {#if clusterGroups.singletons.length > 0} + + {/if} +
+ {:else} +
+
+ +

+ {selectedClusterId === "unclustered" ? "Unclustered" : `Cluster ${selectedClusterId}`} +

+ {selectedMembers.length} components +
+
+ {#each selectedMembers as member (`${member.layer}:${member.cIdx}`)} +
+ +
+ {/each} +
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/DataSourcesTab.svelte b/spd/app/frontend/src/components/DataSourcesTab.svelte index bc9282c27..c54a1fa0a 100644 --- a/spd/app/frontend/src/components/DataSourcesTab.svelte +++ b/spd/app/frontend/src/components/DataSourcesTab.svelte @@ -29,7 +29,7 @@ }); function formatConfigValue(value: unknown): string { - if (value === null || value === undefined) return "—"; + if (value === null || value === undefined) return "\u2014"; if (typeof value === "object") return JSON.stringify(value); return String(value); } @@ -50,116 +50,91 @@ } -
- {#if runState.run.status === "loaded" && runState.run.data.config_yaml} -
-

Run Config

-
{runState.run.data.config_yaml}
-
- {/if} - - - {#if pretrainData.status === "loading"} -
-

Target Model

-

Loading target model info...

-
- {:else if pretrainData.status === "loaded"} - {@const pt = pretrainData.data} -
-

Target Model

-
- Architecture - {pt.summary} - - {#if pt.pretrain_wandb_path} - Pretrain run - {pt.pretrain_wandb_path} - {/if} -
+
+ +
+ {#if runState.run.status === "loaded" && runState.run.data.config_yaml} +
+

Run Config

+
{runState.run.data.config_yaml}
+
+ {/if} - {#if pt.topology} -
-

Topology

- +
+

Target Model

+ {#if pretrainData.status === "loading"} +

Loading...

+ {:else if pretrainData.status === "loaded"} + {@const pt = pretrainData.data} +
+ Architecture + {pt.summary} + {#if pt.pretrain_wandb_path} + Pretrain run + {pt.pretrain_wandb_path} + {/if}
+ {#if pt.topology} +
+ +
+ {/if} + {#if pt.pretrain_config} +
+ Pretraining config +
{formatPretrainConfigYaml(pt.pretrain_config)}
+
+ {/if} + {:else if pretrainData.status === "error"} +

Failed to load target model info

{/if} - - {#if pt.pretrain_config} -
- Pretraining config -
{formatPretrainConfigYaml(pt.pretrain_config)}
-
- {/if} -
- {:else if pretrainData.status === "error"} -
-

Target Model

-

Failed to load target model info

- {/if} - - {#if data.status === "loading"} -

Loading data sources...

- {:else if data.status === "error"} -

Failed to load data sources: {data.error}

- {:else if data.status === "loaded"} - {@const { harvest, autointerp, attributions } = data.data} - - {#if !harvest && !autointerp && !attributions} -

No pipeline data available for this run.

- {/if} - - {#if harvest} -
-

Harvest

+
+ + +
+ +
+
+ +

Harvest

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.harvest} + {@const harvest = data.data.harvest}
Subrun {harvest.subrun_id} - Components {harvest.n_components.toLocaleString()} - Intruder eval {harvest.has_intruder_scores ? "yes" : "no"} - {#each Object.entries(harvest.config) as [key, value] (key)} {key} {formatConfigValue(value)} {/each}
-
- {/if} - - {#if attributions} -
-

Dataset Attributions

-
- Subrun - {attributions.subrun_id} - - Batches - {attributions.n_batches_processed.toLocaleString()} - - Tokens - {attributions.n_tokens_processed.toLocaleString()} - - CI threshold - {attributions.ci_threshold} -
-
- {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
- {#if autointerp} -
-

Autointerp

+ +
+
+ +

Autointerp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.autointerp} + {@const autointerp = data.data.autointerp}
Subrun {autointerp.subrun_id} - Interpretations {autointerp.n_interpretations.toLocaleString()} - Eval scores {#if autointerp.eval_scores.length > 0} @@ -168,65 +143,156 @@ none {/if} - {#each Object.entries(autointerp.config) as [key, value] (key)} {key} {formatConfigValue(value)} {/each}
-
- {/if} - {/if} + {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Dataset Attributions

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.attributions} + {@const attributions = data.data.attributions} +
+ Subrun + {attributions.subrun_id} + Tokens + {attributions.n_tokens_processed.toLocaleString()} + CI threshold + {attributions.ci_threshold} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+ + +
+
+ +

Graph Interp

+
+ {#if data.status === "loading"} +

Loading...

+ {:else if data.status === "loaded" && data.data.graph_interp} + {@const graph_interp = data.data.graph_interp} +
+ Subrun + {graph_interp.subrun_id} + {#each Object.entries(graph_interp.label_counts) as [key, value] (key)} + {key} labels + {value.toLocaleString()} + {/each} + {#if graph_interp.config} + {#each Object.entries(graph_interp.config) as [key, value] (key)} + {key} + {formatConfigValue(value)} + {/each} + {/if} +
+ {:else if data.status === "loaded"} +

Not available

+ {/if} +
+
diff --git a/spd/app/frontend/src/components/InvestigationsTab.svelte b/spd/app/frontend/src/components/InvestigationsTab.svelte new file mode 100644 index 000000000..b7752cb5f --- /dev/null +++ b/spd/app/frontend/src/components/InvestigationsTab.svelte @@ -0,0 +1,645 @@ + + +
+ {#if selected?.status === "loaded"} + +
+ +

{selected.data.title || formatId(selected.data.id)}

+ + {#if selected.data.status} + + {selected.data.status} + + {/if} +
+ + {#if selected.data.summary} +

{selected.data.summary}

+ {/if} + + +

+ {formatId(selected.data.id)} · Started {formatDate(selected.data.created_at)} + {#if selected.data.wandb_path} + · {selected.data.wandb_path} + {/if} +

+ +
+ + +
+ +
+ {#if activeTab === "research"} + {#if selected.data.research_log} + + {:else} +

No research log available

+ {/if} + {:else} +
+ {#each selected.data.events as event, i (i)} +
+ + {event.event_type} + + {formatDate(event.timestamp)} + {event.message} + {#if event.details && Object.keys(event.details).length > 0} +
+ Details +
{JSON.stringify(event.details, null, 2)}
+
+ {/if} +
+ {:else} +

No events recorded

+ {/each} +
+ {/if} +
+ {:else if selected?.status === "loading"} +
Loading investigation...
+ {:else} + +
+

Investigations

+ +
+ +
{ + e.preventDefault(); + handleLaunch(); + }} + > + + +
+ {#if launchState.status === "error"} +
{launchState.error}
+ {/if} + + {#if investigations.status === "loading"} +
Loading investigations...
+ {:else if investigations.status === "error"} +
{investigations.error}
+ {:else if investigations.status === "loaded"} +
+ {#each investigations.data as inv (inv.id)} + + {:else} +

+ No investigations found. Run spd-investigate to create one. +

+ {/each} +
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ModelGraph.svelte b/spd/app/frontend/src/components/ModelGraph.svelte new file mode 100644 index 000000000..ae49a25e3 --- /dev/null +++ b/spd/app/frontend/src/components/ModelGraph.svelte @@ -0,0 +1,520 @@ + + +
+ +
+
+ + + +
+
+ +
+
+ +
+
+ +
+
+ {filteredNodes.length} nodes, {visibleEdges.length} edges +
+
+ + + +
+
+ + + + + + {#if tooltipNode} +
+
{tooltipNode.label}
+
+ {tooltipNode.confidence} + {tooltipNode.key} +
+
+ {/if} + + + {#if selectedNodeKey} + {@const selectedNode = layout.nodes.get(selectedNodeKey)} + {#if selectedNode} +
+
+ {selectedNode.label} + {selectedNode.confidence} + +
+
{selectedNode.key}
+
+ {#if selectedNodeEdges.length > 0} +
+ {#each selectedNodeEdges as e, i (i)} + {@const other = e.source === selectedNodeKey ? e.target : e.source} + {@const otherNode = layout.nodes.get(other)} +
+ {e.source === selectedNodeKey ? "to" : "from"} + {otherNode?.label ?? other} + {e.attribution.toFixed(3)} +
+ {/each} +
+ {:else} + No edges + {/if} +
+
+ {/if} + {/if} +
+ + diff --git a/spd/app/frontend/src/components/ProbColoredTokens.svelte b/spd/app/frontend/src/components/ProbColoredTokens.svelte index 9a81b4858..7b30870c3 100644 --- a/spd/app/frontend/src/components/ProbColoredTokens.svelte +++ b/spd/app/frontend/src/components/ProbColoredTokens.svelte @@ -1,5 +1,7 @@ {#each tokens as tok, i (i)}{tok}{#each tokens as tok, i (i)}{@const prob = getProbAtPosition(nextTokenProbs, i)}{/each} @@ -33,34 +26,10 @@ display: inline-flex; flex-wrap: wrap; gap: 1px; - font-family: var(--font-mono); } - .prob-token { - padding: 1px 2px; + .prob-token-wrapper { border-right: 1px solid var(--border-subtle); - position: relative; - white-space: pre; - } - - .prob-token::after { - content: attr(data-tooltip); - position: absolute; - top: calc(100% + 4px); - left: 0; - background: var(--bg-elevated); - border: 1px solid var(--border-strong); - color: var(--text-primary); - padding: var(--space-1) var(--space-2); - font-size: var(--text-xs); - font-family: var(--font-mono); - white-space: nowrap; - opacity: 0; - pointer-events: none; - z-index: 1000; - } - - .prob-token:hover::after { - opacity: 1; + padding: 1px 0; } diff --git a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte index da0d0e935..c5e182917 100644 --- a/spd/app/frontend/src/components/PromptAttributionsGraph.svelte +++ b/spd/app/frontend/src/components/PromptAttributionsGraph.svelte @@ -1,6 +1,6 @@
@@ -40,7 +57,15 @@ {/if} @@ -93,6 +128,10 @@ {runState.run.error}
{/if} + +
+ +
{#if runState.prompts.status === "loaded"}
@@ -109,6 +148,11 @@
+ {#if runState.clusterMapping} +
+ +
+ {/if} {:else if runState.run.status === "loading" || runState.prompts.status === "loading"}

Loading run...

diff --git a/spd/app/frontend/src/components/TokenHighlights.svelte b/spd/app/frontend/src/components/TokenHighlights.svelte index 4916c1f00..456cdbc72 100644 --- a/spd/app/frontend/src/components/TokenHighlights.svelte +++ b/spd/app/frontend/src/components/TokenHighlights.svelte @@ -1,5 +1,6 @@ + +
+ {#if caption} +
{caption}
+ {/if} + + +
+ + +
+ + + {#each Object.entries(layout.layerYPositions) as [layer, y] (layer)} + + {getRowLabel(getRowKey(layer))} + + {/each} + + +
+ +
+ + + + + {@html edgesSvg} + + + + {#each Object.entries(layout.nodePositions) as [key, pos] (key)} + {@const style = nodeStyles[key]} + {@const [layer, seqIdxStr, cIdxStr] = key.split(":")} + {@const seqIdx = parseInt(seqIdxStr)} + {@const cIdx = parseInt(cIdxStr)} + + handleNodeHover(e, layer, seqIdx, cIdx)} + onmouseleave={handleNodeLeave} + /> + + + {/each} + + + + +
+ + + {#each data.tokens as token, i (i)} + {@const colLeft = layout.seqXStarts[i] + 8} + + {token} + + [{i}] + {/each} + + +
+
+
+ +
+ L0: {data.l0_total} · Edges: {filteredEdges.length} +
+ + + {#if hoveredNode && runState} + (isHoveringTooltip = true)} + onMouseLeave={() => { + isHoveringTooltip = false; + hoveredNode = null; + }} + /> + {/if} +
+ + diff --git a/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte new file mode 100644 index 000000000..ef9b8d40d --- /dev/null +++ b/spd/app/frontend/src/components/investigations/ResearchLogViewer.svelte @@ -0,0 +1,223 @@ + + +
+ {#each contentBlocks as block, i (i)} + {#if block.type === "html"} + +
{@html block.content}
+ {:else if block.type === "graph"} + {@const artifact = artifacts[block.artifactId]} + {#if artifact} + + {:else if artifactsLoading} +
+ Loading graph: {block.artifactId}... +
+ {:else} +
+ Graph artifact not found: {block.artifactId} +
+ {/if} + {/if} + {/each} +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte index a0d663208..91083d851 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte @@ -1,6 +1,7 @@
@@ -208,30 +212,26 @@
- +
+ + {#if graphInterpLabel && componentData.graphInterpDetail.status === "loaded" && componentData.graphInterpDetail.data} + + {/if} +
- {#if componentData.componentDetail.status === "uninitialized"} - uninitialized - {:else if componentData.componentDetail.status === "loading"} - Loading details... - {:else if componentData.componentDetail.status === "loaded"} - {#if componentData.componentDetail.data.example_tokens.length > 0} - - {/if} - {:else if componentData.componentDetail.status === "error"} - Error loading details: {String(componentData.componentDetail.error)} + {#if activationExamples.status === "error"} + Error loading details: {String(activationExamples.error)} + {:else if activationExamples.status === "loaded" && activationExamples.data.tokens.length === 0} + + {:else} + {/if}
@@ -243,34 +243,29 @@ title="Prompt Attributions" incomingLabel="Incoming" outgoingLabel="Outgoing" - {incomingPositive} - {incomingNegative} - {outgoingPositive} - {outgoingNegative} + {incoming} + {outgoing} pageSize={COMPONENT_CARD_CONSTANTS.PROMPT_ATTRIBUTIONS_PAGE_SIZE} onClick={handleEdgeNodeClick} - {tokens} - {outputProbs} /> {/if} - {#if componentData.datasetAttributions.status === "uninitialized"} - uninitialized + {#if componentData.datasetAttributions.status === "loading" || componentData.datasetAttributions.status === "uninitialized"} +
+ +
+
+
+
+
{:else if componentData.datasetAttributions.status === "loaded"} {#if componentData.datasetAttributions.data !== null} - {:else} - No dataset attributions available. {/if} - {:else if componentData.datasetAttributions.status === "loading"} -
- - Loading... -
{:else if componentData.datasetAttributions.status === "error"}
@@ -282,7 +277,12 @@
{#if componentData.tokenStats === null || componentData.tokenStats.status === "loading"} - Loading token stats... +
+
+
+
+
+
{:else if componentData.tokenStats.status === "error"} Error: {String(componentData.tokenStats.error)} {:else} @@ -306,7 +306,10 @@
{#if componentData.correlations.status === "loading"} - Loading... +
+
+
+
{:else if componentData.correlations.status === "loaded" && componentData.correlations.data} diff --git a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte index 1f3c09a01..23619c281 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComputeProgressOverlay.svelte @@ -1,46 +1,54 @@
-
- {#each state.stages as stage, i (i)} - {@const isCurrent = i === state.currentStage} - {@const isComplete = i < state.currentStage} -
-
- {i + 1} - {stage.name} - {#if isComplete} - - {/if} -
- {#if isCurrent} -
- {#if stage.progress !== null} -
- {:else} -
+
+ {#if ciSnapshot} + + {/if} +
+ {#each state.stages as stage, i (i)} + {@const isCurrent = i === state.currentStage} + {@const isComplete = i < state.currentStage} +
+
+ {i + 1} + {stage.name} + {#if isComplete} + {/if}
- {:else if isComplete} -
-
-
- {:else} -
- {/if} -
- {/each} + {#if isCurrent} +
+ {#if stage.progress !== null} +
+ {:else} +
+ {/if} +
+ {:else if isComplete} +
+
+
+ {:else} +
+ {/if} +
+ {/each} +
@@ -54,6 +62,13 @@ z-index: 100; } + .content { + display: flex; + flex-direction: column; + align-items: center; + gap: var(--space-6); + } + .stages { display: flex; flex-direction: column; diff --git a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte index 95433d57e..8f85b9a16 100644 --- a/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte +++ b/spd/app/frontend/src/components/prompt-attr/InterventionsView.svelte @@ -1,10 +1,11 @@
+ {#if displaySettings.showEdgeAttributions && wteOutgoing.length > 0} + {}} + /> + {/if} {:else if isOutput} - + {:else if !hideNodeCard} {#key `${hoveredNode.layer}:${hoveredNode.cIdx}`} diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte new file mode 100644 index 000000000..3fa6a0213 --- /dev/null +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationGrid.svelte @@ -0,0 +1,134 @@ + + +
+
+ + Step {snapshot.step}/{snapshot.total_steps} + + + L0: {Math.round(snapshot.l0_total)} / {initialL0} + ({(fractionRemaining * 100).toFixed(0)}%) + + {#if snapshot.loss > 0} + loss: {snapshot.loss.toFixed(4)} + {/if} +
+ +
+ + diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte index fa3413320..83d9f0594 100644 --- a/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationParams.svelte @@ -21,42 +21,86 @@
- steps{optimization.steps} - imp_min{optimization.imp_min_coeff} - pnorm{optimization.pnorm} - beta{optimization.beta} - mask{optimization.mask_type} - + steps{optimization.steps} + imp_min{optimization.imp_min_coeff} + pnorm{optimization.pnorm} + beta{optimization.beta} + mask{optimization.mask_type} + {optimization.loss.type}{optimization.loss.coeff} - - pos{optimization.loss.position}{#if tokenAtPos !== null} - ({tokenAtPos}){/if} + + pos + {optimization.loss.position} + {#if tokenAtPos !== null} + ({tokenAtPos}) + {/if} - {#if optimization.loss.type === "ce"} - + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + label({optimization.loss.label_str}) {/if} - {#if optimization.adv_pgd_n_steps !== null} - - adv_steps{optimization.adv_pgd_n_steps} + {#if optimization.pgd} + + pgd_steps{optimization.pgd.n_steps} - - adv_lr{optimization.adv_pgd_step_size} + + pgd_lr{optimization.pgd.step_size} {/if} - + L0{optimization.metrics.l0_total.toFixed(1)} - {#if optimization.loss.type === "ce"} - + {#if optimization.loss.type === "ce" || optimization.loss.type === "logit"} + CI prob{formatProb(optimization.metrics.ci_masked_label_prob)} - + stoch prob{formatProb(optimization.metrics.stoch_masked_label_prob)} + + adv prob{formatProb(optimization.metrics.adv_pgd_label_prob)} + {/if}
diff --git a/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte index f762cdff9..3c15034b4 100644 --- a/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OptimizationSettings.svelte @@ -1,15 +1,19 @@
@@ -68,48 +68,57 @@ /> Cross-Entropy +
- -
- At position - { - if (e.currentTarget.value === "") return; - const position = parseInt(e.currentTarget.value); - onChange({ ...config, loss: { ...config.loss, position } }); - }} - min={0} - max={tokens.length - 1} - step={1} - /> - {#if tokenAtSeqPos !== null} - ({tokenAtSeqPos}) - {/if} - {#if config.loss.type === "ce"} - , predict - { - if (config.loss.type !== "ce") - throw new Error( - "inconsistent state: Token dropdown rendered but loss not type CE but no label token", - ); - - if (tokenId !== null) { - onChange({ - ...config, - loss: { ...config.loss, labelTokenId: tokenId, labelTokenText: tokenString }, - }); - } - }} - placeholder="token..." - /> - {/if} + +
+ +
+ {#each tokens as tok, i (i)} + {@const prob = getProbAtPosition(nextTokenProbs, i)} + + {/each} +
+
+ pos {config.loss.position} + {#if config.loss.type === "ce" || config.loss.type === "logit"} + {config.loss.type === "logit" ? "maximize" : "predict"} + { + if (config.loss.type !== "ce" && config.loss.type !== "logit") + throw new Error("inconsistent state: Token dropdown rendered but loss type has no label"); + + if (tokenId !== null) { + onChange({ + ...config, + loss: { ...config.loss, labelTokenId: tokenId, labelTokenText: tokenString }, + }); + } + }} + placeholder="token..." + /> + {/if} +
@@ -256,7 +265,7 @@ display: flex; flex-direction: column; gap: var(--space-3); - max-width: 400px; + max-width: 500px; } .loss-type-options { @@ -296,44 +305,99 @@ color: var(--text-primary); } - .target-section { + .position-section { display: flex; - align-items: center; + flex-direction: column; gap: var(--space-2); - flex-wrap: wrap; - padding: var(--space-2); - background: var(--bg-surface); - border: 1px solid var(--border-default); } - .target-label { + .section-label { font-size: var(--text-xs); + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.05em; color: var(--text-muted); } - .pos-input { - width: 50px; - padding: var(--space-1) var(--space-2); + .token-strip { + display: flex; + flex-wrap: wrap; + gap: 2px; + padding: var(--space-2); + background: var(--bg-inset); border: 1px solid var(--border-default); - background: var(--bg-base); - color: var(--text-primary); - font-size: var(--text-sm); font-family: var(--font-mono); + font-size: var(--text-sm); } - .pos-input:focus { - outline: none; - border-color: var(--accent-primary-dim); + .strip-token { + padding: 2px 2px; + border: 1px solid var(--border-subtle); + border-radius: 2px; + cursor: pointer; + white-space: pre; + font-family: inherit; + font-size: inherit; + color: var(--text-primary); + background: transparent; + position: relative; + transition: + border-color var(--transition-fast), + box-shadow var(--transition-fast); } - .token { - white-space: pre; + .strip-token:hover { + border-color: var(--border-strong); + } + + .strip-token.selected { + border-color: var(--accent-primary); + box-shadow: 0 0 0 1px var(--accent-primary); + z-index: 1; + } + + .strip-token::after { + content: attr(title); + position: absolute; + bottom: calc(100% + 4px); + left: 50%; + transform: translateX(-50%); + background: var(--bg-elevated); + border: 1px solid var(--border-strong); + color: var(--text-primary); + padding: var(--space-1) var(--space-2); + font-size: var(--text-xs); + white-space: nowrap; + opacity: 0; + pointer-events: none; + z-index: 100; + border-radius: var(--radius-sm); + } + + .strip-token:hover::after { + opacity: 1; + } + + .position-info { + display: flex; + align-items: center; + gap: var(--space-2); + } + + .pos-label { + font-size: var(--text-xs); font-family: var(--font-mono); + color: var(--text-muted); background: var(--bg-inset); - padding: 0 var(--space-1); + padding: var(--space-1) var(--space-2); border-radius: var(--radius-sm); } + .predict-label { + font-size: var(--text-xs); + color: var(--text-muted); + } + .slider-section { display: flex; flex-direction: column; @@ -346,14 +410,6 @@ align-items: center; } - .section-label { - font-size: var(--text-xs); - font-weight: 600; - text-transform: uppercase; - letter-spacing: 0.05em; - color: var(--text-muted); - } - .imp-min-input { width: 80px; padding: var(--space-1) var(--space-2); diff --git a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte index 4fdd57b1b..328e5449d 100644 --- a/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/OutputNodeCard.svelte @@ -1,14 +1,38 @@
@@ -61,13 +77,6 @@ 2, )})
- {#if singlePosEntry.adv_pgd_prob !== null && singlePosEntry.adv_pgd_logit !== null} -
- Adversarial: {(singlePosEntry.adv_pgd_prob * 100).toFixed(1)}% (logit: {singlePosEntry.adv_pgd_logit.toFixed( - 2, - )}) -
- {/if}

Position: @@ -83,10 +92,6 @@ Logit Target Logit - {#if hasAdvPgd} - Adv - Logit - {/if} @@ -97,15 +102,22 @@ {pos.logit.toFixed(2)} {(pos.target_prob * 100).toFixed(2)}% {pos.target_logit.toFixed(2)} - {#if hasAdvPgd} - {pos.adv_pgd_prob !== null ? (pos.adv_pgd_prob * 100).toFixed(2) + "%" : "—"} - {pos.adv_pgd_logit !== null ? pos.adv_pgd_logit.toFixed(2) : "—"} - {/if} {/each} {/if} + {#if displaySettings.showEdgeAttributions && outputIncoming.length > 0} + {}} + /> + {/if}

diff --git a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte index 077917234..63a1062d0 100644 --- a/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte +++ b/spd/app/frontend/src/components/ui/DisplaySettingsDropdown.svelte @@ -1,16 +1,19 @@
Center on peak +
+
+

Edge Variant

+

Attribution target: value or |value|

+
+ {#each edgeVariants as variant (variant)} + + {/each} +
+

Component Filtering

Filter components in Components tab by mean CI

diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte index 844cb7c04..ad3f821f5 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte @@ -1,5 +1,5 @@ -{#if hasAnyIncoming} +{#if incoming.length > 0}
-
- {#if incomingPositive.length > 0} -
- -
- {/if} - {#if incomingNegative.length > 0} -
- -
- {/if} -
+
{/if} -{#if hasAnyOutgoing} +{#if outgoing.length > 0}
-
- {#if outgoingPositive.length > 0} -
- -
- {/if} - {#if outgoingNegative.length > 0} -
- -
- {/if} -
+
{/if} @@ -110,17 +36,4 @@ flex-direction: column; gap: var(--space-2); } - - .pos-neg-row { - display: grid; - grid-template-columns: 1fr 1fr; - gap: var(--space-3); - } - - .edge-list { - min-width: 0; - display: flex; - flex-direction: column; - gap: var(--space-1); - } diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte index aa98848b5..3809aceaf 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte @@ -1,7 +1,8 @@ + +
+ + + {#if expanded && detail} +
+
+
+ Input + {#if detail.input?.reasoning} +

{detail.input.reasoning}

+ {/if} + {#each incomingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+ Output + {#if detail.output?.reasoning} +

{detail.output.reasoning}

+ {/if} + {#each outgoingEdges as edge (edge.related_key)} +
+ {formatComponentKey(edge.related_key, edge.token_str)} + 0} + class:negative={edge.attribution < 0} + > + {edge.attribution > 0 ? "+" : ""}{edge.attribution.toFixed(3)} + + {#if edge.related_label} + {edge.related_label} + {/if} +
+ {/each} +
+
+
+ {/if} +
+ + diff --git a/spd/app/frontend/src/components/ui/TokenSpan.svelte b/spd/app/frontend/src/components/ui/TokenSpan.svelte new file mode 100644 index 000000000..4a5fa9b81 --- /dev/null +++ b/spd/app/frontend/src/components/ui/TokenSpan.svelte @@ -0,0 +1,43 @@ + + +{sanitizeToken(token)} + + diff --git a/spd/app/frontend/src/lib/api/correlations.ts b/spd/app/frontend/src/lib/api/correlations.ts index 2e56c3c7e..8dcc63f04 100644 --- a/spd/app/frontend/src/lib/api/correlations.ts +++ b/spd/app/frontend/src/lib/api/correlations.ts @@ -3,7 +3,7 @@ */ import type { SubcomponentCorrelationsResponse, TokenStatsResponse } from "../promptAttributionsTypes"; -import { apiUrl, fetchJson } from "./index"; +import { ApiError, apiUrl, fetchJson } from "./index"; export async function getComponentCorrelations( layer: string, @@ -47,10 +47,18 @@ export async function getIntruderScores(): Promise> { return fetchJson>("/api/correlations/intruder_scores"); } -export async function getInterpretationDetail(layer: string, componentIdx: number): Promise { - return fetchJson( - `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, - ); +export async function getInterpretationDetail( + layer: string, + componentIdx: number, +): Promise { + try { + return await fetchJson( + `/api/correlations/interpretations/${encodeURIComponent(layer)}/${componentIdx}`, + ); + } catch (e) { + if (e instanceof ApiError && e.status === 404) return null; + throw e; + } } export async function requestComponentInterpretation( diff --git a/spd/app/frontend/src/lib/api/dataSources.ts b/spd/app/frontend/src/lib/api/dataSources.ts index e715af1b1..ac20b7220 100644 --- a/spd/app/frontend/src/lib/api/dataSources.ts +++ b/spd/app/frontend/src/lib/api/dataSources.ts @@ -20,15 +20,21 @@ export type AutointerpInfo = { export type AttributionsInfo = { subrun_id: string; - n_batches_processed: number; n_tokens_processed: number; ci_threshold: number; }; +export type GraphInterpInfoDS = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + export type DataSourcesResponse = { harvest: HarvestInfo | null; autointerp: AutointerpInfo | null; attributions: AttributionsInfo | null; + graph_interp: GraphInterpInfoDS | null; }; export async function fetchDataSources(): Promise { diff --git a/spd/app/frontend/src/lib/api/datasetAttributions.ts b/spd/app/frontend/src/lib/api/datasetAttributions.ts index f995a33f6..030eae6c6 100644 --- a/spd/app/frontend/src/lib/api/datasetAttributions.ts +++ b/spd/app/frontend/src/lib/api/datasetAttributions.ts @@ -9,15 +9,23 @@ export type DatasetAttributionEntry = { layer: string; component_idx: number; value: number; + token_str: string | null; }; -export type ComponentAttributions = { +export type SignedAttributions = { positive_sources: DatasetAttributionEntry[]; negative_sources: DatasetAttributionEntry[]; positive_targets: DatasetAttributionEntry[]; negative_targets: DatasetAttributionEntry[]; }; +export type AttrMetric = "attr" | "attr_abs"; + +export type AllMetricAttributions = { + attr: SignedAttributions; + attr_abs: SignedAttributions; +}; + export type DatasetAttributionsMetadata = { available: boolean; }; @@ -30,8 +38,8 @@ export async function getComponentAttributions( layer: string, componentIdx: number, k: number = 10, -): Promise { +): Promise { const url = apiUrl(`/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("k", String(k)); - return fetchJson(url.toString()); + return fetchJson(url.toString()); } diff --git a/spd/app/frontend/src/lib/api/graphInterp.ts b/spd/app/frontend/src/lib/api/graphInterp.ts new file mode 100644 index 000000000..8229e757c --- /dev/null +++ b/spd/app/frontend/src/lib/api/graphInterp.ts @@ -0,0 +1,81 @@ +/** + * API client for /api/graph_interp endpoints. + */ + +import { fetchJson } from "./index"; + +export type GraphInterpHeadline = { + label: string; + confidence: string; + output_label: string | null; + input_label: string | null; +}; + +export type LabelDetail = { + label: string; + confidence: string; + reasoning: string; + prompt: string; +}; + +export type GraphInterpDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; +}; + +export type PromptEdgeResponse = { + related_key: string; + pass_name: string; + attribution: number; + related_label: string | null; + related_confidence: string | null; + token_str: string | null; +}; + +export type GraphInterpComponentDetail = { + output: LabelDetail | null; + input: LabelDetail | null; + unified: LabelDetail | null; + edges: PromptEdgeResponse[]; +}; + +export type GraphNode = { + component_key: string; + label: string; + confidence: string; +}; + +export type GraphEdge = { + source: string; + target: string; + attribution: number; + pass_name: string; +}; + +export type ModelGraphResponse = { + nodes: GraphNode[]; + edges: GraphEdge[]; +}; + +export type GraphInterpInfo = { + subrun_id: string; + config: Record | null; + label_counts: Record; +}; + +export async function getAllGraphInterpLabels(): Promise> { + return fetchJson>("/api/graph_interp/labels"); +} + +export async function getGraphInterpDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/labels/${layer}/${cIdx}`); +} + +export async function getGraphInterpComponentDetail(layer: string, cIdx: number): Promise { + return fetchJson(`/api/graph_interp/detail/${layer}/${cIdx}`); +} + +export async function getModelGraph(): Promise { + return fetchJson("/api/graph_interp/graph"); +} diff --git a/spd/app/frontend/src/lib/api/graphs.ts b/spd/app/frontend/src/lib/api/graphs.ts index 42490d531..125d0cd0e 100644 --- a/spd/app/frontend/src/lib/api/graphs.ts +++ b/spd/app/frontend/src/lib/api/graphs.ts @@ -2,11 +2,25 @@ * API client for /api/graphs endpoints. */ -import type { GraphData, TokenizeResponse, TokenInfo } from "../promptAttributionsTypes"; +import type { GraphData, EdgeData, TokenizeResponse, TokenSearchResult, CISnapshot } from "../promptAttributionsTypes"; import { buildEdgeIndexes } from "../promptAttributionsTypes"; -import { setArchitecture } from "../layerAliasing"; import { apiUrl, ApiError, fetchJson } from "./index"; +/** Hydrate a raw API graph response into a full GraphData with edge indexes. */ +function hydrateGraph(raw: Record): GraphData { + const g = raw as Omit; + const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); + const edgesAbs = (g.edgesAbs as EdgeData[] | null) ?? null; + let edgesAbsBySource: Map | null = null; + let edgesAbsByTarget: Map | null = null; + if (edgesAbs) { + const absIndexes = buildEdgeIndexes(edgesAbs); + edgesAbsBySource = absIndexes.edgesBySource; + edgesAbsByTarget = absIndexes.edgesByTarget; + } + return { ...g, edgesBySource, edgesByTarget, edgesAbs, edgesAbsBySource, edgesAbsByTarget } as GraphData; +} + export type NormalizeType = "none" | "target" | "layer"; export type GraphProgress = { @@ -30,6 +44,7 @@ export type ComputeGraphParams = { async function parseGraphSSEStream( response: Response, onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, ): Promise { const reader = response.body?.getReader(); if (!reader) { @@ -56,19 +71,12 @@ async function parseGraphSSEStream( if (data.type === "progress" && onProgress) { onProgress({ current: data.current, total: data.total, stage: data.stage }); + } else if (data.type === "ci_snapshot" && onCISnapshot) { + onCISnapshot(data as CISnapshot); } else if (data.type === "error") { throw new ApiError(data.error, 500); } else if (data.type === "complete") { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of data.data.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(data.data.edges); - result = { ...data.data, edgesBySource, edgesByTarget }; + result = hydrateGraph(data.data); await reader.cancel(); break; } @@ -106,7 +114,7 @@ export async function computeGraphStream( } export type MaskType = "stochastic" | "ci"; -export type LossType = "ce" | "kl"; +export type LossType = "ce" | "kl" | "logit"; export type ComputeGraphOptimizedParams = { promptId: number; @@ -128,6 +136,7 @@ export type ComputeGraphOptimizedParams = { export async function computeGraphOptimizedStream( params: ComputeGraphOptimizedParams, onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, ): Promise { const url = apiUrl("/api/graphs/optimized/stream"); url.searchParams.set("prompt_id", String(params.promptId)); @@ -157,26 +166,121 @@ export async function computeGraphOptimizedStream( throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); } - return parseGraphSSEStream(response, onProgress); + return parseGraphSSEStream(response, onProgress, onCISnapshot); +} + +export type ComputeGraphOptimizedBatchParams = { + promptId: number; + impMinCoeffs: number[]; + steps: number; + pnorm: number; + beta: number; + normalize: NormalizeType; + ciThreshold: number; + maskType: MaskType; + lossType: LossType; + lossCoeff: number; + lossPosition: number; + labelToken?: number; + advPgdNSteps?: number; + advPgdStepSize?: number; +}; + +export async function computeGraphOptimizedBatchStream( + params: ComputeGraphOptimizedBatchParams, + onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { + const url = apiUrl("/api/graphs/optimized/batch/stream"); + + const body: Record = { + prompt_id: params.promptId, + imp_min_coeffs: params.impMinCoeffs, + steps: params.steps, + pnorm: params.pnorm, + beta: params.beta, + normalize: params.normalize, + ci_threshold: params.ciThreshold, + mask_type: params.maskType, + loss_type: params.lossType, + loss_coeff: params.lossCoeff, + loss_position: params.lossPosition, + }; + if (params.labelToken !== undefined) body.label_token = params.labelToken; + if (params.advPgdNSteps !== undefined) body.adv_pgd_n_steps = params.advPgdNSteps; + if (params.advPgdStepSize !== undefined) body.adv_pgd_step_size = params.advPgdStepSize; + + const response = await fetch(url.toString(), { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(body), + }); + if (!response.ok) { + const error = await response.json(); + throw new ApiError(error.detail || `HTTP ${response.status}`, response.status); + } + + return parseBatchGraphSSEStream(response, onProgress, onCISnapshot); +} + +async function parseBatchGraphSSEStream( + response: Response, + onProgress?: (progress: GraphProgress) => void, + onCISnapshot?: (snapshot: CISnapshot) => void, +): Promise { + const reader = response.body?.getReader(); + if (!reader) { + throw new Error("Response body is not readable"); + } + + const decoder = new TextDecoder(); + let buffer = ""; + let result: GraphData[] | null = null; + + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + buffer += decoder.decode(value, { stream: true }); + + const lines = buffer.split("\n\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (!line.trim() || !line.startsWith("data: ")) continue; + + const data = JSON.parse(line.substring(6)); + + if (data.type === "progress" && onProgress) { + onProgress({ current: data.current, total: data.total, stage: data.stage }); + } else if (data.type === "ci_snapshot" && onCISnapshot) { + onCISnapshot(data as CISnapshot); + } else if (data.type === "error") { + throw new ApiError(data.error, 500); + } else if (data.type === "complete") { + const graphs: GraphData[] = data.data.graphs.map((g: Record) => hydrateGraph(g)); + result = graphs; + await reader.cancel(); + break; + } + } + + if (result) break; + } + + if (!result) { + throw new Error("No result received from stream"); + } + + return result; } export async function getGraphs(promptId: number, normalize: NormalizeType, ciThreshold: number): Promise { const url = apiUrl(`/api/graphs/${promptId}`); url.searchParams.set("normalize", normalize); url.searchParams.set("ci_threshold", String(ciThreshold)); - const graphs = await fetchJson[]>(url.toString()); - return graphs.map((g) => { - // Extract all unique layer names from edges to detect architecture - const layerNames = new Set(); - for (const edge of g.edges) { - layerNames.add(edge.src.split(":")[0]); - layerNames.add(edge.tgt.split(":")[0]); - } - setArchitecture(Array.from(layerNames)); - - const { edgesBySource, edgesByTarget } = buildEdgeIndexes(g.edges); - return { ...g, edgesBySource, edgesByTarget }; - }); + const graphs = await fetchJson[]>(url.toString()); + return graphs.map((g) => hydrateGraph(g)); } export async function tokenizeText(text: string): Promise { @@ -185,15 +289,17 @@ export async function tokenizeText(text: string): Promise { return fetchJson(url.toString(), { method: "POST" }); } -export async function getAllTokens(): Promise { - const response = await fetchJson<{ tokens: TokenInfo[] }>("/api/graphs/tokens"); - return response.tokens; -} - -export async function searchTokens(query: string, limit: number = 10): Promise { +export async function searchTokens( + query: string, + promptId: number, + position: number, + limit: number = 20, +): Promise { const url = apiUrl("/api/graphs/tokens/search"); url.searchParams.set("q", query); url.searchParams.set("limit", String(limit)); - const response = await fetchJson<{ tokens: TokenInfo[] }>(url.toString()); + url.searchParams.set("prompt_id", String(promptId)); + url.searchParams.set("position", String(position)); + const response = await fetchJson<{ tokens: TokenSearchResult[] }>(url.toString()); return response.tokens; } diff --git a/spd/app/frontend/src/lib/api/index.ts b/spd/app/frontend/src/lib/api/index.ts index 773663636..d2d810283 100644 --- a/spd/app/frontend/src/lib/api/index.ts +++ b/spd/app/frontend/src/lib/api/index.ts @@ -51,5 +51,8 @@ export * from "./datasetAttributions"; export * from "./intervention"; export * from "./dataset"; export * from "./clusters"; +export * from "./investigations"; export * from "./dataSources"; +export * from "./graphInterp"; export * from "./pretrainInfo"; +export * from "./runRegistry"; diff --git a/spd/app/frontend/src/lib/api/intervention.ts b/spd/app/frontend/src/lib/api/intervention.ts index 689c29cc1..154228181 100644 --- a/spd/app/frontend/src/lib/api/intervention.ts +++ b/spd/app/frontend/src/lib/api/intervention.ts @@ -2,11 +2,7 @@ * API client for /api/intervention endpoints. */ -import type { - ForkedInterventionRunSummary, - InterventionRunSummary, - RunInterventionRequest, -} from "../interventionTypes"; +import type { InterventionRunSummary, RunInterventionRequest } from "../interventionTypes"; export async function runAndSaveIntervention(request: RunInterventionRequest): Promise { const response = await fetch("/api/intervention/run", { @@ -39,30 +35,3 @@ export async function deleteInterventionRun(runId: number): Promise { throw new Error(error.detail || "Failed to delete intervention run"); } } - -export async function forkInterventionRun( - runId: number, - tokenReplacements: [number, number][], - topK: number = 10, -): Promise { - const response = await fetch(`/api/intervention/runs/${runId}/fork`, { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ token_replacements: tokenReplacements, top_k: topK }), - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to fork intervention run"); - } - return (await response.json()) as ForkedInterventionRunSummary; -} - -export async function deleteForkedInterventionRun(forkId: number): Promise { - const response = await fetch(`/api/intervention/forks/${forkId}`, { - method: "DELETE", - }); - if (!response.ok) { - const error = await response.json(); - throw new Error(error.detail || "Failed to delete forked intervention run"); - } -} diff --git a/spd/app/frontend/src/lib/api/investigations.ts b/spd/app/frontend/src/lib/api/investigations.ts new file mode 100644 index 000000000..42f1fb1f3 --- /dev/null +++ b/spd/app/frontend/src/lib/api/investigations.ts @@ -0,0 +1,101 @@ +/** + * API client for investigation results. + */ + +export interface InvestigationSummary { + id: string; // inv_id (e.g., "inv-abc12345") + wandb_path: string | null; + prompt: string | null; + created_at: string; + has_research_log: boolean; + has_explanations: boolean; + event_count: number; + last_event_time: string | null; + last_event_message: string | null; + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; // in_progress, completed, inconclusive +} + +export interface EventEntry { + event_type: string; + timestamp: string; + message: string; + details: Record | null; +} + +export interface InvestigationDetail { + id: string; + wandb_path: string | null; + prompt: string | null; + created_at: string; + research_log: string | null; + events: EventEntry[]; + explanations: Record[]; + artifact_ids: string[]; // List of artifact IDs available for this investigation + // Agent-provided summary + title: string | null; + summary: string | null; + status: string | null; +} + +import type { EdgeData, OutputProbability } from "../promptAttributionsTypes"; + +/** Data for a graph artifact (subset of GraphData, self-contained for offline viewing) */ +export interface ArtifactGraphData { + tokens: string[]; + edges: EdgeData[]; + outputProbs: Record; + nodeCiVals: Record; + nodeSubcompActs: Record; + maxAbsAttr: number; + l0_total: number; +} + +export interface GraphArtifact { + type: "graph"; + id: string; + caption: string | null; + graph_id: number; + data: ArtifactGraphData; +} + +export interface LaunchResponse { + inv_id: string; + job_id: string; +} + +export async function launchInvestigation(prompt: string): Promise { + const res = await fetch("/api/investigations/launch", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ prompt }), + }); + if (!res.ok) throw new Error(`Failed to launch investigation: ${res.statusText}`); + return res.json(); +} + +export async function listInvestigations(): Promise { + const res = await fetch("/api/investigations"); + if (!res.ok) throw new Error(`Failed to list investigations: ${res.statusText}`); + return res.json(); +} + +export async function getInvestigation(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}`); + if (!res.ok) throw new Error(`Failed to get investigation: ${res.statusText}`); + return res.json(); +} + +export async function listArtifacts(invId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts`); + if (!res.ok) throw new Error(`Failed to list artifacts: ${res.statusText}`); + return res.json(); +} + +export async function getArtifact(invId: string, artifactId: string): Promise { + const res = await fetch(`/api/investigations/${invId}/artifacts/${artifactId}`); + if (!res.ok) throw new Error(`Failed to get artifact: ${res.statusText}`); + return res.json(); +} diff --git a/spd/app/frontend/src/lib/api/pretrainInfo.ts b/spd/app/frontend/src/lib/api/pretrainInfo.ts index 7092c735a..0cd66bd97 100644 --- a/spd/app/frontend/src/lib/api/pretrainInfo.ts +++ b/spd/app/frontend/src/lib/api/pretrainInfo.ts @@ -20,6 +20,7 @@ export type TopologyInfo = { export type PretrainInfoResponse = { model_type: string; summary: string; + dataset_short: string | null; target_model_config: Record | null; pretrain_config: Record | null; pretrain_wandb_path: string | null; diff --git a/spd/app/frontend/src/lib/api/prompts.ts b/spd/app/frontend/src/lib/api/prompts.ts index 25763f989..5b5ee276a 100644 --- a/spd/app/frontend/src/lib/api/prompts.ts +++ b/spd/app/frontend/src/lib/api/prompts.ts @@ -14,3 +14,7 @@ export async function createCustomPrompt(text: string): Promise { url.searchParams.set("text", text); return fetchJson(url.toString(), { method: "POST" }); } + +export async function deletePrompt(promptId: number): Promise { + await fetchJson(`/api/prompts/${promptId}`, { method: "DELETE" }); +} diff --git a/spd/app/frontend/src/lib/api/runRegistry.ts b/spd/app/frontend/src/lib/api/runRegistry.ts new file mode 100644 index 000000000..c727f4dcc --- /dev/null +++ b/spd/app/frontend/src/lib/api/runRegistry.ts @@ -0,0 +1,26 @@ +/** + * API client for /api/run_registry endpoint. + */ + +import { fetchJson } from "./index"; + +export type DataAvailability = { + harvest: boolean; + autointerp: boolean; + attributions: boolean; + graph_interp: boolean; +}; + +export type RunInfoResponse = { + wandb_run_id: string; + architecture: string | null; + availability: DataAvailability; +}; + +export async function fetchRunInfo(wandbRunIds: string[]): Promise { + return fetchJson("/api/run_registry", { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(wandbRunIds), + }); +} diff --git a/spd/app/frontend/src/lib/api/runs.ts b/spd/app/frontend/src/lib/api/runs.ts index 1430632a4..d898c8671 100644 --- a/spd/app/frontend/src/lib/api/runs.ts +++ b/spd/app/frontend/src/lib/api/runs.ts @@ -14,6 +14,8 @@ export type LoadedRun = { backend_user: string; dataset_attributions_available: boolean; dataset_search_enabled: boolean; + graph_interp_available: boolean; + autointerp_available: boolean; }; export async function getStatus(): Promise { diff --git a/spd/app/frontend/src/lib/colors.ts b/spd/app/frontend/src/lib/colors.ts index e64cc696d..d15462693 100644 --- a/spd/app/frontend/src/lib/colors.ts +++ b/spd/app/frontend/src/lib/colors.ts @@ -7,17 +7,17 @@ */ export const colors = { - // Text - punchy contrast (matches --text-*) - textPrimary: "#111111", - textSecondary: "#555555", - textMuted: "#999999", + // Text - warm navy contrast (matches --text-*) + textPrimary: "#1d272a", + textSecondary: "#646464", + textMuted: "#b4b4b4", // Status colors for edges/data (matches --accent-primary, --status-negative) - positive: "#2563eb", + positive: "#4d65ff", negative: "#dc2626", // RGB components for dynamic opacity - positiveRgb: { r: 37, g: 99, b: 235 }, // blue - matches --accent-primary + positiveRgb: { r: 77, g: 101, b: 255 }, // vibrant blue - matches --accent-primary negativeRgb: { r: 220, g: 38, b: 38 }, // red - matches --status-negative // Output node gradient (green) - matches --status-positive @@ -28,10 +28,10 @@ export const colors = { tokenHighlightOpacity: 0.4, // Node default - nodeDefault: "#6b7280", + nodeDefault: "#8a8780", // Accent (for active states) - matches --accent-primary - accent: "#2563eb", + accent: "#7C4D33", // Set overlap visualization (A/B/intersection) setOverlap: { diff --git a/spd/app/frontend/src/lib/componentKeys.ts b/spd/app/frontend/src/lib/componentKeys.ts new file mode 100644 index 000000000..ff83bda06 --- /dev/null +++ b/spd/app/frontend/src/lib/componentKeys.ts @@ -0,0 +1,17 @@ +/** + * Utilities for component key display (e.g. rendering embed/output keys with token strings). + */ + +export function isTokenNode(key: string): boolean { + const layer = key.split(":")[0]; + return layer === "embed" || layer === "output"; +} + +export function formatComponentKey(key: string, tokenStr: string | null): string { + if (tokenStr && isTokenNode(key)) { + const layer = key.split(":")[0]; + const label = layer === "embed" ? "input" : "output"; + return `'${tokenStr}' (${label})`; + } + return key; +} diff --git a/spd/app/frontend/src/lib/displaySettings.svelte.ts b/spd/app/frontend/src/lib/displaySettings.svelte.ts index db3e3f7c9..6998214ee 100644 --- a/spd/app/frontend/src/lib/displaySettings.svelte.ts +++ b/spd/app/frontend/src/lib/displaySettings.svelte.ts @@ -13,6 +13,14 @@ export const NODE_COLOR_MODE_LABELS: Record = { subcomp_act: "Subcomp Act", }; +// Edge variant for attribution graphs +export type EdgeVariant = "signed" | "abs_target"; + +export const EDGE_VARIANT_LABELS: Record = { + signed: "Signed", + abs_target: "Abs Target", +}; + // Example color mode for activation contexts viewer export type ExampleColorMode = "ci" | "component_act" | "both"; @@ -48,6 +56,8 @@ export const displaySettings = $state({ meanCiCutoff: 1e-7, centerOnPeak: false, showAutoInterpPromptButton: false, + curvedEdges: true, + edgeVariant: "signed" as EdgeVariant, }); export function anyCorrelationStatsEnabled() { diff --git a/spd/app/frontend/src/lib/interventionTypes.ts b/spd/app/frontend/src/lib/interventionTypes.ts index c364be243..de8e1af4d 100644 --- a/spd/app/frontend/src/lib/interventionTypes.ts +++ b/spd/app/frontend/src/lib/interventionTypes.ts @@ -1,46 +1,110 @@ /** Types for the intervention forward pass feature */ -export type InterventionNode = { - layer: string; - seq_pos: number; - component_idx: number; -}; +/** Default eval PGD settings (distinct from training PGD which is an optimization regularizer) */ +export const EVAL_PGD_N_STEPS = 4; +export const EVAL_PGD_STEP_SIZE = 1.0; export type TokenPrediction = { token: string; token_id: number; - spd_prob: number; - target_prob: number; + prob: number; logit: number; + target_prob: number; target_logit: number; }; -export type InterventionResponse = { - input_tokens: string[]; - predictions_per_position: TokenPrediction[][]; +export type LabelPredictions = { + position: number; + ci: TokenPrediction; + stochastic: TokenPrediction; + adversarial: TokenPrediction; + ablated: TokenPrediction | null; }; -/** A forked intervention run with modified tokens */ -export type ForkedInterventionRunSummary = { - id: number; - token_replacements: [number, number][]; // [(seq_pos, new_token_id), ...] - result: InterventionResponse; - created_at: string; +export type InterventionResult = { + input_tokens: string[]; + ci: TokenPrediction[][]; + stochastic: TokenPrediction[][]; + adversarial: TokenPrediction[][]; + ablated: TokenPrediction[][] | null; + ci_loss: number; + stochastic_loss: number; + adversarial_loss: number; + ablated_loss: number | null; + label: LabelPredictions | null; }; /** Persisted intervention run from the server */ export type InterventionRunSummary = { id: number; selected_nodes: string[]; // node keys (layer:seq:cIdx) - result: InterventionResponse; + result: InterventionResult; created_at: string; - forked_runs?: ForkedInterventionRunSummary[]; // child runs with modified tokens }; /** Request to run and save an intervention */ export type RunInterventionRequest = { graph_id: number; - text: string; selected_nodes: string[]; - top_k?: number; + nodes_to_ablate?: string[]; + top_k: number; + adv_pgd: { n_steps: number; step_size: number }; +}; + +// --- Frontend-only run lifecycle types --- + +import { SvelteSet } from "svelte/reactivity"; +import { isInterventableNode } from "./promptAttributionsTypes"; + +/** Draft run: cloned from a parent, editable node selection. No forwarded results yet. */ +export type DraftRun = { + kind: "draft"; + parentId: number; + selectedNodes: SvelteSet; +}; + +/** Baked run: forwarded and immutable. Wraps a persisted InterventionRunSummary. */ +export type BakedRun = { + kind: "baked"; + id: number; + selectedNodes: Set; + result: InterventionResult; + createdAt: string; +}; + +export type InterventionRun = DraftRun | BakedRun; + +export type InterventionState = { + runs: InterventionRun[]; + activeIndex: number; }; + +/** Whether a run's selection is editable */ +export function isRunEditable(run: InterventionRun): run is DraftRun { + return run.kind === "draft"; +} + +/** Build initial InterventionState from persisted runs. + * The first persisted run is the base run (all CI > 0 nodes), auto-created during graph computation. */ +export function buildInterventionState(persistedRuns: InterventionRunSummary[]): InterventionState { + if (persistedRuns.length === 0) throw new Error("Graph must have at least one intervention run (the base run)"); + const runs: InterventionRun[] = persistedRuns.map( + (r): BakedRun => ({ + kind: "baked", + id: r.id, + selectedNodes: new Set(r.selected_nodes), + result: r.result, + createdAt: r.created_at, + }), + ); + return { runs, activeIndex: 0 }; +} + +/** Get all interventable node keys with CI > 0 from a nodeCiVals record */ +export function getInterventableNodes(nodeCiVals: Record): Set { + const nodes = new Set(); + for (const [nodeKey, ci] of Object.entries(nodeCiVals)) { + if (isInterventableNode(nodeKey) && ci > 0) nodes.add(nodeKey); + } + return nodes; +} diff --git a/spd/app/frontend/src/lib/layerAliasing.ts b/spd/app/frontend/src/lib/layerAliasing.ts deleted file mode 100644 index 2c5269543..000000000 --- a/spd/app/frontend/src/lib/layerAliasing.ts +++ /dev/null @@ -1,219 +0,0 @@ -/** - * Layer aliasing system - transforms internal module names to human-readable aliases. - * - * Formats: - * - Internal: "h.0.mlp.c_fc", "h.1.attn.q_proj" - * - Aliased: "L0.mlp.in", "L1.attn.q" - * - * Handles multiple architectures: - * - GPT-2: c_fc -> mlp.in, down_proj -> mlp.out - * - Llama SwiGLU: gate_proj -> mlp.gate, up_proj -> mlp.up, down_proj -> mlp.down - * - Attention: q_proj -> attn.q, k_proj -> attn.k, v_proj -> attn.v, o_proj -> attn.o - * - Special: lm_head -> W_U, embed/output unchanged - */ - -type Architecture = "gpt2" | "llama" | "unknown"; - -/** Mapping of internal module names to aliases by architecture */ -const ALIASES: Record> = { - gpt2: { - // MLP - c_fc: "in", - down_proj: "out", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - llama: { - // MLP (SwiGLU) - gate_proj: "gate", - up_proj: "up", - down_proj: "down", - // Attention - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, - unknown: { - // Fallback - just do attention mappings - q_proj: "q", - k_proj: "k", - v_proj: "v", - o_proj: "o", - }, -}; - -/** Special layers with fixed display names */ -const SPECIAL_LAYERS: Record = { - lm_head: "W_U", - embed: "embed", - output: "output", -}; - -// Cache for detected architecture from the full model -let cachedArchitecture: Architecture | null = null; - -/** - * Detect architecture from a collection of layer names. - * Llama has gate_proj/up_proj, GPT-2 has c_fc. - * - * This should be called once with all available layer names to establish - * the architecture for the session, ensuring down_proj is aliased correctly. - */ -export function detectArchitectureFromLayers(layers: string[]): Architecture { - const hasLlamaLayers = layers.some((layer) => layer.includes("gate_proj") || layer.includes("up_proj")); - if (hasLlamaLayers) { - return "llama"; - } - - const hasGPT2Layers = layers.some((layer) => layer.includes("c_fc")); - if (hasGPT2Layers) { - return "gpt2"; - } - - return "unknown"; -} - -/** - * Set the architecture for aliasing operations. - * Call this when you have access to all layer names (e.g., when loading a graph). - */ -export function setArchitecture(layers: string[]): void { - cachedArchitecture = detectArchitectureFromLayers(layers); -} - -/** - * Detect architecture from layer name. - * Uses cached architecture if available (set via setArchitecture()), - * otherwise falls back to single-layer detection. - * - * Note: down_proj appears in both architectures with different meanings: - * - GPT-2: down_proj -> "out" (second MLP projection) - * - Llama: down_proj -> "down" (third MLP projection after gate/up) - * - * Single-layer detection cannot distinguish these cases reliably. - */ -function detectArchitecture(layer: string): Architecture { - // Use cached architecture if available - if (cachedArchitecture !== null) { - return cachedArchitecture; - } - - // Fallback: single-layer detection (less reliable for down_proj) - if (layer.includes("gate_proj") || layer.includes("up_proj")) { - return "llama"; - } - if (layer.includes("c_fc")) { - return "gpt2"; - } - // down_proj is ambiguous without context, default to GPT-2 - if (layer.includes("down_proj")) { - return "gpt2"; - } - return "unknown"; -} - -/** - * Parse a layer name into components. - * Returns null for special layers (embed, output, lm_head) or unrecognized formats. - */ -function parseLayerName(layer: string): { block: number; moduleType: string; submodule: string } | null { - if (layer in SPECIAL_LAYERS) { - return null; - } - - const match = layer.match(/^h\.(\d+)\.(attn|mlp)\.(\w+)$/); - if (!match) { - return null; - } - - const [, blockStr, moduleType, submodule] = match; - return { - block: parseInt(blockStr), - moduleType, - submodule, - }; -} - -/** - * Transform a layer name to its aliased form. - * - * Examples: - * - "h.0.mlp.c_fc" -> "L0.mlp.in" - * - "h.2.attn.q_proj" -> "L2.attn.q" - * - "lm_head" -> "W_U" - * - "embed" -> "embed" - */ -export function getLayerAlias(layer: string): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Get a row label for grouped display in graphs. - * - * @param layer - Internal layer name (e.g., "h.0.mlp.c_fc") - * @param isQkvGroup - Whether this represents a grouped QKV row - * @returns Label (e.g., "L0.mlp.in", "L2.attn.qkv") - * - * @example - * getAliasedRowLabel("h.0.mlp.c_fc") // => "L0.mlp.in" - * getAliasedRowLabel("h.2.attn.q_proj", true) // => "L2.attn.qkv" - */ -export function getAliasedRowLabel(layer: string, isQkvGroup = false): string { - if (layer in SPECIAL_LAYERS) { - return SPECIAL_LAYERS[layer]; - } - - const parsed = parseLayerName(layer); - if (!parsed) { - return layer; - } - - if (isQkvGroup) { - return `L${parsed.block}.${parsed.moduleType}.qkv`; - } - - const arch = detectArchitecture(layer); - const alias = ALIASES[arch][parsed.submodule]; - - if (!alias) { - return `L${parsed.block}.${parsed.moduleType}.${parsed.submodule}`; - } - - return `L${parsed.block}.${parsed.moduleType}.${alias}`; -} - -/** - * Format a node key with aliased layer names. - * - * Node keys are "layer:seq:cIdx" or "layer:cIdx" format. - * - * Examples: - * - "h.0.mlp.c_fc:3:5" -> "L0.mlp.in:3:5" - * - "h.1.attn.q_proj:2:10" -> "L1.attn.q:2:10" - */ -export function formatNodeKeyWithAliases(nodeKey: string): string { - const parts = nodeKey.split(":"); - const layer = parts[0]; - const aliasedLayer = getLayerAlias(layer); - return [aliasedLayer, ...parts.slice(1)].join(":"); -} diff --git a/spd/app/frontend/src/lib/promptAttributionsTypes.ts b/spd/app/frontend/src/lib/promptAttributionsTypes.ts index fc705fad0..b40a63641 100644 --- a/spd/app/frontend/src/lib/promptAttributionsTypes.ts +++ b/spd/app/frontend/src/lib/promptAttributionsTypes.ts @@ -20,6 +20,7 @@ export type EdgeAttribution = { key: string; // "layer:seq:cIdx" for prompt or "layer:cIdx" for dataset value: number; // raw attribution value (positive or negative) normalizedMagnitude: number; // |value| / maxAbsValue, for color intensity (0-1) + tokenStr: string | null; // resolved token string for embed/output layers }; export type OutputProbability = { @@ -27,11 +28,20 @@ export type OutputProbability = { logit: number; // CI-masked (SPD model) raw logit target_prob: number; // Target model probability target_logit: number; // Target model raw logit - adv_pgd_prob: number | null; // Adversarial PGD probability - adv_pgd_logit: number | null; // Adversarial PGD raw logit token: string; }; +export type CISnapshot = { + step: number; + total_steps: number; + layers: string[]; + seq_len: number; + initial_alive: number[][]; + current_alive: number[][]; + l0_total: number; + loss: number; +}; + export type GraphType = "standard" | "optimized" | "manual"; export type GraphData = { @@ -41,10 +51,15 @@ export type GraphData = { edges: EdgeData[]; edgesBySource: Map; // nodeKey -> edges where this node is source edgesByTarget: Map; // nodeKey -> edges where this node is target + // Absolute-target variant (∂|y|/∂x · x), null for old graphs + edgesAbs: EdgeData[] | null; + edgesAbsBySource: Map | null; + edgesAbsByTarget: Map | null; outputProbs: Record; // key is "seq:cIdx" nodeCiVals: Record; // node key -> CI value (or output prob for output nodes or 1 for wte node) nodeSubcompActs: Record; // node key -> subcomponent activation (v_i^T @ a) maxAbsAttr: number; // max absolute edge value + maxAbsAttrAbs: number | null; // max absolute edge value for abs-target variant maxAbsSubcompAct: number; // max absolute subcomponent activation for normalization l0_total: number; // total active components at current CI threshold optimization?: OptimizationResult; @@ -93,7 +108,15 @@ export type KLLossResult = { position: number; }; -export type LossResult = CELossResult | KLLossResult; +export type LogitLossResult = { + type: "logit"; + coeff: number; + position: number; + label_token: number; + label_str: string; +}; + +export type LossResult = CELossResult | KLLossResult | LogitLossResult; export type OptimizationMetrics = { ci_masked_label_prob: number | null; // Probability of label under CI mask (CE loss only) @@ -102,6 +125,11 @@ export type OptimizationMetrics = { l0_total: number; // Total L0 (active components) }; +export type PgdConfig = { + n_steps: number; + step_size: number; +}; + export type OptimizationResult = { imp_min_coeff: number; steps: number; @@ -110,8 +138,7 @@ export type OptimizationResult = { mask_type: MaskType; loss: LossResult; metrics: OptimizationMetrics; - adv_pgd_n_steps: number | null; - adv_pgd_step_size: number | null; + pgd: PgdConfig | null; }; export type SubcomponentMetadata = { @@ -169,11 +196,33 @@ export type TokenizeResponse = { next_token_probs: (number | null)[]; // Probability of next token (last is null) }; -export type TokenInfo = { +export type TokenSearchResult = { id: number; string: string; + prob: number; }; +/** Select active edge set based on variant preference. Falls back to signed if abs unavailable. */ +export function getActiveEdges( + data: GraphData, + variant: "signed" | "abs_target", +): { edges: EdgeData[]; bySource: Map; byTarget: Map; maxAbsAttr: number } { + if (variant === "abs_target" && data.edgesAbs) { + return { + edges: data.edgesAbs, + bySource: data.edgesAbsBySource!, + byTarget: data.edgesAbsByTarget!, + maxAbsAttr: data.maxAbsAttrAbs || 1, + }; + } + return { + edges: data.edges, + bySource: data.edgesBySource, + byTarget: data.edgesByTarget, + maxAbsAttr: data.maxAbsAttr || 1, + }; +} + // Client-side computed types export type NodePosition = { @@ -233,7 +282,7 @@ export function formatNodeKeyForDisplay(nodeKey: string, displayNames: Record>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); // Current coords being loaded/displayed (for interpretation lookup) let currentCoords = $state(null); @@ -132,20 +134,40 @@ export function useComponentData() { datasetAttributions = { status: "loaded", data: null }; } - // Fetch interpretation detail (404 = no interpretation for this component) - getInterpretationDetail(layer, cIdx) - .then((data) => { - if (isStale()) return; - interpretationDetail = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - interpretationDetail = { status: "loaded", data: null }; - } else { + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; interpretationDetail = { status: "error", error }; - } - }); + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail (skip if not available for this run) + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } /** @@ -159,6 +181,7 @@ export function useComponentData() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache - reactive to both coords and cache @@ -212,6 +235,9 @@ export function useComponentData() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts index f32dab70a..d76c5da9e 100644 --- a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts +++ b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts @@ -6,7 +6,7 @@ * examples (200). Dataset attributions and interpretation detail are on-demand. */ -import { getContext } from "svelte"; +import { getContext, untrack } from "svelte"; import type { Loadable } from "."; import { ApiError, @@ -14,10 +14,11 @@ import { getComponentAttributions, getComponentCorrelations, getComponentTokenStats, + getGraphInterpComponentDetail, getInterpretationDetail, requestComponentInterpretation, } from "./api"; -import type { ComponentAttributions, InterpretationDetail } from "./api"; +import type { AllMetricAttributions, GraphInterpComponentDetail, InterpretationDetail } from "./api"; import type { SubcomponentCorrelationsResponse, SubcomponentActivationContexts, @@ -29,7 +30,7 @@ const DATASET_ATTRIBUTIONS_TOP_K = 20; /** Fetch more activation examples in background after initial cached load */ const ACTIVATION_EXAMPLES_FULL_LIMIT = 200; -export type { ComponentAttributions as DatasetAttributions }; +export type { AllMetricAttributions as DatasetAttributions }; export type ComponentCoords = { layer: string; cIdx: number }; @@ -39,8 +40,9 @@ export function useComponentDataExpectCached() { let componentDetail = $state>({ status: "uninitialized" }); let correlations = $state>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); + let graphInterpDetail = $state>({ status: "uninitialized" }); let currentCoords = $state(null); let requestId = 0; @@ -87,21 +89,41 @@ export function useComponentDataExpectCached() { datasetAttributions = { status: "loaded", data: null }; } - // Fetch interpretation detail on-demand (not cached) - interpretationDetail = { status: "loading" }; - getInterpretationDetail(layer, cIdx) - .then((data) => { - if (isStale()) return; - interpretationDetail = { status: "loaded", data }; - }) - .catch((error) => { - if (isStale()) return; - if (error instanceof ApiError && error.status === 404) { - interpretationDetail = { status: "loaded", data: null }; - } else { + const interpState = untrack(() => runState.getInterpretation(`${layer}:${cIdx}`)); + if (interpState.status === "loaded" && interpState.data.status !== "none") { + interpretationDetail = { status: "loading" }; + getInterpretationDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + interpretationDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; interpretationDetail = { status: "error", error }; - } - }); + }); + } else { + interpretationDetail = { status: "loaded", data: null }; + } + + // Fetch graph interp detail + if (runState.graphInterpAvailable) { + graphInterpDetail = { status: "loading" }; + getGraphInterpComponentDetail(layer, cIdx) + .then((data) => { + if (isStale()) return; + graphInterpDetail = { status: "loaded", data }; + }) + .catch((error) => { + if (isStale()) return; + if (error instanceof ApiError && error.status === 404) { + graphInterpDetail = { status: "loaded", data: null }; + } else { + graphInterpDetail = { status: "error", error }; + } + }); + } else { + graphInterpDetail = { status: "loaded", data: null }; + } } function load(layer: string, cIdx: number) { @@ -144,6 +166,7 @@ export function useComponentDataExpectCached() { tokenStats = { status: "uninitialized" }; datasetAttributions = { status: "uninitialized" }; interpretationDetail = { status: "uninitialized" }; + graphInterpDetail = { status: "uninitialized" }; } // Interpretation is derived from the global cache @@ -197,6 +220,9 @@ export function useComponentDataExpectCached() { get interpretationDetail() { return interpretationDetail; }, + get graphInterpDetail() { + return graphInterpDetail; + }, load, reset, generateInterpretation, diff --git a/spd/app/frontend/src/lib/useRun.svelte.ts b/spd/app/frontend/src/lib/useRun.svelte.ts index de6d20c7d..1cfc3cca6 100644 --- a/spd/app/frontend/src/lib/useRun.svelte.ts +++ b/spd/app/frontend/src/lib/useRun.svelte.ts @@ -7,13 +7,8 @@ import type { Loadable } from "."; import * as api from "./api"; -import type { LoadedRun as RunData, InterpretationHeadline } from "./api"; -import type { - PromptPreview, - SubcomponentActivationContexts, - TokenInfo, - SubcomponentMetadata, -} from "./promptAttributionsTypes"; +import type { LoadedRun as RunData, InterpretationHeadline, GraphInterpHeadline } from "./api"; +import type { PromptPreview, SubcomponentActivationContexts, SubcomponentMetadata } from "./promptAttributionsTypes"; /** Maps component keys to cluster IDs. Singletons (unclustered components) have null values. */ export type ClusterMappingData = Record; @@ -46,17 +41,15 @@ export function useRun() { /** Intruder eval scores keyed by component key */ let intruderScores = $state>>({ status: "uninitialized" }); + /** Graph interp labels keyed by component key (layer:cIdx) */ + let graphInterpLabels = $state>>({ status: "uninitialized" }); + /** Cluster mapping for the current run */ let clusterMapping = $state(null); /** Available prompts for the current run */ let prompts = $state>({ status: "uninitialized" }); - /** All tokens in the tokenizer for the current run */ - let allTokens = $state>({ status: "uninitialized" }); - - /** Model topology info for frontend layout */ - /** Activation contexts summary (null = harvest not available) */ let activationContextsSummary = $state | null>>({ status: "uninitialized", @@ -68,9 +61,9 @@ export function useRun() { /** Reset all run-scoped state */ function resetRunScopedState() { prompts = { status: "uninitialized" }; - allTokens = { status: "uninitialized" }; interpretations = { status: "uninitialized" }; intruderScores = { status: "uninitialized" }; + graphInterpLabels = { status: "uninitialized" }; activationContextsSummary = { status: "uninitialized" }; _componentDetailsCache = {}; clusterMapping = null; @@ -88,6 +81,9 @@ export function useRun() { api.getIntruderScores() .then((data) => (intruderScores = { status: "loaded", data })) .catch((error) => (intruderScores = { status: "error", error })); + api.getAllGraphInterpLabels() + .then((data) => (graphInterpLabels = { status: "loaded", data })) + .catch((error) => (graphInterpLabels = { status: "error", error })); api.getAllInterpretations() .then((i) => { interpretations = { @@ -106,14 +102,6 @@ export function useRun() { .catch((error) => (interpretations = { status: "error", error })); } - /** Fetch tokens - must complete before run is considered loaded */ - async function fetchTokens(): Promise { - allTokens = { status: "loading" }; - const tokens = await api.getAllTokens(); - allTokens = { status: "loaded", data: tokens }; - return tokens; - } - async function loadRun(wandbPath: string, contextLength: number) { run = { status: "loading" }; try { @@ -122,8 +110,6 @@ export function useRun() { if (status) { run = { status: "loaded", data: status }; fetchRunScopedData(); - // Fetch tokens in background (no longer blocks UI - used only by token search) - fetchTokens(); } else { run = { status: "error", error: "Failed to load run" }; } @@ -142,10 +128,6 @@ export function useRun() { try { const status = await api.getStatus(); if (status) { - // Fetch tokens and model info if we don't have them (e.g., page refresh) - if (allTokens.status === "uninitialized") { - await fetchTokens(); - } run = { status: "loaded", data: status }; // Fetch other run-scoped data if we don't have it if (interpretations.status === "uninitialized") { @@ -230,6 +212,11 @@ export function useRun() { return clusterMapping?.data[key] ?? null; } + function getGraphInterpLabel(componentKey: string): GraphInterpHeadline | null { + if (graphInterpLabels.status !== "loaded") return null; + return graphInterpLabels.data[componentKey] ?? null; + } + return { get run() { return run; @@ -237,21 +224,27 @@ export function useRun() { get interpretations() { return interpretations; }, + get graphInterpLabels() { + return graphInterpLabels; + }, get clusterMapping() { return clusterMapping; }, get prompts() { return prompts; }, - get allTokens() { - return allTokens; - }, get activationContextsSummary() { return activationContextsSummary; }, get datasetAttributionsAvailable() { return run.status === "loaded" && run.data.dataset_attributions_available; }, + get graphInterpAvailable() { + return run.status === "loaded" && run.data.graph_interp_available; + }, + get autoInterpAvailable() { + return run.status === "loaded" && run.data.autointerp_available; + }, loadRun, clearRun, syncStatus, @@ -259,6 +252,7 @@ export function useRun() { getInterpretation, setInterpretation, getIntruderScore, + getGraphInterpLabel, getActivationContextDetail, loadActivationContextsSummary, setClusterMapping, diff --git a/spd/app/frontend/vite.config.ts b/spd/app/frontend/vite.config.ts index fc72bbc92..a08d086fb 100644 --- a/spd/app/frontend/vite.config.ts +++ b/spd/app/frontend/vite.config.ts @@ -9,6 +9,7 @@ const backendUrl = process.env.BACKEND_URL || "http://localhost:8000"; export default defineConfig({ plugins: [svelte()], server: { + hmr: false, proxy: { "/api": { target: backendUrl, diff --git a/spd/app/run_app.py b/spd/app/run_app.py index 6aff0ce4c..c61174d1e 100755 --- a/spd/app/run_app.py +++ b/spd/app/run_app.py @@ -303,7 +303,7 @@ def spawn_frontend( return proc def monitor_child_liveness(self) -> None: - log_lines_to_show = 5 + log_lines_to_show = 20 prev_lines: list[str] = [] while True: diff --git a/spd/autointerp/config.py b/spd/autointerp/config.py index 1bb9db6ee..4d60a1705 100644 --- a/spd/autointerp/config.py +++ b/spd/autointerp/config.py @@ -33,8 +33,8 @@ class CompactSkepticalConfig(BaseConfig): include_pmi: bool = True include_spd_context: bool = True include_dataset_description: bool = True - label_max_words: int = 5 - forbidden_words: list[str] = FORBIDDEN_WORDS_DEFAULT + label_max_words: int = 8 + forbidden_words: list[str] | None = None class DualViewConfig(BaseConfig): @@ -51,7 +51,7 @@ class DualViewConfig(BaseConfig): include_pmi: bool = True include_dataset_description: bool = True label_max_words: int = 8 - forbidden_words: list[str] = FORBIDDEN_WORDS_DEFAULT + forbidden_words: list[str] | None = None StrategyConfig = CompactSkepticalConfig | DualViewConfig diff --git a/spd/autointerp/db.py b/spd/autointerp/db.py index 66d681333..72cc7bca9 100644 --- a/spd/autointerp/db.py +++ b/spd/autointerp/db.py @@ -31,6 +31,8 @@ ); """ +DONE_MARKER = ".done" + class InterpDB: def __init__(self, db_path: Path, readonly: bool = False) -> None: @@ -41,8 +43,12 @@ def __init__(self, db_path: Path, readonly: bool = False) -> None: else: self._conn = sqlite3.connect(str(db_path), check_same_thread=False) self._conn.executescript(_SCHEMA) + self._db_path = db_path self._conn.row_factory = sqlite3.Row + def mark_done(self) -> None: + (self._db_path.parent / DONE_MARKER).touch() + def save_interpretation(self, result: InterpretationResult) -> None: self._conn.execute( "INSERT OR REPLACE INTO interpretations VALUES (?, ?, ?, ?, ?, ?)", diff --git a/spd/autointerp/interpret.py b/spd/autointerp/interpret.py index 89e992902..94ea692d4 100644 --- a/spd/autointerp/interpret.py +++ b/spd/autointerp/interpret.py @@ -95,20 +95,18 @@ def run_interpret( db_path: Path, tokenizer_name: str, ) -> list[InterpretationResult]: - components = harvest.get_all_components() + summary = harvest.get_summary() + logger.info(f"Loaded summary for {len(summary)} components") token_stats = harvest.get_token_stats() assert token_stats is not None, "token_stats.pt not found. Run harvest first." app_tok = AppTokenizer.from_pretrained(tokenizer_name) - # Sort by firing density descending, just as an easy proxy for doing the most useful work first. - # NOTE: this doesn't necessarily align with mean causal importance as we use for sorting in the - # app, but it's a close enough proxy. - eligible = sorted(components, key=lambda c: c.firing_density, reverse=True) + eligible_keys = sorted(summary, key=lambda k: summary[k].firing_density, reverse=True) if limit is not None: - eligible = eligible[:limit] + eligible_keys = eligible_keys[:limit] async def _run() -> list[InterpretationResult]: db = InterpDB(db_path) @@ -118,19 +116,17 @@ async def _run() -> list[InterpretationResult]: if completed: logger.info(f"Resuming: {len(completed)} already completed") - remaining = [c for c in eligible if c.component_key not in completed] - logger.info(f"Interpreting {len(remaining)} components") + remaining_keys = [k for k in eligible_keys if k not in completed] + logger.info(f"Interpreting {len(remaining_keys)} components") schema = INTERPRETATION_SCHEMA def build_jobs() -> Iterable[LLMJob]: - for component in remaining: - input_stats = get_input_token_stats( - token_stats, component.component_key, app_tok, top_k=20 - ) - output_stats = get_output_token_stats( - token_stats, component.component_key, app_tok, top_k=50 - ) + for key in remaining_keys: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest" + input_stats = get_input_token_stats(token_stats, key, app_tok, top_k=20) + output_stats = get_output_token_stats(token_stats, key, app_tok, top_k=50) assert input_stats is not None assert output_stats is not None prompt = format_prompt( @@ -141,7 +137,7 @@ def build_jobs() -> Iterable[LLMJob]: input_token_stats=input_stats, output_token_stats=output_stats, ) - yield LLMJob(prompt=prompt, schema=schema, key=component.component_key) + yield LLMJob(prompt=prompt, schema=schema, key=key) results: list[InterpretationResult] = [] n_errors = 0 @@ -156,7 +152,7 @@ def build_jobs() -> Iterable[LLMJob]: max_requests_per_minute=max_requests_per_minute, cost_limit_usd=cost_limit_usd, response_schema=schema, - n_total=len(remaining), + n_total=len(remaining_keys), ): match outcome: case LLMResult(job=job, parsed=parsed, raw=raw): @@ -187,10 +183,11 @@ def build_jobs() -> Iterable[LLMJob]: # 10 is a magic number - just trying to avoid low sample size causing this to false alarm if error_rate > 0.2 and n_errors > 10: raise RuntimeError( - f"Error rate {error_rate:.0%} ({n_errors}/{len(remaining)}) exceeds 20% threshold" + f"Error rate {error_rate:.0%} ({n_errors}/{len(remaining_keys)}) exceeds 20% threshold" ) finally: + db.mark_done() db.close() logger.info(f"Completed {len(results)} interpretations -> {db_path}") diff --git a/spd/autointerp/llm_api.py b/spd/autointerp/llm_api.py index c58170b97..3464e9eb6 100644 --- a/spd/autointerp/llm_api.py +++ b/spd/autointerp/llm_api.py @@ -91,7 +91,7 @@ class LLMError: @dataclass -class _CostTracker: +class CostTracker: input_tokens: int = 0 output_tokens: int = 0 input_price_per_token: float = 0.0 @@ -167,9 +167,6 @@ def _get_retry_after(e: Exception) -> float | None: # --------------------------------------------------------------------------- -# TODO(oli) check this merge - - async def map_llm_calls( openrouter_api_key: str, model: str, @@ -181,6 +178,7 @@ async def map_llm_calls( cost_limit_usd: float | None, response_schema: dict[str, Any], n_total: int | None = None, + cost_tracker: CostTracker | None = None, ) -> AsyncGenerator[LLMResult | LLMError]: """Fan out LLM calls concurrently, yielding results as they complete. @@ -190,17 +188,24 @@ async def map_llm_calls( Jobs can be a lazy iterable (e.g. a generator). Prompt building in the generator body naturally interleaves with async HTTP calls. + + Pass a shared CostTracker to accumulate costs across multiple calls. """ if n_total is None and isinstance(jobs, Sized): n_total = len(jobs) async with OpenRouter(api_key=openrouter_api_key) as api: input_price, output_price = await _get_model_pricing(api, model) - cost = _CostTracker( - input_price_per_token=input_price, - output_price_per_token=output_price, - limit_usd=cost_limit_usd, - ) + if cost_tracker is not None: + cost = cost_tracker + cost.input_price_per_token = input_price + cost.output_price_per_token = output_price + else: + cost = CostTracker( + input_price_per_token=input_price, + output_price_per_token=output_price, + limit_usd=cost_limit_usd, + ) rate_limiter = AsyncLimiter(max_rate=max_requests_per_minute, time_period=60) backoff = _GlobalBackoff() reasoning = Reasoning(effort=reasoning_effort) @@ -262,7 +267,6 @@ async def chat(prompt: str, context_label: str) -> str: raise RuntimeError(f"Max retries exceeded for {context_label}: {last_error}") queue: asyncio.Queue[LLMResult | LLMError | None] = asyncio.Queue() - semaphore = asyncio.Semaphore(max_concurrent) n_done = 0 budget_exceeded = False @@ -272,45 +276,57 @@ async def process_one(job: LLMJob) -> None: if budget_exceeded: return - async with semaphore: - try: - raw = "" - parsed = None - for attempt in range(_JSON_PARSE_RETRIES): - raw = await chat(job.prompt, job.key) - try: - parsed = json.loads(raw) - break - except json.JSONDecodeError: - if attempt == _JSON_PARSE_RETRIES - 1: - raise - logger.warning( - f"{job.key}: invalid JSON " - f"(attempt {attempt + 1}/{_JSON_PARSE_RETRIES}), retrying" - ) - assert parsed is not None - await queue.put(LLMResult(job=job, parsed=parsed, raw=raw)) - except _BudgetExceededError: - budget_exceeded = True - return - except Exception as e: - await queue.put(LLMError(job=job, error=e)) - n_done += 1 + try: + raw = "" + parsed = None + for attempt in range(_JSON_PARSE_RETRIES): + raw = await chat(job.prompt, job.key) + try: + parsed = json.loads(raw) + break + except json.JSONDecodeError: + if attempt == _JSON_PARSE_RETRIES - 1: + raise + logger.warning( + f"{job.key}: invalid JSON " + f"(attempt {attempt + 1}/{_JSON_PARSE_RETRIES}), retrying" + ) + assert parsed is not None + await queue.put(LLMResult(job=job, parsed=parsed, raw=raw)) + except _BudgetExceededError: + budget_exceeded = True + return + except Exception as e: + await queue.put(LLMError(job=job, error=e)) + n_done += 1 total_str = f"/{n_total}" if n_total is not None else "" - if n_done % 100 == 0 or n_done == n_total: + if n_done == 1 or n_done % 10 == 0 or n_done == n_total: logger.info( f"[{n_done}{total_str}] ${cost.cost_usd():.2f} " f"({cost.input_tokens:,} in, {cost.output_tokens:,} out)" ) async def run_all() -> None: - tasks = [asyncio.create_task(process_one(job)) for job in jobs] - if not tasks: + job_queue: asyncio.Queue[LLMJob | None] = asyncio.Queue(maxsize=max_concurrent) + + async def worker() -> None: + while (job := await job_queue.get()) is not None: + await process_one(job) + + workers = [asyncio.create_task(worker()) for _ in range(max_concurrent)] + try: + for n_queued, job in enumerate(jobs, 1): + if budget_exceeded: + break + await job_queue.put(job) + if n_queued % 500 == 0: + logger.info(f"Queued {n_queued} jobs") + for _ in workers: + await job_queue.put(None) + await asyncio.gather(*workers) + finally: await queue.put(None) - return - await asyncio.gather(*tasks) - await queue.put(None) task = asyncio.create_task(run_all()) try: diff --git a/spd/autointerp/prompt_helpers.py b/spd/autointerp/prompt_helpers.py new file mode 100644 index 000000000..c91f0fe67 --- /dev/null +++ b/spd/autointerp/prompt_helpers.py @@ -0,0 +1,156 @@ +"""Shared prompt-building helpers for autointerp and graph interpretation. + +Pure functions for formatting component data into LLM prompt sections. +""" + +import re + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.utils import delimit_tokens +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData + +DATASET_DESCRIPTIONS: dict[str, str] = { + "SimpleStories/SimpleStories": ( + "SimpleStories: 2M+ short stories (200-350 words), grade 1-8 reading level. " + "Simple vocabulary, common narrative elements." + ), +} + +WEIGHT_NAMES: dict[str, str] = { + "attn.q": "attention query projection", + "attn.k": "attention key projection", + "attn.v": "attention value projection", + "attn.o": "attention output projection", + "mlp.up": "MLP up-projection", + "mlp.down": "MLP down-projection", + "glu.up": "GLU up-projection", + "glu.down": "GLU down-projection", + "glu.gate": "GLU gate projection", +} + +_ORDINALS = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th"] + + +def ordinal(n: int) -> str: + if 1 <= n <= len(_ORDINALS): + return _ORDINALS[n - 1] + return f"{n}th" + + +def human_layer_desc(canonical: str, n_blocks: int) -> str: + """Convert canonical layer string to human-readable description. + + '0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks' + '1.attn.q' -> 'attention query projection in the 2nd of 4 blocks' + """ + m = re.match(r"(\d+)\.(.*)", canonical) + if not m: + return canonical + layer_idx = int(m.group(1)) + weight_key = m.group(2) + weight_name = WEIGHT_NAMES.get(weight_key, weight_key) + return f"{weight_name} in the {ordinal(layer_idx + 1)} of {n_blocks} blocks" + + +def layer_position_note(canonical: str, n_blocks: int) -> str: + """Brief note about what layer position means for interpretation.""" + m = re.match(r"(\d+)\.", canonical) + if not m: + return "" + layer_idx = int(m.group(1)) + if layer_idx == n_blocks - 1: + return "This is in the final block, so its output directly influences token predictions." + remaining = n_blocks - 1 - layer_idx + return ( + f"This is {remaining} block{'s' if remaining > 1 else ''} from the output, " + f"so its effect on token predictions is indirect — filtered through later layers." + ) + + +def density_note(firing_density: float) -> str: + if firing_density > 0.15: + return ( + "This is a high-density component (fires frequently). " + "High-density components often act as broad biases rather than selective features." + ) + if firing_density < 0.005: + return "This is a very sparse component, likely highly specific." + return "" + + +def build_output_section( + output_stats: TokenPRLift, + output_pmi: list[tuple[str, float]] | None, +) -> str: + section = "" + + if output_pmi: + section += ( + "**Output PMI (pointwise mutual information, in nats: how much more likely " + "a token is to be produced when this component fires, vs its base rate. " + "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**\n" + ) + for tok, pmi in output_pmi[:10]: + section += f"- {repr(tok)}: {pmi:.2f}\n" + + if output_stats.top_precision: + section += "\n**Output precision — of all probability mass for token X, what fraction is at positions where this component fires?**\n" + for tok, prec in output_stats.top_precision[:10]: + section += f"- {repr(tok)}: {prec * 100:.0f}%\n" + + return section + + +def build_input_section( + input_stats: TokenPRLift, + input_pmi: list[tuple[str, float]] | None, +) -> str: + section = "" + + if input_pmi: + section += "**Input PMI (same metric as above, for input tokens):**\n" + for tok, pmi in input_pmi[:6]: + section += f"- {repr(tok)}: {pmi:.2f}\n" + + if input_stats.top_precision: + section += "\n**Input precision — probability the component fires given the current token is X:**\n" + for tok, prec in input_stats.top_precision[:8]: + section += f"- {repr(tok)}: {prec * 100:.0f}%\n" + + return section + + +def build_fires_on_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> str: + section = "" + examples = component.activation_examples[:max_examples] + + for i, ex in enumerate(examples): + if any(ex.firings): + spans = app_tok.get_spans(ex.token_ids) + tokens = list(zip(spans, ex.firings, strict=True)) + section += f"{i + 1}. {delimit_tokens(tokens)}\n" + + return section + + +def build_says_examples( + component: ComponentData, + app_tok: AppTokenizer, + max_examples: int, +) -> str: + section = "" + examples = component.activation_examples[:max_examples] + + for i, ex in enumerate(examples): + if any(ex.firings): + spans = app_tok.get_spans(ex.token_ids) + shifted_firings = [False] + ex.firings[:-1] + tokens = list(zip(spans, shifted_firings, strict=True)) + section += f"{i + 1}. {delimit_tokens(tokens)}\n" + + return section diff --git a/spd/autointerp/repo.py b/spd/autointerp/repo.py index cae089059..75c3fec29 100644 --- a/spd/autointerp/repo.py +++ b/spd/autointerp/repo.py @@ -12,8 +12,9 @@ import yaml -from spd.autointerp.db import InterpDB +from spd.autointerp.db import DONE_MARKER, InterpDB from spd.autointerp.schemas import InterpretationResult, get_autointerp_dir +from spd.log import logger class InterpRepo: @@ -29,25 +30,30 @@ def __init__(self, db: InterpDB, subrun_dir: Path, run_id: str) -> None: self.run_id = run_id @classmethod - def _find_latest_subrun_dir(cls, run_id: str) -> Path | None: + def _find_latest_done_subrun_dir(cls, run_id: str) -> Path | None: autointerp_dir = get_autointerp_dir(run_id) if not autointerp_dir.exists(): return None candidates = sorted( - [d for d in autointerp_dir.iterdir() if d.is_dir() and d.name.startswith("a-")], + [ + d + for d in autointerp_dir.iterdir() + if d.is_dir() and d.name.startswith("a-") and (d / DONE_MARKER).exists() + ], key=lambda d: d.name, ) return candidates[-1] if candidates else None @classmethod def open(cls, run_id: str) -> "InterpRepo | None": - """Open autointerp data for a run. Returns None if no autointerp data exists.""" - subrun_dir = cls._find_latest_subrun_dir(run_id) + """Open autointerp data for a run. Returns None if no completed autointerp data exists.""" + subrun_dir = cls._find_latest_done_subrun_dir(run_id) if subrun_dir is None: return None db_path = subrun_dir / "interp.db" if not db_path.exists(): return None + logger.info(f"Opening autointerp data for {run_id} from {subrun_dir}") return cls( db=InterpDB(db_path, readonly=True), subrun_dir=subrun_dir, diff --git a/spd/autointerp/scripts/run_interpret.py b/spd/autointerp/scripts/run_interpret.py index 060f27d74..da056dc35 100644 --- a/spd/autointerp/scripts/run_interpret.py +++ b/spd/autointerp/scripts/run_interpret.py @@ -14,7 +14,7 @@ from spd.adapters import adapter_from_id from spd.autointerp.config import AutointerpConfig from spd.autointerp.interpret import run_interpret -from spd.autointerp.schemas import get_autointerp_subrun_dir +from spd.autointerp.schemas import get_autointerp_dir, get_autointerp_subrun_dir from spd.harvest.repo import HarvestRepo from spd.log import logger @@ -23,6 +23,7 @@ def main( decomposition_id: str, config_json: dict[str, Any], harvest_subrun_id: str | None = None, + autointerp_subrun_id: str | None = None, ) -> None: assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" interp_config = AutointerpConfig.model_validate(config_json) @@ -38,10 +39,14 @@ def main( if harvest is None: raise ValueError(f"No harvest data found for {decomposition_id}") - autointerp_run_id = "a-" + datetime.now().strftime("%Y%m%d_%H%M%S") - - subrun_dir = get_autointerp_subrun_dir(decomposition_id, autointerp_run_id) - subrun_dir.mkdir(parents=True, exist_ok=True) + if autointerp_subrun_id is not None: + subrun_dir = get_autointerp_dir(decomposition_id) / autointerp_subrun_id + assert subrun_dir.exists(), f"Subrun dir not found: {subrun_dir}" + logger.info(f"Resuming existing subrun: {autointerp_subrun_id}") + else: + autointerp_subrun_id = "a-" + datetime.now().strftime("%Y%m%d_%H%M%S") + subrun_dir = get_autointerp_subrun_dir(decomposition_id, autointerp_subrun_id) + subrun_dir.mkdir(parents=True, exist_ok=True) # Save config for reproducibility interp_config.to_file(subrun_dir / "config.yaml") @@ -72,6 +77,7 @@ def get_command( decomposition_id: str, config: AutointerpConfig, harvest_subrun_id: str | None = None, + autointerp_subrun_id: str | None = None, ) -> str: config_json = config.model_dump_json(exclude_none=True) cmd = ( @@ -81,6 +87,8 @@ def get_command( ) if harvest_subrun_id is not None: cmd += f"--harvest_subrun_id {harvest_subrun_id} " + if autointerp_subrun_id is not None: + cmd += f"--autointerp_subrun_id {autointerp_subrun_id} " return cmd diff --git a/spd/autointerp/strategies/compact_skeptical.py b/spd/autointerp/strategies/compact_skeptical.py index 76d36ba8a..db826c342 100644 --- a/spd/autointerp/strategies/compact_skeptical.py +++ b/spd/autointerp/strategies/compact_skeptical.py @@ -16,6 +16,10 @@ "SimpleStories: 2M+ short stories (200-350 words), grade 1-8 reading level. " "Simple vocabulary, common narrative elements." ), + "danbraunai/pile-uncopyrighted-tok-shuffled": ( + "The Pile (uncopyrighted subset): diverse English text from books, " + "academic papers, code, web pages, and other sources." + ), } SPD_CONTEXT = ( diff --git a/spd/autointerp/strategies/dual_view.py b/spd/autointerp/strategies/dual_view.py index 430405e41..1206c73f5 100644 --- a/spd/autointerp/strategies/dual_view.py +++ b/spd/autointerp/strategies/dual_view.py @@ -7,83 +7,22 @@ - Task framing asks for functional description, not detection label """ -import re - from spd.app.backend.app_tokenizer import AppTokenizer -from spd.app.backend.utils import delimit_tokens from spd.autointerp.config import DualViewConfig +from spd.autointerp.prompt_helpers import ( + DATASET_DESCRIPTIONS, + build_fires_on_examples, + build_input_section, + build_output_section, + build_says_examples, + density_note, + human_layer_desc, + layer_position_note, +) from spd.autointerp.schemas import ModelMetadata from spd.harvest.analysis import TokenPRLift from spd.harvest.schemas import ComponentData -DATASET_DESCRIPTIONS: dict[str, str] = { - "SimpleStories/SimpleStories": ( - "SimpleStories: 2M+ short stories (200-350 words), grade 1-8 reading level. " - "Simple vocabulary, common narrative elements." - ), -} - -WEIGHT_NAMES: dict[str, str] = { - "attn.q": "attention query projection", - "attn.k": "attention key projection", - "attn.v": "attention value projection", - "attn.o": "attention output projection", - "mlp.up": "MLP up-projection", - "mlp.down": "MLP down-projection", - "glu.up": "GLU up-projection", - "glu.down": "GLU down-projection", - "glu.gate": "GLU gate projection", -} - -_ORDINALS = ["1st", "2nd", "3rd", "4th", "5th", "6th", "7th", "8th"] - - -def _ordinal(n: int) -> str: - if 1 <= n <= len(_ORDINALS): - return _ORDINALS[n - 1] - return f"{n}th" - - -def _human_layer_desc(canonical: str, n_blocks: int) -> str: - """Convert canonical layer string to human-readable description. - - '0.mlp.up' -> 'MLP up-projection in the 1st of 4 blocks' - '1.attn.q' -> 'attention query projection in the 2nd of 4 blocks' - """ - m = re.match(r"(\d+)\.(.*)", canonical) - if not m: - return canonical - layer_idx = int(m.group(1)) - weight_key = m.group(2) - weight_name = WEIGHT_NAMES.get(weight_key, weight_key) - return f"{weight_name} in the {_ordinal(layer_idx + 1)} of {n_blocks} blocks" - - -def _layer_position_note(canonical: str, n_blocks: int) -> str: - """Brief note about what layer position means for interpretation.""" - m = re.match(r"(\d+)\.", canonical) - if not m: - return "" - layer_idx = int(m.group(1)) - if layer_idx == n_blocks - 1: - return "This is in the final block, so its output directly influences token predictions." - remaining = n_blocks - 1 - layer_idx - return ( - f"This is {remaining} block{'s' if remaining > 1 else ''} from the output, " - f"so its effect on token predictions is indirect — filtered through later layers." - ) - - -def _density_note(firing_density: float) -> str: - if firing_density > 0.15: - return ( - "This is a high-density component (fires frequently). " - "High-density components often act as broad biases rather than selective features." - ) - if firing_density < 0.005: - return "This is a very sparse component, likely highly specific." - return "" - def format_prompt( config: DualViewConfig, @@ -108,10 +47,10 @@ def format_prompt( else None ) - output_section = _build_output_section(output_token_stats, output_pmi) - input_section = _build_input_section(input_token_stats, input_pmi) - fires_on_examples = _build_fires_on_examples(component, app_tok, config.max_examples) - says_examples = _build_says_examples(component, app_tok, config.max_examples) + output_section = build_output_section(output_token_stats, output_pmi) + input_section = build_input_section(input_token_stats, input_pmi) + fires_on_examples = build_fires_on_examples(component, app_tok, config.max_examples) + says_examples = build_says_examples(component, app_tok, config.max_examples) if component.firing_density > 0.0: rate_str = f"~1 in {int(1 / component.firing_density)} tokens" @@ -119,11 +58,11 @@ def format_prompt( rate_str = "extremely rare" canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) - layer_desc = _human_layer_desc(canonical, model_metadata.n_blocks) - position_note = _layer_position_note(canonical, model_metadata.n_blocks) - density_note = _density_note(component.firing_density) + layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) + position_note = layer_position_note(canonical, model_metadata.n_blocks) + dens_note = density_note(component.firing_density) - context_notes = " ".join(filter(None, [position_note, density_note])) + context_notes = " ".join(filter(None, [position_note, dens_note])) dataset_line = "" if config.include_dataset_description: @@ -183,85 +122,3 @@ def format_prompt( Say "unclear" if the evidence is too weak or diffuse. {forbidden_sentence}Lowercase only. """ - - -def _build_output_section( - output_stats: TokenPRLift, - output_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - - if output_pmi: - section += ( - "**Output PMI (pointwise mutual information, in nats: how much more likely " - "a token is to be produced when this component fires, vs its base rate. " - "0 = no association, 1 = ~3x more likely, 2 = ~7x, 3 = ~20x):**\n" - ) - for tok, pmi in output_pmi[:10]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - if output_stats.top_precision: - section += "\n**Output precision — of all probability mass for token X, what fraction is at positions where this component fires?**\n" - for tok, prec in output_stats.top_precision[:10]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section - - -def _build_input_section( - input_stats: TokenPRLift, - input_pmi: list[tuple[str, float]] | None, -) -> str: - section = "" - - if input_pmi: - section += "**Input PMI (same metric as above, for input tokens):**\n" - for tok, pmi in input_pmi[:6]: - section += f"- {repr(tok)}: {pmi:.2f}\n" - - if input_stats.top_recall: - section += "\n**Input recall — most common tokens when the component fires:**\n" - for tok, recall in input_stats.top_recall[:8]: - section += f"- {repr(tok)}: {recall * 100:.0f}%\n" - - if input_stats.top_precision: - section += "\n**Input precision — probability the component fires given the current token is X:**\n" - for tok, prec in input_stats.top_precision[:8]: - section += f"- {repr(tok)}: {prec * 100:.0f}%\n" - - return section - - -def _build_fires_on_examples( - component: ComponentData, - app_tok: AppTokenizer, - max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] - - for i, ex in enumerate(examples): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - tokens = list(zip(spans, ex.firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - - return section - - -def _build_says_examples( - component: ComponentData, - app_tok: AppTokenizer, - max_examples: int, -) -> str: - section = "" - examples = component.activation_examples[:max_examples] - - for i, ex in enumerate(examples): - if any(ex.firings): - spans = app_tok.get_spans(ex.token_ids) - shifted_firings = [False] + ex.firings[:-1] - tokens = list(zip(spans, shifted_firings, strict=True)) - section += f"{i + 1}. {delimit_tokens(tokens)}\n" - - return section diff --git a/spd/clustering/CLAUDE.md b/spd/clustering/CLAUDE.md index a063596f5..f502f8785 100644 --- a/spd/clustering/CLAUDE.md +++ b/spd/clustering/CLAUDE.md @@ -108,6 +108,18 @@ DistancesArray # Float[np.ndarray, "n_iters n_ens n_ens"] - `matching_dist.py` - Optimal matching distance via Hungarian algorithm - `merge_pair_samplers.py` - Strategies for selecting which pair to merge +## Utility Scripts + +**`get_cluster_mapping.py`**: Extracts cluster assignments at a specific iteration from a clustering run, outputs JSON mapping component labels to cluster indices (singletons mapped to `null`). + +```bash +python -m spd.clustering.scripts.get_cluster_mapping /path/to/clustering_run --iteration 299 +``` + +## App Integration + +To make a cluster mapping available in the app's dropdown for a run, add its path to `CANONICAL_RUNS` in `spd/app/frontend/src/lib/registry.ts` under the corresponding run's `clusterMappings` array. + ## Config Files Configs live in `spd/clustering/configs/`: diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index eccda019f..80e9c63bc 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,4 +6,4 @@ slurm_partition: null wandb_project: "spd" wandb_entity: "goodfire" create_git_snapshot: false -clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file +clustering_run_config_path: "spd/clustering/configs/crc/ss_llama_simple_mlp.json" \ No newline at end of file diff --git a/spd/configs.py b/spd/configs.py index efbfb3bb4..2635ef36c 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -567,6 +567,7 @@ class _PersistentPGDBaseConfig(LossMetricConfig): "source refinement iterations on the same batch in an inner loop beforehand." ), ] = 0 + n_samples: PositiveInt = 1 @model_validator(mode="before") @classmethod diff --git a/spd/dataset_attributions/CLAUDE.md b/spd/dataset_attributions/CLAUDE.md index faf3a5373..2e2c0bd4c 100644 --- a/spd/dataset_attributions/CLAUDE.md +++ b/spd/dataset_attributions/CLAUDE.md @@ -5,150 +5,111 @@ Multi-GPU pipeline for computing component-to-component attribution strengths ag ## Usage (SLURM) ```bash -# Process specific number of batches spd-attributions --n_batches 1000 --n_gpus 8 - -# Process entire training dataset (omit --n_batches) -spd-attributions --n_gpus 24 - -# With optional parameters -spd-attributions --n_batches 1000 --n_gpus 8 \ - --batch_size 64 --ci_threshold 1e-6 --time 48:00:00 +spd-attributions --n_gpus 24 # whole dataset ``` The command: -1. Creates a git snapshot branch for reproducibility (jobs may be queued) -2. Submits a SLURM job array with N tasks (one per GPU) +1. Creates a git snapshot branch for reproducibility +2. Submits a SLURM job array (one per GPU) 3. Each task processes batches where `batch_idx % world_size == rank` -4. Submits a merge job (depends on array completion) that combines all worker results - -**Note**: `--n_batches` is optional. If omitted, the pipeline processes the entire training dataset. +4. Submits a merge job (depends on array completion) ## Usage (non-SLURM) -For environments without SLURM, run the worker script directly: - ```bash -# Single GPU (defaults from DatasetAttributionConfig, auto-generates subrun ID) -python -m spd.dataset_attributions.scripts.run - -# Single GPU with config file -python -m spd.dataset_attributions.scripts.run --config_path path/to/config.yaml +# Single GPU +python -m spd.dataset_attributions.scripts.run_worker -# Multi-GPU (run in parallel via shell, tmux, etc.) -# All workers and the merge step must share the same --subrun_id +# Multi-GPU SUBRUN="da-$(date +%Y%m%d_%H%M%S)" -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 0 --world_size 4 --subrun_id $SUBRUN & -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 1 --world_size 4 --subrun_id $SUBRUN & -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 2 --world_size 4 --subrun_id $SUBRUN & -python -m spd.dataset_attributions.scripts.run --config_json '{"n_batches": 1000}' --rank 3 --world_size 4 --subrun_id $SUBRUN & +python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 1000}' --rank 0 --world_size 4 --subrun_id $SUBRUN & +python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 1000}' --rank 1 --world_size 4 --subrun_id $SUBRUN & +# ... wait - -# Merge results after all workers complete -python -m spd.dataset_attributions.scripts.run --merge --subrun_id $SUBRUN +python -m spd.dataset_attributions.scripts.run_merge --wandb_path --subrun_id $SUBRUN ``` -Each worker processes batches where `batch_idx % world_size == rank`, then the merge step combines all partial results. - ## Data Storage -Each attribution invocation creates a timestamped sub-run directory. `AttributionRepo` automatically loads from the latest sub-run. - ``` SPD_OUT_DIR/dataset_attributions// -├── da-20260211_120000/ # sub-run 1 -│ ├── dataset_attributions.pt # Final merged attributions -│ └── worker_states/ # cleaned up after merge +├── da-20260223_183250/ # sub-run (latest picked by repo) +│ ├── dataset_attributions.pt # merged result +│ └── worker_states/ │ └── dataset_attributions_rank_*.pt -├── da-20260211_140000/ # sub-run 2 -│ └── ... ``` -Legacy layout (pre sub-run) is still supported as a fallback by `AttributionRepo`: +`AttributionRepo.open(run_id)` loads the latest `da-*` subrun that has a `dataset_attributions.pt`. -``` -SPD_OUT_DIR/dataset_attributions// -└── dataset_attributions.pt -``` +## Attribution Metrics -## Architecture +Two metrics: `AttrMetric = Literal["attr", "attr_abs"]` -### SLURM Launcher (`scripts/run_slurm.py`, `scripts/run_slurm_cli.py`) +| Metric | Formula | Description | +|--------|---------|-------------| +| `attr` | E[∂y/∂x · x] | Signed mean attribution | +| `attr_abs` | E[∂\|y\|/∂x · x] | Attribution to absolute value of target (2 backward passes) | -Entry point via `spd-attributions`. Submits array job + dependent merge job. +Naming convention: modifier *before* `attr` applies to the target (e.g. `attr_abs` = attribution to |target|). -### Worker Script (`scripts/run.py`) +## Architecture -Internal script called by SLURM jobs. Accepts config via `--config_path` (file) or `--config_json` (inline JSON). Supports: -- `--config_path`/`--config_json`: Provide `DatasetAttributionConfig` (defaults used if neither given) -- `--rank R --world_size N`: Process subset of batches -- `--merge`: Combine per-rank results into final file -- `--subrun_id`: Sub-run identifier (auto-generated if not provided) +### Storage (`storage.py`) -### Config (`config.py`) +`DatasetAttributionStorage` stores four structurally distinct edge types: -`DatasetAttributionConfig` (tuning params) and `AttributionsSlurmConfig` (DatasetAttributionConfig + SLURM params). `wandb_path` is a runtime arg, not part of config. +| Edge type | Fields | Shape | Has abs? | +|-----------|--------|-------|----------| +| component → component | `regular_attr`, `regular_attr_abs` | `dict[target, dict[source, (tgt_c, src_c)]]` | yes | +| embed → component | `embed_attr`, `embed_attr_abs` | `dict[target, (tgt_c, vocab)]` | yes | +| component → unembed | `unembed_attr` | `dict[source, (d_model, src_c)]` | no | +| embed → unembed | `embed_unembed_attr` | `(d_model, vocab)` | no | -### Harvest Logic (`harvest.py`) +All layer names use **canonical addressing** (`"embed"`, `"0.glu.up"`, `"output"`). -Main harvesting functions: -- `harvest_attributions(wandb_path, config, output_dir, ...)`: Process batches for a single rank -- `merge_attributions(output_dir)`: Combine worker results from `output_dir/worker_states/` into `output_dir` +Unembed edges are stored in residual space (d_model dimensions). `w_unembed` is stored alongside the attribution data, so output token attributions are computed on-the-fly internally — callers never need to provide the projection matrix. No abs variant for unembed edges because abs is a nonlinear operation incompatible with residual-space storage. -### Attribution Harvester (`harvester.py`) +**Normalization**: `normed[t, s] = raw[t, s] / source_denom[s] / target_rms[t]`. Component sources use `ci_sum[s]` as denominator, embed sources use `embed_token_count[s]` (per-token occurrence count). This puts both source types on comparable per-occurrence scales. -Core class that accumulates attribution strengths using gradient × activation formula: +Key methods: `get_top_sources(key, k, sign, metric)`, `get_top_targets(key, k, sign, metric)`. Both return `[]` for nonexistent components. `merge(paths)` classmethod for combining worker results via weighted average by n_tokens. -``` -attribution[src, tgt] = Σ_batch Σ_pos (∂out[pos, tgt] / ∂in[pos, src]) × in_act[pos, src] -``` +### Harvester (`harvester.py`) -Key optimizations: +Accumulates attributions using gradient × activation. Uses **concrete module paths** internally (talks to model cache/CI). Four accumulator groups mirror the storage edge types. Key optimizations: 1. Sum outputs over positions before gradients (reduces backward passes) -2. For output targets, store attributions to output residual stream instead of vocab tokens (reduces storage from O((V+C)²) to O((V+C)×(C+d_model))) - -### Storage (`storage.py`) +2. Output-residual storage (O(d_model) instead of O(vocab)) +3. `scatter_add_` for embed sources, vectorized `.add_()` for components (>14x faster than per-element loops) -`DatasetAttributionStorage` class using output-residual-based storage for scalability. +### Harvest (`harvest.py`) -**Storage structure:** -- `source_to_component`: (n_sources, n_components) - direct attributions to component targets -- `source_to_out_residual`: (n_sources, d_model) - attributions to output residual stream for output queries +Orchestrates the pipeline: loads model, builds gradient connectivity, runs batches, translates concrete→canonical at storage boundary via `topology.target_to_canon()`. -**Source indexing (rows):** -- `[0, vocab_size)`: wte tokens -- `[vocab_size, vocab_size + n_components)`: component layers +### Scripts -**Target handling:** -- Component targets: direct lookup in `source_to_component` -- Output targets: computed on-the-fly via `source_to_out_residual @ w_unembed[:, token_id]` +- `scripts/run_worker.py` — worker entrypoint (single GPU) +- `scripts/run_merge.py` — merge entrypoint (CPU only, needs ~200G RAM) +- `scripts/run_slurm.py` — SLURM launcher (array + merge jobs) +- `scripts/run_slurm_cli.py` — CLI wrapper for `spd-attributions` -**Why output-residual-based storage?** +### Config (`config.py`) -For large vocab models (V=32K), the naive approach would require O((V+C)²) storage (~4 GB). -The output-residual-based approach requires only O((V+C)×(C+d)) storage (~670 MB for Llama-scale), -a 6.5x reduction. Output attributions are computed on-the-fly at query time with negligible latency. +- `DatasetAttributionConfig`: n_batches, batch_size, ci_threshold +- `AttributionsSlurmConfig`: adds n_gpus, partition, time, merge_time, merge_mem (default 200G) ### Repository (`repo.py`) -`AttributionRepo` provides read access via `AttributionRepo.open(run_id)`. Returns `None` if no data exists. Storage is loaded eagerly at construction. +`AttributionRepo.open(run_id)` → loads latest subrun. Returns `None` if no data. + +## Query Methods -## Key Types +All query methods take `metric: AttrMetric` (`"attr"` or `"attr_abs"`). -```python -DatasetAttributionStorage # Main storage class with split matrices -DatasetAttributionEntry # Single entry: component_key, layer, component_idx, value -DatasetAttributionConfig # Config (BaseConfig): n_batches, batch_size, ci_threshold -``` +| Method | Description | +|--------|-------------| +| `get_top_sources(target_key, k, sign, metric)` | Top sources → target | +| `get_top_targets(source_key, k, sign, metric)` | Top targets ← source | -## Query Methods +Key format: `"embed:{token_id}"`, `"0.glu.up:{c_idx}"`, `"output:{token_id}"`. -| Method | w_unembed required? | Description | -|--------|---------------------|-------------| -| `get_top_sources(component_key, k, sign)` | No | Top sources → component target | -| `get_top_sources(output_key, k, sign, w_unembed)` | Yes | Top sources → output token | -| `get_top_component_targets(source_key, k, sign)` | No | Top component targets | -| `get_top_output_targets(source_key, k, sign, w_unembed)` | Yes | Top output token targets | -| `get_top_targets(source_key, k, sign, w_unembed)` | Yes | All targets (components + outputs) | -| `get_attribution(source_key, component_key)` | No | Single component attribution | -| `get_attribution(source_key, output_key, w_unembed)` | Yes | Single output attribution | +Note: `attr_abs` returns empty for output targets (unembed edges have no abs variant). diff --git a/spd/dataset_attributions/config.py b/spd/dataset_attributions/config.py index a1de165fb..6f02df0f9 100644 --- a/spd/dataset_attributions/config.py +++ b/spd/dataset_attributions/config.py @@ -13,6 +13,8 @@ class DatasetAttributionConfig(BaseConfig): + spd_run_wandb_path: str + harvest_subrun_id: str | None = None n_batches: int | Literal["whole_dataset"] = 10_000 batch_size: int = 32 ci_threshold: float = 0.0 @@ -26,3 +28,4 @@ class AttributionsSlurmConfig(BaseConfig): partition: str = DEFAULT_PARTITION_NAME time: str = "48:00:00" merge_time: str = "01:00:00" + merge_mem: str = "200G" diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 15e4f5b19..9b02b0d3e 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -33,27 +33,13 @@ from spd.utils.wandb_utils import parse_wandb_run_path -def _build_component_layer_keys(model: ComponentModel) -> list[str]: - """Build list of component layer keys in canonical order. - - Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. - wte and output keys are not included - they're constructed from vocab_size. - """ - component_layer_keys = [] - for layer in model.target_module_paths: - n_components = model.module_to_c[layer] - for c_idx in range(n_components): - component_layer_keys.append(f"{layer}:{c_idx}") - return component_layer_keys - - def _build_alive_masks( model: ComponentModel, run_id: str, harvest_subrun_id: str | None, - n_components: int, + embed_path: str, vocab_size: int, -) -> tuple[Bool[Tensor, " n_sources"], Bool[Tensor, " n_components"]]: +) -> dict[str, Bool[Tensor, " n_components"]]: """Build masks of alive components (mean_activation > threshold) for sources and targets. Falls back to all-alive if harvest summary not available. @@ -63,47 +49,34 @@ def _build_alive_masks( - Targets: [0, n_components) = component layers (output handled via out_residual) """ - n_sources = vocab_size + n_components - - source_alive = torch.zeros(n_sources, dtype=torch.bool) - target_alive = torch.zeros(n_components, dtype=torch.bool) - - # All wte tokens are always alive (source indices [0, vocab_size)) - source_alive[:vocab_size] = True + component_alive = { + embed_path: torch.ones(vocab_size, dtype=torch.bool), # TODO(oli): maybe remove this + **{ + layer: torch.zeros(model.module_to_c[layer], dtype=torch.bool) + for layer in model.target_module_paths + }, + } if harvest_subrun_id is not None: harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) else: harvest = HarvestRepo.open_most_recent(run_id, readonly=True) assert harvest is not None, f"No harvest data for {run_id}" + summary = harvest.get_summary() assert summary is not None, "Harvest summary not available" - # Build masks for component layers - source_idx = vocab_size # Start after wte tokens - target_idx = 0 - for layer in model.target_module_paths: n_layer_components = model.module_to_c[layer] for c_idx in range(n_layer_components): component_key = f"{layer}:{c_idx}" is_alive = component_key in summary and summary[component_key].firing_density > 0.0 - source_alive[source_idx] = is_alive - target_alive[target_idx] = is_alive - source_idx += 1 - target_idx += 1 - - n_source_alive = int(source_alive.sum().item()) - n_target_alive = int(target_alive.sum().item()) - logger.info( - f"Alive components: {n_source_alive}/{n_sources} sources, " - f"{n_target_alive}/{n_components} component targets (firing density > 0.0)" - ) - return source_alive, target_alive + component_alive[layer][c_idx] = is_alive + + return component_alive def harvest_attributions( - wandb_path: str, config: DatasetAttributionConfig, output_dir: Path, harvest_subrun_id: str | None = None, @@ -127,43 +100,32 @@ def harvest_attributions( device = torch.device(get_device()) logger.info(f"Loading model on {device}") - _, _, run_id = parse_wandb_run_path(wandb_path) + _, _, run_id = parse_wandb_run_path(config.spd_run_wandb_path) - run_info = SPDRunInfo.from_path(wandb_path) + run_info = SPDRunInfo.from_path(config.spd_run_wandb_path) model = ComponentModel.from_run_info(run_info).to(device) model.eval() spd_config = run_info.config - train_loader, tokenizer = train_loader_and_tokenizer(spd_config, config.batch_size) - vocab_size = tokenizer.vocab_size - assert isinstance(vocab_size, int), f"vocab_size must be int, got {type(vocab_size)}" - logger.info(f"Vocab size: {vocab_size}") - - # Build component keys and alive masks - component_layer_keys = _build_component_layer_keys(model) - n_components = len(component_layer_keys) - source_alive, target_alive = _build_alive_masks( - model, run_id, harvest_subrun_id, n_components, vocab_size - ) - source_alive = source_alive.to(device) - target_alive = target_alive.to(device) - - n_sources = vocab_size + n_components - logger.info(f"Component layers: {n_components}, Sources: {n_sources}") + train_loader, _ = train_loader_and_tokenizer(spd_config, config.batch_size) # Get gradient connectivity logger.info("Computing sources_by_target...") topology = TransformerTopology(model.target_model) + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path + vocab_size = topology.embedding_module.num_embeddings + logger.info(f"Vocab size: {vocab_size}") sources_by_target_raw = get_sources_by_target(model, topology, str(device), spd_config.sampling) - # Filter sources_by_target: - # - Valid targets: component layers + output - # - Valid sources: wte + component layers + # Filter to valid source/target pairs: + # - Valid sources: embedding + component layers + # - Valid targets: component layers + unembed component_layers = set(model.target_module_paths) - valid_sources = component_layers | {"wte"} - valid_targets = component_layers | {"output"} + valid_sources = component_layers | {embed_path} + valid_targets = component_layers | {unembed_path} - sources_by_target = {} + sources_by_target: dict[str, list[str]] = {} for target, sources in sources_by_target_raw.items(): if target not in valid_targets: continue @@ -172,19 +134,21 @@ def harvest_attributions( sources_by_target[target] = filtered_sources logger.info(f"Found {len(sources_by_target)} target layers with gradient connections") - # Create harvester + # Build alive masks + component_alive = _build_alive_masks(model, run_id, harvest_subrun_id, embed_path, vocab_size) + + # Create harvester (all concrete paths internally) harvester = AttributionHarvester( model=model, sources_by_target=sources_by_target, - n_components=n_components, vocab_size=vocab_size, - source_alive=source_alive, - target_alive=target_alive, + component_alive=component_alive, sampling=spd_config.sampling, + embed_path=embed_path, embedding_module=topology.embedding_module, + unembed_path=unembed_path, unembed_module=topology.unembed_module, device=device, - show_progress=True, ) # Process batches @@ -194,37 +158,24 @@ def harvest_attributions( batch_range = range(n_batches) case "whole_dataset": batch_range = itertools.count() + for batch_idx in tqdm.tqdm(batch_range, desc="Attribution batches"): try: batch_data = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break + # Skip batches not assigned to this rank if world_size is not None and batch_idx % world_size != rank: continue + batch = extract_batch_data(batch_data).to(device) harvester.process_batch(batch) - logger.info( - f"Processing complete. Tokens: {harvester.n_tokens:,}, Batches: {harvester.n_batches}" - ) - - # Normalize by n_tokens to get per-token average attribution - normalized_comp = harvester.comp_accumulator / harvester.n_tokens - normalized_out_residual = harvester.out_residual_accumulator / harvester.n_tokens + logger.info(f"Processing complete. Tokens: {harvester.n_tokens:,}") - # Build and save storage - storage = DatasetAttributionStorage( - component_layer_keys=component_layer_keys, - vocab_size=vocab_size, - d_model=harvester.d_model, - source_to_component=normalized_comp.cpu(), - source_to_out_residual=normalized_out_residual.cpu(), - n_batches_processed=harvester.n_batches, - n_tokens_processed=harvester.n_tokens, - ci_threshold=config.ci_threshold, - ) + storage = harvester.finalize(topology, config.ci_threshold) if rank is not None: worker_dir = output_dir / "worker_states" @@ -234,72 +185,24 @@ def harvest_attributions( output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / "dataset_attributions.pt" storage.save(output_path) - logger.info(f"Saved dataset attributions to {output_path}") def merge_attributions(output_dir: Path) -> None: - """Merge partial attribution files from parallel workers. - - Looks for worker_states/dataset_attributions_rank_*.pt files and merges them - into dataset_attributions.pt in the output_dir. - - Uses streaming merge to avoid OOM - loads one file at a time instead of all at once. - """ + """Merge partial attribution files from parallel workers.""" worker_dir = output_dir / "worker_states" rank_files = sorted(worker_dir.glob("dataset_attributions_rank_*.pt")) assert rank_files, f"No rank files found in {worker_dir}" logger.info(f"Found {len(rank_files)} rank files to merge") - # Load first file to get metadata and initialize accumulators - # Use double precision for accumulation to prevent precision loss with billions of tokens - first = DatasetAttributionStorage.load(rank_files[0]) - total_comp = (first.source_to_component * first.n_tokens_processed).double() - total_out_residual = (first.source_to_out_residual * first.n_tokens_processed).double() - total_tokens = first.n_tokens_processed - total_batches = first.n_batches_processed - logger.info(f"Loaded rank 0: {first.n_tokens_processed:,} tokens") - - # Stream remaining files one at a time - for rank_file in tqdm.tqdm(rank_files[1:], desc="Merging rank files"): - storage = DatasetAttributionStorage.load(rank_file) - - # Validate consistency - assert storage.component_layer_keys == first.component_layer_keys, ( - "Component layer keys mismatch" - ) - assert storage.vocab_size == first.vocab_size, "Vocab size mismatch" - assert storage.d_model == first.d_model, "d_model mismatch" - assert storage.ci_threshold == first.ci_threshold, "CI threshold mismatch" - - # Accumulate de-normalized values - total_comp += storage.source_to_component * storage.n_tokens_processed - total_out_residual += storage.source_to_out_residual * storage.n_tokens_processed - total_tokens += storage.n_tokens_processed - total_batches += storage.n_batches_processed - - # Normalize by total tokens and convert back to float32 for storage - merged_comp = (total_comp / total_tokens).float() - merged_out_residual = (total_out_residual / total_tokens).float() - - # Save merged result - merged = DatasetAttributionStorage( - component_layer_keys=first.component_layer_keys, - vocab_size=first.vocab_size, - d_model=first.d_model, - source_to_component=merged_comp, - source_to_out_residual=merged_out_residual, - n_batches_processed=total_batches, - n_tokens_processed=total_tokens, - ci_threshold=first.ci_threshold, - ) + merged = DatasetAttributionStorage.merge(rank_files) output_path = output_dir / "dataset_attributions.pt" merged.save(output_path) - assert output_path.stat().st_size > 0, f"Merge output is empty: {output_path}" - logger.info(f"Merged {len(rank_files)} files -> {output_path}") - logger.info(f"Total: {total_batches} batches, {total_tokens:,} tokens") - - for rank_file in rank_files: - rank_file.unlink() - worker_dir.rmdir() - logger.info(f"Deleted {len(rank_files)} per-rank files and worker_states/") + logger.info(f"Total: {merged.n_tokens_processed:,} tokens") + + # TODO(oli): reenable this + # disabled deletion for testing, posterity and retries + # for rank_file in rank_files: + # rank_file.unlink() + # worker_dir.rmdir() + # logger.info(f"Deleted {len(rank_files)} per-rank files and worker_states/") diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 5bef0af63..0bdac9af0 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -4,27 +4,33 @@ training dataset using gradient x activation formula, summed over all positions and batches. -Uses residual-based storage for scalability: -- Component targets: accumulated directly to comp_accumulator -- Output targets: accumulated as attributions to output residual stream (source_to_out_residual) - Output attributions computed on-the-fly at query time via w_unembed +Three metrics are accumulated: +- attr: E[∂y/∂x · x] (signed mean attribution) +- attr_abs: E[∂|y|/∂x · x] (attribution to absolute value of target) + +Output (pseudo-) component attributions are handled differently: We accumulate attributions +to the output residual stream, then later project this into token space. + +All layer keys are concrete module paths (e.g. "wte", "h.0.attn.q_proj", "lm_head"). +Translation to canonical names happens at the storage boundary in harvest.py. """ from typing import Any import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Int from torch import Tensor, nn -from tqdm.auto import tqdm from spd.configs import SamplingType +from spd.dataset_attributions.storage import DatasetAttributionStorage from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import make_mask_infos +from spd.topology import TransformerTopology from spd.utils.general_utils import bf16_autocast class AttributionHarvester: - """Accumulates attribution strengths across batches. + """Accumulates attribution strengths across batches using concrete module paths. The attribution formula is: attribution[src, tgt] = Σ_batch Σ_pos (∂out[pos, tgt] / ∂in[pos, src]) × in_act[pos, src] @@ -35,11 +41,6 @@ class AttributionHarvester: 2. For output targets, store attributions to the pre-unembed residual (d_model dimensions) instead of vocab tokens. This eliminates the expensive O((V+C) × d_model × V) matmul during harvesting and reduces storage. - - Index structure: - - Sources: wte tokens [0, vocab_size) + component layers [vocab_size, ...) - - Component targets: [0, n_components) in comp_accumulator - - Output targets: via out_residual_accumulator (computed on-the-fly at query time) """ sampling: SamplingType @@ -48,95 +49,115 @@ def __init__( self, model: ComponentModel, sources_by_target: dict[str, list[str]], - n_components: int, vocab_size: int, - source_alive: Bool[Tensor, " n_sources"], - target_alive: Bool[Tensor, " n_components"], + component_alive: dict[str, Bool[Tensor, " n_components"]], sampling: SamplingType, + embed_path: str, embedding_module: nn.Embedding, + unembed_path: str, unembed_module: nn.Linear, device: torch.device, - show_progress: bool = False, ): self.model = model self.sources_by_target = sources_by_target - self.n_components = n_components - self.vocab_size = vocab_size - self.source_alive = source_alive - self.target_alive = target_alive + self.component_alive = component_alive self.sampling = sampling + self.embed_path = embed_path self.embedding_module = embedding_module + self.unembed_path = unembed_path self.unembed_module = unembed_module + self.output_d_model = unembed_module.in_features self.device = device - self.show_progress = show_progress - - self.n_sources = vocab_size + n_components - self.n_batches = 0 - self.n_tokens = 0 - - # Split accumulators for component and output targets - self.comp_accumulator = torch.zeros(self.n_sources, n_components, device=device) - # For output targets: store attributions to output residual dimensions - self.d_model = unembed_module.in_features - self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) - - # Build per-layer index ranges for sources - self.component_layer_names = list(model.target_module_paths) - self.source_layer_to_idx_range = self._build_source_layer_index_ranges() - self.target_layer_to_idx_range = self._build_target_layer_index_ranges() - - # Pre-compute alive indices per layer - self.alive_source_idxs_per_layer = self._build_alive_indices( - self.source_layer_to_idx_range, source_alive + # attribution accumulators + self._straight_through_attr_acc = torch.zeros( + (self.output_d_model, self.embedding_module.num_embeddings), device=self.device ) - self.alive_target_idxs_per_layer = self._build_alive_indices( - self.target_layer_to_idx_range, target_alive + self._embed_tgts_acc = self._get_embed_targets_attr_accumulator(sources_by_target) + self._embed_tgts_acc_abs = self._get_embed_targets_attr_accumulator(sources_by_target) + self._unembed_srcs_acc = self._get_unembed_sources_attr_accumulator(sources_by_target) + self._regular_layers_acc = self._get_regular_layer_attr_accumulator(sources_by_target) + self._regular_layers_acc_abs = self._get_regular_layer_attr_accumulator(sources_by_target) + + # embed token occurrence counts for normalization (analogous to ci_sum for components) + self._embed_token_count = torch.zeros( + (self.embedding_module.num_embeddings,), dtype=torch.long, device=self.device ) - def _build_source_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Source order: wte tokens [0, vocab_size), then component layers.""" - ranges: dict[str, tuple[int, int]] = {"wte": (0, self.vocab_size)} - idx = self.vocab_size - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - return ranges - - def _build_target_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Target order: component layers [0, n_components). Output handled separately.""" - ranges: dict[str, tuple[int, int]] = {} - idx = 0 - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - # Note: "output" not included - handled via out_residual_accumulator - return ranges - - def _build_alive_indices( - self, layer_ranges: dict[str, tuple[int, int]], alive_mask: Bool[Tensor, " n"] - ) -> dict[str, list[int]]: - """Get alive local indices for each layer.""" - return { - layer: torch.where(alive_mask[start:end])[0].tolist() - for layer, (start, end) in layer_ranges.items() + # rms normalization accumulators + self.n_tokens = 0 + self._ci_sum_accumulator = { + layer: torch.zeros((c,), device=self.device) + for layer, c in self.model.module_to_c.items() } + self._square_component_act_accumulator = { + layer: torch.zeros((c,), device=self.device) + for layer, c in self.model.module_to_c.items() + } + self._logit_sq_sum = torch.zeros((self.unembed_module.out_features,), device=self.device) + + def _get_embed_targets_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, Tensor]: + # extract targets who's sources include the embedding + embed_targets_attr_accumulators: dict[str, Tensor] = {} + for target, sources in sources_by_target.items(): + if target == self.unembed_path: + # ignore straight-through edge + continue + if self.embed_path in sources: + embed_targets_attr_accumulators[target] = torch.zeros( + (self.model.module_to_c[target], self.embedding_module.num_embeddings), + device=self.device, + ) + return embed_targets_attr_accumulators + + def _get_unembed_sources_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, Tensor]: + # extract the unembed's sources + unembed_sources_attr_accumulators: dict[str, Tensor] = {} + for source in sources_by_target[self.unembed_path]: + if source == self.embed_path: + # ignore straight-through edge + continue + unembed_sources_attr_accumulators[source] = torch.zeros( + (self.output_d_model, self.model.module_to_c[source]), device=self.device + ) + return unembed_sources_attr_accumulators + + def _get_regular_layer_attr_accumulator( + self, sources_by_target: dict[str, list[str]] + ) -> dict[str, dict[str, Tensor]]: + regular_layers_shapes: dict[str, dict[str, Tensor]] = {} + for target, sources in sources_by_target.items(): + if target == self.unembed_path: + continue + regular_layers_shapes[target] = {} + for source in sources: + if source == self.embed_path: + continue + regular_layers_shapes[target][source] = torch.zeros( + (self.model.module_to_c[target], self.model.module_to_c[source]), + device=self.device, + ) + return regular_layers_shapes def process_batch(self, tokens: Int[Tensor, "batch seq"]) -> None: """Accumulate attributions from one batch.""" - self.n_batches += 1 self.n_tokens += tokens.numel() + self._embed_token_count.add_( + torch.bincount(tokens.flatten(), minlength=self.embedding_module.num_embeddings) + ) - # Setup hooks to capture wte output and pre-unembed residual - wte_out: list[Tensor] = [] + # Setup hooks to capture embedding output and pre-unembed residual + embed_out: list[Tensor] = [] pre_unembed: list[Tensor] = [] - def wte_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: + def embed_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: out.requires_grad_(True) - wte_out.clear() - wte_out.append(out) + embed_out.clear() + embed_out.append(out) return out def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> None: @@ -144,7 +165,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - h1 = self.embedding_module.register_forward_hook(wte_hook, with_kwargs=True) + h1 = self.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) h2 = self.unembed_module.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) # Get masks with all components active @@ -153,6 +174,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No ci = self.model.calc_causal_importances( pre_weight_acts=out.cache, sampling=self.sampling, detach_inputs=False ) + mask_infos = make_mask_infos( component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, routing_masks="all", @@ -160,100 +182,146 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No # Forward pass with gradients with torch.enable_grad(), bf16_autocast(): - comp_output: OutputWithCache = self.model( + model_output: OutputWithCache = self.model( tokens, mask_infos=mask_infos, cache_type="component_acts" ) h1.remove() h2.remove() - cache = comp_output.cache - cache["wte_post_detach"] = wte_out[0] - cache["pre_unembed"] = pre_unembed[0] - cache["tokens"] = tokens - - # Process each target layer - layers = list(self.sources_by_target.items()) - pbar = tqdm(layers, desc="Targets", disable=not self.show_progress, leave=False) - for target_layer, source_layers in pbar: - if target_layer == "output": - self._process_output_targets(source_layers, cache) + cache = model_output.cache + cache[f"{self.embed_path}_post_detach"] = embed_out[0] + cache[f"{self.unembed_path}_pre_detach"] = pre_unembed[0] + + with torch.no_grad(): + for real_layer, ci_vals in ci.lower_leaky.items(): + self._ci_sum_accumulator[real_layer].add_(ci_vals.sum(dim=(0, 1))) + self._logit_sq_sum.add_(model_output.output.detach().square().sum(dim=(0, 1))) + + for target_layer in self.sources_by_target: + if target_layer == self.unembed_path: + self._process_output_targets(cache, tokens, ci.lower_leaky) else: - self._process_component_targets(target_layer, source_layers, cache) + with torch.no_grad(): + sum_sq_acts = cache[f"{target_layer}_post_detach"].square().sum(dim=(0, 1)) + self._square_component_act_accumulator[target_layer].add_(sum_sq_acts) + self._process_component_targets(cache, tokens, ci.lower_leaky, target_layer) - def _process_component_targets( + def _process_output_targets( self, - target_layer: str, - source_layers: list[str], cache: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], ) -> None: - """Process attributions to a component layer.""" - target_start, _ = self.target_layer_to_idx_range[target_layer] - alive_targets = self.alive_target_idxs_per_layer[target_layer] - if not alive_targets: - return + """Process output attributions via output-residual-space storage.""" + out_residual = cache[f"{self.unembed_path}_pre_detach"] + + out_residual_sum = out_residual.sum(dim=(0, 1)) + + source_layers = self.sources_by_target[self.unembed_path] + assert self.embed_path in source_layers, "remove me when passed" - # Sum over batch and sequence - target_acts = cache[f"{target_layer}_pre_detach"].sum(dim=(0, 1)) source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for t_idx in alive_targets: - grads = torch.autograd.grad(target_acts[t_idx], source_acts, retain_graph=True) - self._accumulate_attributions( - self.comp_accumulator[:, target_start + t_idx], - source_layers, - grads, - source_acts, - cache["tokens"], - ) + for d_idx in range(self.output_d_model): + grads = torch.autograd.grad(out_residual_sum[d_idx], source_acts, retain_graph=True) + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == self.embed_path: + token_attr = (grad * act).sum(dim=-1) # (B S) + self._straight_through_attr_acc[d_idx].scatter_add_( + 0, tokens.flatten(), token_attr.flatten() + ) + else: + ci_weighted_attr = (grad * act * ci[source_layer]).sum(dim=(0, 1)) + self._unembed_srcs_acc[source_layer][d_idx].add_(ci_weighted_attr) - def _process_output_targets( + def _process_component_targets( self, - source_layers: list[str], cache: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], + target_layer: str, ) -> None: - """Process output attributions via output-residual-space storage. - - Instead of computing and storing attributions to vocab tokens directly, - we store attributions to output residual dimensions. Output attributions are - computed on-the-fly at query time via: attr[src, token] = out_residual[src] @ w_unembed[:, token] - """ - # Sum output residual over batch and sequence -> [d_model] - out_residual = cache["pre_unembed"].sum(dim=(0, 1)) + """Process attributions to a component layer.""" + alive_targets = self.component_alive[target_layer] + if not alive_targets.any(): + return + + target_acts_raw = cache[f"{target_layer}_pre_detach"] + + target_acts = target_acts_raw.sum(dim=(0, 1)) + # abs() before sum — needs its own backward pass because each element has a different + # sign, so sign·grad can't be factored out of the sum. (In the app backend's per-prompt + # computation the target is a single scalar, so sign·grad works as an analytical shortcut + # and avoids a second backward. See app/backend/compute.py::_compute_edges_for_target.) + target_acts_abs = target_acts_raw.abs().sum(dim=(0, 1)) + + source_layers = self.sources_by_target[target_layer] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for d_idx in range(self.d_model): - grads = torch.autograd.grad(out_residual[d_idx], source_acts, retain_graph=True) - self._accumulate_attributions( - self.out_residual_accumulator[:, d_idx], - source_layers, - grads, - source_acts, - cache["tokens"], + def _accumulate_grads( + grads: tuple[Tensor, ...], + t_idx: int, + embed_acc: dict[str, Tensor], + regular_acc: dict[str, dict[str, Tensor]], + ) -> None: + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == self.embed_path: + token_attr = (grad * act).sum(dim=-1) # (B S) + embed_acc[target_layer][t_idx].scatter_add_( + 0, tokens.flatten(), token_attr.flatten() + ) + else: + ci_weighted = (grad * act * ci[source_layer]).sum(dim=(0, 1)) # (C,) + regular_acc[target_layer][source_layer][t_idx].add_(ci_weighted) + + for t_idx in torch.where(alive_targets)[0].tolist(): + grads = torch.autograd.grad(target_acts[t_idx], source_acts, retain_graph=True) + _accumulate_grads( + grads=grads, + t_idx=t_idx, + embed_acc=self._embed_tgts_acc, + regular_acc=self._regular_layers_acc, ) - def _accumulate_attributions( - self, - target_col: Float[Tensor, " n_sources"], - source_layers: list[str], - grads: tuple[Tensor, ...], - source_acts: list[Tensor], - tokens: Int[Tensor, "batch seq"], - ) -> None: - """Accumulate grad*act attributions from sources to a target column.""" - with torch.no_grad(): - for layer, grad, act in zip(source_layers, grads, source_acts, strict=True): - alive = self.alive_source_idxs_per_layer[layer] - if not alive: - continue + grads_abs = torch.autograd.grad(target_acts_abs[t_idx], source_acts, retain_graph=True) + _accumulate_grads( + grads=grads_abs, + t_idx=t_idx, + embed_acc=self._embed_tgts_acc_abs, + regular_acc=self._regular_layers_acc_abs, + ) - if layer == "wte": - # Per-token: sum grad*act over d_model, scatter by token id - attr = (grad * act).sum(dim=-1).flatten() - target_col.scatter_add_(0, tokens.flatten(), attr) - else: - # Per-component: sum grad*act over batch and sequence - start, _ = self.source_layer_to_idx_range[layer] - attr = (grad * act).sum(dim=(0, 1)) - for c in alive: - target_col[start + c] += attr[c] + def finalize( + self, topology: TransformerTopology, ci_threshold: float + ) -> DatasetAttributionStorage: + """Package raw accumulators into storage. No normalization — that happens at query time.""" + assert self.n_tokens > 0, "No batches processed" + + to_canon = topology.target_to_canon + + def _canon_nested(acc: dict[str, dict[str, Tensor]]) -> dict[str, dict[str, Tensor]]: + return { + to_canon(t): {to_canon(s): v for s, v in srcs.items()} for t, srcs in acc.items() + } + + def _canon(acc: dict[str, Tensor]) -> dict[str, Tensor]: + return {to_canon(k): v for k, v in acc.items()} + + return DatasetAttributionStorage( + regular_attr=_canon_nested(self._regular_layers_acc), + regular_attr_abs=_canon_nested(self._regular_layers_acc_abs), + embed_attr=_canon(self._embed_tgts_acc), + embed_attr_abs=_canon(self._embed_tgts_acc_abs), + unembed_attr=_canon(self._unembed_srcs_acc), + embed_unembed_attr=self._straight_through_attr_acc, + w_unembed=topology.get_unembed_weight(), + ci_sum=_canon(self._ci_sum_accumulator), + component_act_sq_sum=_canon(self._square_component_act_accumulator), + logit_sq_sum=self._logit_sq_sum, + embed_token_count=self._embed_token_count, + ci_threshold=ci_threshold, + n_tokens_processed=self.n_tokens, + ) diff --git a/spd/dataset_attributions/repo.py b/spd/dataset_attributions/repo.py index 697036ba3..1175d584e 100644 --- a/spd/dataset_attributions/repo.py +++ b/spd/dataset_attributions/repo.py @@ -42,14 +42,13 @@ def open(cls, run_id: str) -> "AttributionRepo | None": candidates = sorted( [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("da-")], key=lambda d: d.name, + reverse=True, ) - if not candidates: - return None - subrun_dir = candidates[-1] - path = subrun_dir / "dataset_attributions.pt" - if not path.exists(): - return None - return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + for subrun_dir in candidates: + path = subrun_dir / "dataset_attributions.pt" + if path.exists(): + return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + return None def get_attributions(self) -> DatasetAttributionStorage: return self._storage diff --git a/spd/dataset_attributions/scripts/diagnose_cancellation.py b/spd/dataset_attributions/scripts/diagnose_cancellation.py new file mode 100644 index 000000000..09bb85cbb --- /dev/null +++ b/spd/dataset_attributions/scripts/diagnose_cancellation.py @@ -0,0 +1,529 @@ +"""Diagnostic: does |mean(grad×act)| preserve L2(grad×act) rankings? + +The harvester accumulates signed sums of grad×act across positions. This script +checks whether that signed mean gives the same top-K source ranking as the +magnitude-preserving L2 = sqrt(mean((grad×act)²)) alternative. + +Methodology: + For each target component, iterate through data, find positions where the + target's CI > threshold (i.e. it's actually firing), then compute per-position + grad×act for all source components at those positions. Reduce to |mean| and L2 + per source component, rank them, and compare rankings via top-K overlap and + mean rank displacement. + + The per-position grad×act computation matches the harvester exactly: + - Component sources: grad × act × ci (CI-weighted, per the harvester) + - Embed sources: (grad × act).sum(embed_dim), grouped by token ID + +Usage: + python -m spd.dataset_attributions.scripts.diagnose_cancellation \ + "wandb:goodfire/spd/s-892f140b" \ + --n_targets_per_layer 20 --n_active 100 +""" + +import random +from dataclasses import dataclass +from typing import Any + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch import Tensor, nn + +from spd.configs import LMTaskConfig, SamplingType +from spd.data import train_loader_and_tokenizer +from spd.harvest.repo import HarvestRepo +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import make_mask_infos +from spd.settings import SPD_OUT_DIR +from spd.topology import TransformerTopology, get_sources_by_target +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import bf16_autocast, extract_batch_data +from spd.utils.wandb_utils import parse_wandb_run_path + +matplotlib.use("Agg") + + +# --------------------------------------------------------------------------- +# Setup +# --------------------------------------------------------------------------- + + +@dataclass +class ModelContext: + model: ComponentModel + topology: TransformerTopology + sampling: SamplingType + sources_by_target: dict[str, list[str]] + device: torch.device + embed_path: str + unembed_path: str + vocab_size: int + + +def setup(wandb_path: str) -> ModelContext: + device = torch.device(get_device()) + run_info = SPDRunInfo.from_path(wandb_path) + model = ComponentModel.from_run_info(run_info).to(device) + model.eval() + topology = TransformerTopology(model.target_model) + + sources_by_target_raw = get_sources_by_target( + model, topology, str(device), run_info.config.sampling + ) + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path + component_layers = set(model.target_module_paths) + valid_sources = component_layers | {embed_path} + valid_targets = component_layers | {unembed_path} + sources_by_target: dict[str, list[str]] = {} + for target, sources in sources_by_target_raw.items(): + if target not in valid_targets: + continue + filtered = [s for s in sources if s in valid_sources] + if filtered: + sources_by_target[target] = filtered + + return ModelContext( + model=model, + topology=topology, + sampling=run_info.config.sampling, + sources_by_target=sources_by_target, + device=device, + embed_path=embed_path, + unembed_path=unembed_path, + vocab_size=topology.embedding_module.num_embeddings, + ) + + +# --------------------------------------------------------------------------- +# Forward pass (matches harvester.process_batch exactly) +# --------------------------------------------------------------------------- + + +def forward_with_caches( + ctx: ModelContext, + tokens: Tensor, +) -> tuple[dict[str, Tensor], dict[str, Tensor]]: + """One forward pass → (cache, ci). Reuse across all (target, source) pairs.""" + embed_out: list[Tensor] = [] + pre_unembed: list[Tensor] = [] + + def embed_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: + out.requires_grad_(True) + embed_out.clear() + embed_out.append(out) + return out + + def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> None: + args[0].requires_grad_(True) + pre_unembed.clear() + pre_unembed.append(args[0]) + + h1 = ctx.topology.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) + h2 = ctx.topology.unembed_module.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) + + with torch.no_grad(), bf16_autocast(): + out = ctx.model(tokens, cache_type="input") + ci = ctx.model.calc_causal_importances( + pre_weight_acts=out.cache, sampling=ctx.sampling, detach_inputs=False + ) + + mask_infos = make_mask_infos( + component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, + routing_masks="all", + ) + + with torch.enable_grad(), bf16_autocast(): + model_output = ctx.model(tokens, mask_infos=mask_infos, cache_type="component_acts") + + h1.remove() + h2.remove() + + cache = model_output.cache + cache[f"{ctx.embed_path}_post_detach"] = embed_out[0] + cache[f"{ctx.unembed_path}_pre_detach"] = pre_unembed[0] + + return cache, ci.lower_leaky + + +# --------------------------------------------------------------------------- +# Per-position attribution (matches harvester._process_component_targets) +# --------------------------------------------------------------------------- + + +def per_position_grads_at( + ctx: ModelContext, + cache: dict[str, Tensor], + ci: dict[str, Tensor], + target_concrete: str, + t_idx: int, + s: int, +) -> dict[str, Tensor]: + """Compute grad×act for all source layers at a single position (b=0, s=s). + + Returns {source_concrete: value_tensor} where: + - Component source: grad × act × ci, shape (C_source,) + - Embed source: (grad × act).sum(embed_dim), scalar + Matches the harvester's _accumulate_grads exactly, just without the sum. + """ + target_acts_raw = cache[f"{target_concrete}_pre_detach"] + scalar = target_acts_raw[0, s, t_idx] + + source_layers = ctx.sources_by_target[target_concrete] + source_acts = [cache[f"{sc}_post_detach"] for sc in source_layers] + grads = torch.autograd.grad(scalar, source_acts, retain_graph=True) + + result: dict[str, Tensor] = {} + with torch.no_grad(): + for source_layer, act, grad in zip(source_layers, source_acts, grads, strict=True): + if source_layer == ctx.embed_path: + result[source_layer] = (grad[0, s] * act[0, s]).sum().cpu() + else: + result[source_layer] = (grad[0, s] * act[0, s] * ci[source_layer][0, s]).cpu() + return result + + +# --------------------------------------------------------------------------- +# Collect active positions for a target component +# --------------------------------------------------------------------------- + + +def collect_active_attrs( + ctx: ModelContext, + loader_iter: Any, + target_concrete: str, + t_idx: int, + n_active: int, + ci_threshold: float, + max_sequences: int, +) -> tuple[dict[str, list[Tensor]], int, int]: + """Iterate sequences, backward only at positions where target CI > threshold. + + Returns (per_source_vals, n_found, n_sequences_checked) where + per_source_vals[source_concrete] is a list of tensors, one per active position. + """ + source_layers = ctx.sources_by_target[target_concrete] + per_source: dict[str, list[Tensor]] = {sc: [] for sc in source_layers} + n_found = 0 + n_checked = 0 + + for _ in range(max_sequences): + try: + batch_data = next(loader_iter) + except StopIteration: + break + tokens = extract_batch_data(batch_data).to(ctx.device) + n_checked += 1 + + # Cheap CI check (no grad graph needed) + with torch.no_grad(), bf16_autocast(): + out = ctx.model(tokens, cache_type="input") + ci_check = ctx.model.calc_causal_importances( + pre_weight_acts=out.cache, sampling=ctx.sampling, detach_inputs=False + ) + + ci_vals = ci_check.lower_leaky[target_concrete][0, :, t_idx] + active_positions = (ci_vals > ci_threshold).nonzero(as_tuple=True)[0] + if len(active_positions) == 0: + continue + + # Full forward with grad graph + cache, ci = forward_with_caches(ctx, tokens) + + for s in active_positions.tolist(): + if n_found >= n_active: + break + grads = per_position_grads_at(ctx, cache, ci, target_concrete, t_idx, s) + for sc, val in grads.items(): + per_source[sc].append(val) + n_found += 1 + + if n_found >= n_active: + break + + return per_source, n_found, n_checked + + +# --------------------------------------------------------------------------- +# Reduce per-position values to |mean| and L2 per source component +# --------------------------------------------------------------------------- + + +def reduce_to_rankings( + per_source: dict[str, list[Tensor]], + embed_path: str, + tokens_per_pos: list[Tensor] | None, + vocab_size: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Reduce per-position attrs to per-source-component |mean| and L2. + + Component sources: each position gives a (C,) vector. |mean| and L2 over positions. + Embed sources: each position gives a scalar. Group by token ID via scatter_add, + then |mean| and L2 per token. + + Returns (abs_means, l2s, is_embed) arrays pooled across all source layers. + """ + all_abs_means: list[np.ndarray] = [] + all_l2s: list[np.ndarray] = [] + all_is_embed: list[np.ndarray] = [] + + for source_layer, vals in per_source.items(): + if not vals: + continue + + if source_layer == embed_path: + assert tokens_per_pos is not None + all_vals = torch.stack(vals).float() + all_toks = torch.cat(tokens_per_pos) + token_sum = torch.zeros(vocab_size) + token_sq_sum = torch.zeros(vocab_size) + token_count = torch.zeros(vocab_size) + token_sum.scatter_add_(0, all_toks, all_vals) + token_sq_sum.scatter_add_(0, all_toks, all_vals.square()) + token_count.scatter_add_(0, all_toks, torch.ones_like(all_vals)) + safe_count = token_count.clamp(min=1) + all_abs_means.append((token_sum / safe_count).abs().numpy()) + all_l2s.append((token_sq_sum / safe_count).sqrt().numpy()) + all_is_embed.append(np.ones(vocab_size, dtype=bool)) + else: + stacked = torch.stack(vals).float() # (N, C) + all_abs_means.append(stacked.mean(dim=0).abs().numpy()) + all_l2s.append(stacked.square().mean(dim=0).sqrt().numpy()) + all_is_embed.append(np.zeros(stacked.shape[1], dtype=bool)) + + return np.concatenate(all_abs_means), np.concatenate(all_l2s), np.concatenate(all_is_embed) + + +# --------------------------------------------------------------------------- +# Target selection +# --------------------------------------------------------------------------- + + +def select_targets( + ctx: ModelContext, + run_id: str, + n_per_layer: int, + fd_range: tuple[float, float], + seed: int, + comp_only: bool, +) -> list[tuple[str, int, float]]: + """Select target components with firing density in range. + + Returns [(concrete_path, c_idx, firing_density), ...]. + """ + harvest = HarvestRepo.open_most_recent(run_id, readonly=True) + assert harvest is not None + summary = harvest.get_summary() + assert summary is not None + + rng = random.Random(seed) + targets: list[tuple[str, int, float]] = [] + + for target_concrete in ctx.sources_by_target: + if target_concrete == ctx.unembed_path: + continue + if comp_only and ctx.embed_path in ctx.sources_by_target[target_concrete]: + continue + + candidates: list[tuple[int, float]] = [] + for c_idx in range(ctx.model.module_to_c[target_concrete]): + key = f"{target_concrete}:{c_idx}" + if key not in summary: + continue + fd = summary[key].firing_density + if fd_range[0] < fd < fd_range[1]: + candidates.append((c_idx, fd)) + + rng.shuffle(candidates) + for c_idx, fd in candidates[:n_per_layer]: + targets.append((target_concrete, c_idx, fd)) + + return targets + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +@dataclass +class TargetResult: + target_layer: str + target_idx: int + firing_density: float + n_active: int + n_sources: int + has_embed: bool + top5_mrd: float + top5_overlap: int + top10_mrd: float + top10_overlap: int + + +def main( + wandb_path: str, + n_targets_per_layer: int = 20, + n_active: int = 100, + ci_threshold: float = 0.01, + fd_min: float = 1e-4, + fd_max: float = 1e-1, + max_sequences: int = 5000, + seed: int = 42, + comp_only: bool = False, +) -> None: + import time + + ctx = setup(wandb_path) + _, _, run_id = parse_wandb_run_path(wandb_path) + to_canon = ctx.topology.target_to_canon + + targets = select_targets(ctx, run_id, n_targets_per_layer, (fd_min, fd_max), seed, comp_only) + print(f"Selected {len(targets)} targets (fd in ({fd_min}, {fd_max}), comp_only={comp_only})") + + spd_config = SPDRunInfo.from_path(wandb_path).config + assert isinstance(spd_config.task_config, LMTaskConfig) + # frozen + # spd_config.task_config.dataset_name = "danbraunai/pile-uncopyrighted-tok-shuffled" + train_loader, _ = train_loader_and_tokenizer(spd_config, 1) + loader_iter = iter(train_loader) + + results: list[TargetResult] = [] + t0 = time.time() + + for ti, (tgt_concrete, t_idx, fd) in enumerate(targets): + tgt_canon = to_canon(tgt_concrete) + source_list = ctx.sources_by_target[tgt_concrete] + has_embed = ctx.embed_path in source_list + + per_source, n_found, _ = collect_active_attrs( + ctx, loader_iter, tgt_concrete, t_idx, n_active, ci_threshold, max_sequences + ) + + if n_found < 10: + print( + f"[{ti + 1}/{len(targets)}] {tgt_canon}:{t_idx} fd={fd:.4f} " + f"— only {n_found} active, skipping" + ) + continue + + # Embed sources excluded: collect_active_attrs doesn't store token IDs + # needed for scatter_add grouping. This is fine — embed rankings are + # near-perfect anyway (confirmed in notebook analysis). + abs_means, l2s, _ = reduce_to_rankings( + {sc: v for sc, v in per_source.items() if sc != ctx.embed_path}, + ctx.embed_path, + None, + ctx.vocab_size, + ) + + n = len(abs_means) + rank_mean = np.argsort(np.argsort(-abs_means)) + rank_l2 = np.argsort(np.argsort(-l2s)) + + top5 = np.argsort(-l2s)[:5] + top10 = np.argsort(-l2s)[:10] + + results.append( + TargetResult( + target_layer=tgt_canon, + target_idx=t_idx, + firing_density=fd, + n_active=n_found, + n_sources=n, + has_embed=has_embed, + top5_mrd=np.abs(rank_mean[top5] - rank_l2[top5]).mean(), + top5_overlap=len(set(top5) & set(np.argsort(-abs_means)[:5])), + top10_mrd=np.abs(rank_mean[top10] - rank_l2[top10]).mean(), + top10_overlap=len(set(top10) & set(np.argsort(-abs_means)[:10])), + ) + ) + + if (ti + 1) % 10 == 0: + elapsed = time.time() - t0 + rate = elapsed / (ti + 1) + print( + f"[{ti + 1}/{len(targets)}] {elapsed:.0f}s, ~{rate * (len(targets) - ti - 1):.0f}s left", + flush=True, + ) + + elapsed = time.time() - t0 + print(f"\nDone: {len(results)} targets in {elapsed:.0f}s") + + _print_results(results) + _plot_results(results) + + +def _print_results(results: list[TargetResult]) -> None: + print(f"\n{'=' * 70}") + print("CANCELLATION DIAGNOSTIC: |mean| vs L2 ranking agreement") + print(f"{'=' * 70}") + print(f" {len(results)} targets, active positions only (CI > threshold)") + print() + + for label, metric, K in [ + ("Top-5 mean rank displacement", "top5_mrd", 5), + ("Top-5 overlap", "top5_overlap", 5), + ("Top-10 mean rank displacement", "top10_mrd", 10), + ("Top-10 overlap", "top10_overlap", 10), + ]: + vals = [getattr(r, metric) for r in results] + print( + f" {label}: {np.mean(vals):.1f} ± {np.std(vals):.1f}" + f" (median {np.median(vals):.1f})" + (f"/{K}" if "overlap" in metric else "") + ) + + print("\n By target layer:") + layers = sorted(set(r.target_layer for r in results)) + print( + f" {'layer':<18} {'n':>3} {'top5 mrd':>10} {'top5 olap':>10} " + f"{'top10 mrd':>10} {'top10 olap':>10}" + ) + print(f" {'-' * 65}") + for layer in layers: + lr = [r for r in results if r.target_layer == layer] + print( + f" {layer:<18} {len(lr):>3} " + f"{np.mean([r.top5_mrd for r in lr]):>6.1f}±{np.std([r.top5_mrd for r in lr]):<3.1f}" + f"{np.mean([r.top5_overlap for r in lr]):>7.1f}/5 " + f"{np.mean([r.top10_mrd for r in lr]):>6.1f}±{np.std([r.top10_mrd for r in lr]):<3.1f}" + f"{np.mean([r.top10_overlap for r in lr]):>7.1f}/10" + ) + + +def _plot_results(results: list[TargetResult]) -> None: + _, axes = plt.subplots(1, 2, figsize=(14, 5)) + + layers = sorted(set(r.target_layer for r in results)) + colors = {layer: f"C{i}" for i, layer in enumerate(layers)} + + for ax, K in [(axes[0], 5), (axes[1], 10)]: + for layer in layers: + vals = [r for r in results if r.target_layer == layer] + ax.hist( + [getattr(r, f"top{K}_mrd") for r in vals], + bins=np.arange(-0.5, 25.5, 1), + alpha=0.4, + color=colors[layer], + label=f"{layer} (μ={np.mean([getattr(r, f'top{K}_mrd') for r in vals]):.1f})", + ) + ax.set_xlabel(f"Top-{K} mean rank displacement") + ax.set_ylabel("# targets") + ax.set_title( + f"Top-{K}: |mean| vs L2 ranking agreement\n{len(results)} targets, active positions only" + ) + ax.legend(fontsize=8) + + plt.tight_layout() + out_path = SPD_OUT_DIR / "www" / "attr_cancellation_diagnostic.png" + out_path.parent.mkdir(parents=True, exist_ok=True) + plt.savefig(out_path, dpi=150) + print(f"\nSaved to {out_path}") + plt.close() + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/scripts/run.py b/spd/dataset_attributions/scripts/run.py deleted file mode 100644 index 5d060767e..000000000 --- a/spd/dataset_attributions/scripts/run.py +++ /dev/null @@ -1,106 +0,0 @@ -"""Worker script for dataset attribution computation. - -Called by SLURM jobs submitted via spd-attributions, or run directly for non-SLURM environments. - -Usage: - # Single GPU - python -m spd.dataset_attributions.scripts.run --config_json '...' - - # Multi-GPU (run in parallel) - python -m spd.dataset_attributions.scripts.run --config_json '...' --rank 0 --world_size 4 --subrun_id da-20260211_120000 - ... - python -m spd.dataset_attributions.scripts.run --merge --subrun_id da-20260211_120000 -""" - -from datetime import datetime -from typing import Any - -from spd.dataset_attributions.config import DatasetAttributionConfig -from spd.dataset_attributions.harvest import harvest_attributions, merge_attributions -from spd.dataset_attributions.repo import get_attributions_subrun_dir -from spd.log import logger -from spd.utils.wandb_utils import parse_wandb_run_path - - -def main( - wandb_path: str, - config_json: dict[str, Any], - rank: int | None = None, - world_size: int | None = None, - merge: bool = False, - subrun_id: str | None = None, - harvest_subrun_id: str | None = None, -) -> None: - assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" - _, _, run_id = parse_wandb_run_path(wandb_path) - - if subrun_id is None: - subrun_id = "da-" + datetime.now().strftime("%Y%m%d_%H%M%S") - - output_dir = get_attributions_subrun_dir(run_id, subrun_id) - - if merge: - assert rank is None and world_size is None, "Cannot specify rank/world_size with --merge" - logger.info(f"Merging attribution results for {wandb_path} (subrun {subrun_id})") - merge_attributions(output_dir) - return - - assert (rank is None) == (world_size is None), "rank and world_size must both be set or unset" - - config = DatasetAttributionConfig.model_validate(config_json) - - if world_size is not None: - logger.info( - f"Distributed harvest: {wandb_path} (rank {rank}/{world_size}, subrun {subrun_id})" - ) - else: - logger.info(f"Single-GPU harvest: {wandb_path} (subrun {subrun_id})") - - harvest_attributions( - wandb_path=wandb_path, - config=config, - output_dir=output_dir, - harvest_subrun_id=harvest_subrun_id, - rank=rank, - world_size=world_size, - ) - - -def get_worker_command( - wandb_path: str, - config_json: str, - rank: int, - world_size: int, - subrun_id: str, - harvest_subrun_id: str | None = None, -) -> str: - cmd = ( - f"python -m spd.dataset_attributions.scripts.run " - f'"{wandb_path}" ' - f"--config_json '{config_json}' " - f"--rank {rank} " - f"--world_size {world_size} " - f"--subrun_id {subrun_id}" - ) - if harvest_subrun_id is not None: - cmd += f" --harvest_subrun_id {harvest_subrun_id}" - return cmd - - -def get_merge_command(wandb_path: str, subrun_id: str) -> str: - return ( - f"python -m spd.dataset_attributions.scripts.run " - f'"{wandb_path}" ' - "--merge " - f"--subrun_id {subrun_id}" - ) - - -def cli() -> None: - import fire - - fire.Fire(main) - - -if __name__ == "__main__": - cli() diff --git a/spd/dataset_attributions/scripts/run_merge.py b/spd/dataset_attributions/scripts/run_merge.py new file mode 100644 index 000000000..913ea5374 --- /dev/null +++ b/spd/dataset_attributions/scripts/run_merge.py @@ -0,0 +1,37 @@ +"""Merge script for dataset attribution rank files. + +Combines per-rank attribution files into a single merged result. + +Usage: + python -m spd.dataset_attributions.scripts.run_merge --wandb_path --subrun_id da-xxx +""" + +from spd.dataset_attributions.harvest import merge_attributions +from spd.dataset_attributions.repo import get_attributions_subrun_dir +from spd.log import logger +from spd.utils.wandb_utils import parse_wandb_run_path + + +def main( + *, + wandb_path: str, + subrun_id: str, +) -> None: + _, _, run_id = parse_wandb_run_path(wandb_path) + output_dir = get_attributions_subrun_dir(run_id, subrun_id) + logger.info(f"Merging attribution results for {wandb_path} (subrun {subrun_id})") + merge_attributions(output_dir) + + +def get_command(wandb_path: str, subrun_id: str) -> str: + return ( + f"python -m spd.dataset_attributions.scripts.run_merge " + f'--wandb_path "{wandb_path}" ' + f"--subrun_id {subrun_id}" + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index 3fdba505e..6adc4bd52 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -14,7 +14,7 @@ from datetime import datetime from spd.dataset_attributions.config import AttributionsSlurmConfig -from spd.dataset_attributions.scripts import run as attribution_run +from spd.dataset_attributions.scripts import run_merge, run_worker from spd.log import logger from spd.utils.git_utils import create_git_snapshot from spd.utils.slurm import ( @@ -85,7 +85,7 @@ def submit_attributions( # SLURM arrays are 1-indexed, so task ID 1 -> rank 0, etc. worker_commands = [] for rank in range(n_gpus): - cmd = attribution_run.get_worker_command( + cmd = run_worker.get_command( wandb_path, config_json, rank=rank, @@ -115,12 +115,13 @@ def submit_attributions( ) # Submit merge job with dependency on array completion - merge_cmd = attribution_run.get_merge_command(wandb_path, subrun_id) + merge_cmd = run_merge.get_command(wandb_path, subrun_id) merge_config = SlurmConfig( job_name="spd-attr-merge", partition=partition, - n_gpus=0, # No GPU needed for merge + n_gpus=0, time=config.merge_time, + mem=config.merge_mem, snapshot_branch=snapshot_branch, dependency_job_id=array_result.job_id, comment=wandb_url, diff --git a/spd/dataset_attributions/scripts/run_worker.py b/spd/dataset_attributions/scripts/run_worker.py new file mode 100644 index 000000000..4944b4160 --- /dev/null +++ b/spd/dataset_attributions/scripts/run_worker.py @@ -0,0 +1,77 @@ +"""Worker script for dataset attribution computation. + +Called by SLURM jobs submitted via spd-attributions, or run directly for non-SLURM environments. + +Usage: + # Single GPU + python -m spd.dataset_attributions.scripts.run_worker + + # Single GPU with config + python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 500}' + + # Multi-GPU (run in parallel) + python -m spd.dataset_attributions.scripts.run_worker --rank 0 --world_size 4 --subrun_id da-xxx +""" + +from datetime import datetime +from typing import Any + +from spd.dataset_attributions.config import DatasetAttributionConfig +from spd.dataset_attributions.harvest import harvest_attributions +from spd.dataset_attributions.repo import get_attributions_subrun_dir +from spd.utils.wandb_utils import parse_wandb_run_path + + +def main( + wandb_path: str, + config_json: dict[str, Any], + rank: int, + world_size: int, + subrun_id: str | None = None, + harvest_subrun_id: str | None = None, +) -> None: + _, _, run_id = parse_wandb_run_path(wandb_path) + + if subrun_id is None: + subrun_id = "da-" + datetime.now().strftime("%Y%m%d_%H%M%S") + + base = config_json or {} + base.setdefault("spd_run_wandb_path", wandb_path) + if harvest_subrun_id is not None: + base.setdefault("harvest_subrun_id", harvest_subrun_id) + config = DatasetAttributionConfig.model_validate(base) + output_dir = get_attributions_subrun_dir(run_id, subrun_id) + + harvest_attributions( + config=config, + output_dir=output_dir, + rank=rank, + world_size=world_size, + ) + + +def get_command( + wandb_path: str, + config_json: str, + rank: int, + world_size: int, + subrun_id: str, + harvest_subrun_id: str | None = None, +) -> str: + cmd = ( + f"python -m spd.dataset_attributions.scripts.run_worker " + f'"{wandb_path}" ' + f"--config_json '{config_json}' " + f"--rank {rank} " + f"--world_size {world_size} " + f"--subrun_id {subrun_id}" + ) + if harvest_subrun_id is not None: + cmd += f" --harvest_subrun_id {harvest_subrun_id}" + return cmd + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/storage.py b/spd/dataset_attributions/storage.py index 16181201d..3d62a1b4a 100644 --- a/spd/dataset_attributions/storage.py +++ b/spd/dataset_attributions/storage.py @@ -1,22 +1,37 @@ """Storage classes for dataset attributions. -Uses a residual-based storage approach for scalability: -- Component targets: stored directly in source_to_component matrix -- Output targets: stored as attributions to residual stream, computed on-the-fly via w_unembed +Stores raw (unnormalized) attribution sums. Normalization happens at query time using +stored metadata (CI sums, activation RMS, logit RMS). + +Four edge types, each with its own shape: +- regular: component → component [tgt_c, src_c] (signed + abs) +- embed: embed → component [tgt_c, vocab] (signed + abs) +- unembed: component → unembed [d_model, src_c] (signed only, residual space) +- embed_unembed: embed → unembed [d_model, vocab] (signed only, residual space) + +Abs variants are unavailable for unembed edges because abs is a nonlinear operation +incompatible with the residual-space storage trick. + +Normalization formula: + normed[t, s] = raw[t, s] / source_denom[s] / target_rms[t] +- source_denom is ci_sum[s] for component sources, embed_token_count[s] for embed sources +- target_rms is component activation RMS for component targets, logit RMS for output targets """ -import dataclasses -from collections.abc import Callable +import bisect from dataclasses import dataclass from pathlib import Path from typing import Literal import torch -from jaxtyping import Float from torch import Tensor from spd.log import logger +AttrMetric = Literal["attr", "attr_abs"] + +EPS = 1e-10 + @dataclass class DatasetAttributionEntry: @@ -28,318 +43,339 @@ class DatasetAttributionEntry: value: float -@dataclass class DatasetAttributionStorage: """Dataset-aggregated attribution strengths between components. - Uses residual-based storage for scalability with large vocabularies: - - source_to_component: direct attributions to component targets - - source_to_out_residual: attributions to output residual stream (for computing output attributions) - - Output attributions are computed on-the-fly: attr[src, output_token] = out_residual[src] @ w_unembed[:, token] + All layer names use canonical addressing (e.g., "embed", "0.glu.up", "output"). - Source indexing (rows): - - [0, vocab_size): wte tokens - - [vocab_size, vocab_size + n_components): component layers - - Target indexing: - - Component targets: [0, n_components) in source_to_component - - Output targets: computed via source_to_out_residual @ w_unembed + Internally stores raw sums — normalization applied at query time. + Public interface: get_top_sources(), get_top_targets(), save/load/merge. Key formats: - - wte tokens: "wte:{token_id}" - - component layers: "layer:c_idx" (e.g., "h.0.attn.q_proj:5") + - embed tokens: "embed:{token_id}" + - component layers: "canonical_layer:c_idx" (e.g., "0.glu.up:5") - output tokens: "output:{token_id}" """ - component_layer_keys: list[str] - """Component layer keys in order: ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...]""" + def __init__( + self, + regular_attr: dict[str, dict[str, Tensor]], + regular_attr_abs: dict[str, dict[str, Tensor]], + embed_attr: dict[str, Tensor], + embed_attr_abs: dict[str, Tensor], + unembed_attr: dict[str, Tensor], + embed_unembed_attr: Tensor, + w_unembed: Tensor, + ci_sum: dict[str, Tensor], + component_act_sq_sum: dict[str, Tensor], + logit_sq_sum: Tensor, + embed_token_count: Tensor, + ci_threshold: float, + n_tokens_processed: int, + ): + self._regular_attr = regular_attr + self._regular_attr_abs = regular_attr_abs + self._embed_attr = embed_attr + self._embed_attr_abs = embed_attr_abs + self._unembed_attr = unembed_attr + self._embed_unembed_attr = embed_unembed_attr + self._w_unembed = w_unembed + self._ci_sum = ci_sum + self._component_act_sq_sum = component_act_sq_sum + self._logit_sq_sum = logit_sq_sum + self._embed_token_count = embed_token_count + self.ci_threshold = ci_threshold + self.n_tokens_processed = n_tokens_processed - vocab_size: int - """Vocabulary size (number of wte and output tokens)""" + @property + def target_layers(self) -> set[str]: + return self._regular_attr.keys() | self._embed_attr.keys() - d_model: int - """Model hidden dimension (residual stream size)""" + def _target_n_components(self, layer: str) -> int | None: + if layer in self._embed_attr: + return self._embed_attr[layer].shape[0] + if layer in self._regular_attr: + first_source = next(iter(self._regular_attr[layer].values())) + return first_source.shape[0] + return None - source_to_component: Float[Tensor, "n_sources n_components"] - """Attributions from sources to component targets. Shape: (vocab_size + n_components, n_components)""" + @property + def n_components(self) -> int: + total = 0 + for layer in self.target_layers: + n = self._target_n_components(layer) + assert n is not None + total += n + return total + + @staticmethod + def _parse_key(key: str) -> tuple[str, int]: + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) - source_to_out_residual: Float[Tensor, "n_sources d_model"] - """Attributions from sources to output residual dimensions. Shape: (vocab_size + n_components, d_model)""" + def _select_metric( + self, metric: AttrMetric + ) -> tuple[dict[str, dict[str, Tensor]], dict[str, Tensor]]: + match metric: + case "attr": + return self._regular_attr, self._embed_attr + case "attr_abs": + return self._regular_attr_abs, self._embed_attr_abs - n_batches_processed: int - n_tokens_processed: int - ci_threshold: float + def _component_activation_rms(self, layer: str) -> Tensor: + """RMS activation for a component layer. Shape (n_components,).""" + return (self._component_act_sq_sum[layer] / self.n_tokens_processed).sqrt().clamp(min=EPS) - _component_key_to_idx: dict[str, int] = dataclasses.field( - default_factory=dict, repr=False, init=False - ) + def _logit_activation_rms(self) -> Tensor: + """RMS logit per token. Shape (vocab,).""" + return (self._logit_sq_sum / self.n_tokens_processed).sqrt().clamp(min=EPS) - def __post_init__(self) -> None: - self._component_key_to_idx = {k: i for i, k in enumerate(self.component_layer_keys)} + def _layer_ci_sum(self, layer: str) -> Tensor: + """CI sum for a source layer, clamped. Shape (n_components,).""" + return self._ci_sum[layer].clamp(min=EPS) - n_components = len(self.component_layer_keys) - n_sources = self.vocab_size + n_components + def _embed_count(self) -> Tensor: + """Per-token occurrence count, clamped. Shape (vocab,).""" + return self._embed_token_count.float().clamp(min=EPS) - expected_comp_shape = (n_sources, n_components) - assert self.source_to_component.shape == expected_comp_shape, ( - f"source_to_component shape {self.source_to_component.shape} " - f"doesn't match expected {expected_comp_shape}" - ) + def get_top_sources( + self, + target_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + ) -> list[DatasetAttributionEntry]: + target_layer, target_idx = self._parse_key(target_key) + + value_segments: list[Tensor] = [] + layer_names: list[str] = [] + if target_layer == "embed": + return [] + + if target_layer == "output": + if metric == "attr_abs": + return [] + w = self._w_unembed[:, target_idx].to(self._embed_unembed_attr.device) + target_act_rms = self._logit_activation_rms()[target_idx] + + for source_layer, attr_matrix in self._unembed_attr.items(): + raw = w @ attr_matrix # (src_c,) + value_segments.append(raw / self._layer_ci_sum(source_layer) / target_act_rms) + layer_names.append(source_layer) + + raw = w @ self._embed_unembed_attr # (vocab,) + value_segments.append(raw / self._embed_count() / target_act_rms) + layer_names.append("embed") + else: + regular_attr, embed_target_attr = self._select_metric(metric) + target_act_rms = self._component_activation_rms(target_layer)[target_idx] - expected_resid_shape = (n_sources, self.d_model) - assert self.source_to_out_residual.shape == expected_resid_shape, ( - f"source_to_out_residual shape {self.source_to_out_residual.shape} " - f"doesn't match expected {expected_resid_shape}" - ) + if target_layer in regular_attr: + for source_layer, attr_matrix in regular_attr[target_layer].items(): + raw = attr_matrix[target_idx, :] # (src_c,) + value_segments.append(raw / self._layer_ci_sum(source_layer) / target_act_rms) + layer_names.append(source_layer) - @property - def n_components(self) -> int: - return len(self.component_layer_keys) + if target_layer in embed_target_attr: + raw = embed_target_attr[target_layer][target_idx, :] # (vocab,) + value_segments.append(raw / self._embed_count() / target_act_rms) + layer_names.append("embed") - @property - def n_sources(self) -> int: - return self.vocab_size + self.n_components + return self._top_k_from_segments(value_segments, layer_names, k, sign) - def _parse_key(self, key: str) -> tuple[str, int]: - """Parse a key into (layer, idx).""" - layer, idx_str = key.rsplit(":", 1) - return layer, int(idx_str) + def get_top_targets( + self, + source_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + include_outputs: bool = True, + ) -> list[DatasetAttributionEntry]: + source_layer, source_idx = self._parse_key(source_key) - def _source_idx(self, key: str) -> int: - """Get source (row) index for a key. Raises KeyError if not a valid source.""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - assert 0 <= idx < self.vocab_size, ( - f"wte index {idx} out of range [0, {self.vocab_size})" - ) - return idx - case "output": - raise KeyError(f"output tokens cannot be sources: {key}") - case _: - return self.vocab_size + self._component_key_to_idx[key] - - def _component_target_idx(self, key: str) -> int: - """Get target index for a component key. Raises KeyError if output or invalid.""" - if key.startswith(("wte:", "output:")): - raise KeyError(f"Not a component target: {key}") - return self._component_key_to_idx[key] - - def _source_idx_to_key(self, idx: int) -> str: - """Convert source (row) index to key.""" - if idx < self.vocab_size: - return f"wte:{idx}" - return self.component_layer_keys[idx - self.vocab_size] - - def _component_target_idx_to_key(self, idx: int) -> str: - """Convert component target index to key.""" - return self.component_layer_keys[idx] - - def _output_target_idx_to_key(self, idx: int) -> str: - """Convert output token index to key.""" - return f"output:{idx}" - - def _is_output_target(self, key: str) -> bool: - """Check if key is an output target.""" - return key.startswith("output:") - - def _output_token_id(self, key: str) -> int: - """Extract token_id from an output key like 'output:123'. Asserts valid range.""" - _, token_id = self._parse_key(key) - assert 0 <= token_id < self.vocab_size, f"output index {token_id} out of range" - return token_id - - def has_source(self, key: str) -> bool: - """Check if a key can be a source (wte token or component layer).""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - return 0 <= idx < self.vocab_size - case "output": - return False - case _: - return key in self._component_key_to_idx - - def has_target(self, key: str) -> bool: - """Check if a key can be a target (component layer or output token).""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - return False - case "output": - return 0 <= idx < self.vocab_size - case _: - return key in self._component_key_to_idx + value_segments: list[Tensor] = [] + layer_names: list[str] = [] - def save(self, path: Path) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - torch.save( - { - "component_layer_keys": self.component_layer_keys, - "vocab_size": self.vocab_size, - "d_model": self.d_model, - "source_to_component": self.source_to_component.cpu(), - "source_to_out_residual": self.source_to_out_residual.cpu(), - "n_batches_processed": self.n_batches_processed, - "n_tokens_processed": self.n_tokens_processed, - "ci_threshold": self.ci_threshold, - }, - path, - ) - size_mb = path.stat().st_size / (1024 * 1024) - logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") + if source_layer == "output": + return [] + elif source_layer == "embed": + regular, embed = self._select_metric(metric) + embed_count = self._embed_count()[source_idx] - @classmethod - def load(cls, path: Path) -> "DatasetAttributionStorage": - data = torch.load(path, weights_only=True, mmap=True) - return cls( - component_layer_keys=data["component_layer_keys"], - vocab_size=data["vocab_size"], - d_model=data["d_model"], - source_to_component=data["source_to_component"], - source_to_out_residual=data["source_to_out_residual"], - n_batches_processed=data["n_batches_processed"], - n_tokens_processed=data["n_tokens_processed"], - ci_threshold=data["ci_threshold"], - ) + for target_layer, attr_matrix in embed.items(): + raw = attr_matrix[:, source_idx] # (tgt_c,) + value_segments.append( + raw / embed_count / self._component_activation_rms(target_layer) + ) + layer_names.append(target_layer) - def get_attribution( - self, - source_key: str, - target_key: str, - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> float: - """Get attribution strength from source to target. - - Args: - source_key: Source component key (wte or component layer) - target_key: Target component key (component layer or output token) - w_unembed: Unembedding matrix, required if target is an output token - """ - src_idx = self._source_idx(source_key) + if include_outputs and metric == "attr": + residual = self._embed_unembed_attr[:, source_idx] # (d_model,) + raw = residual @ self._w_unembed # (vocab,) + value_segments.append(raw / embed_count / self._logit_activation_rms()) + layer_names.append("output") + else: + regular, embed = self._select_metric(metric) + ci = self._layer_ci_sum(source_layer)[source_idx] - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - return (self.source_to_out_residual[src_idx] @ w_unembed[:, token_id]).item() + for target_layer, sources in regular.items(): + if source_layer not in sources: + continue + raw = sources[source_layer][:, source_idx] # (tgt_c,) + value_segments.append(raw / ci / self._component_activation_rms(target_layer)) + layer_names.append(target_layer) - tgt_idx = self._component_target_idx(target_key) - return self.source_to_component[src_idx, tgt_idx].item() + if include_outputs and metric == "attr" and source_layer in self._unembed_attr: + residual = self._unembed_attr[source_layer][:, source_idx] # (d_model,) + raw = residual @ self._w_unembed # (vocab,) + value_segments.append(raw / ci / self._logit_activation_rms()) + layer_names.append("output") - def _get_top_k( + return self._top_k_from_segments(value_segments, layer_names, k, sign) + + def _top_k_from_segments( self, - values: Tensor, + value_segments: list[Tensor], + layer_names: list[str], k: int, sign: Literal["positive", "negative"], - idx_to_key: Callable[[int], str], ) -> list[DatasetAttributionEntry]: - """Get top-k entries from a 1D tensor of attribution values.""" + if not value_segments: + return [] + + all_values = torch.cat(value_segments) + offsets = [0] + for seg in value_segments: + offsets.append(offsets[-1] + len(seg)) + is_positive = sign == "positive" - top_vals, top_idxs = torch.topk(values, min(k, len(values)), largest=is_positive) + top_vals, top_idxs = torch.topk(all_values, min(k, len(all_values)), largest=is_positive) - # Filter to only values matching the requested sign mask = top_vals > 0 if is_positive else top_vals < 0 top_vals, top_idxs = top_vals[mask], top_idxs[mask] results = [] - for idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): - key = idx_to_key(idx) - layer, c_idx = self._parse_key(key) + for flat_idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): + seg_idx = bisect.bisect_right(offsets, flat_idx) - 1 + local_idx = flat_idx - offsets[seg_idx] + layer = layer_names[seg_idx] results.append( DatasetAttributionEntry( - component_key=key, + component_key=f"{layer}:{local_idx}", layer=layer, - component_idx=c_idx, + component_idx=local_idx, value=val, ) ) return results - def get_top_sources( - self, - target_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target. - - Args: - target_key: Target component key (component layer or output token) - k: Number of top sources to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if target is an output token - """ - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - values = self.source_to_out_residual @ w_unembed[:, token_id] # (n_sources,) + def get_attribution(self, source_key: str, target_key: str) -> float: + source_layer, source_idx = self._parse_key(source_key) + target_layer, target_idx = self._parse_key(target_key) + + if target_layer == "output" and source_layer == "embed": + return (self._embed_unembed_attr[:, source_idx] @ self._w_unembed[:, target_idx]).item() + elif target_layer == "output" and source_layer != "embed": + return ( + self._unembed_attr[source_layer][:, source_idx] @ self._w_unembed[:, target_idx] + ).item() + elif target_layer != "output" and source_layer == "embed": + return (self._embed_attr[target_layer][target_idx, source_idx]).item() else: - tgt_idx = self._component_target_idx(target_key) - values = self.source_to_component[:, tgt_idx] + assert target_layer != "output" and source_layer != "embed" + return (self._regular_attr[target_layer][source_layer][target_idx, source_idx]).item() - return self._get_top_k(values, k, sign, self._source_idx_to_key) + def save(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + torch.save( + { + "regular_attr": _to_cpu_nested(self._regular_attr), + "regular_attr_abs": _to_cpu_nested(self._regular_attr_abs), + "embed_attr": _to_cpu(self._embed_attr), + "embed_attr_abs": _to_cpu(self._embed_attr_abs), + "unembed_attr": _to_cpu(self._unembed_attr), + "embed_unembed_attr": self._embed_unembed_attr.detach().cpu(), + "w_unembed": self._w_unembed.detach().cpu(), + "ci_sum": _to_cpu(self._ci_sum), + "component_act_sq_sum": _to_cpu(self._component_act_sq_sum), + "logit_sq_sum": self._logit_sq_sum.detach().cpu(), + "embed_token_count": self._embed_token_count.detach().cpu(), + "ci_threshold": self.ci_threshold, + "n_tokens_processed": self.n_tokens_processed, + }, + path, + ) + size_mb = path.stat().st_size / (1024 * 1024) + logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") - def get_top_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - include_outputs: bool = True, - ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO. - - Args: - source_key: Source component key (wte or component layer) - k: Number of top targets to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if include_outputs=True - include_outputs: Whether to include output tokens in results + @classmethod + def load(cls, path: Path) -> "DatasetAttributionStorage": + data = torch.load(path, weights_only=True) + return cls( + regular_attr=data["regular_attr"], + regular_attr_abs=data["regular_attr_abs"], + embed_attr=data["embed_attr"], + embed_attr_abs=data["embed_attr_abs"], + unembed_attr=data["unembed_attr"], + embed_unembed_attr=data["embed_unembed_attr"], + w_unembed=data["w_unembed"], + ci_sum=data["ci_sum"], + component_act_sq_sum=data["component_act_sq_sum"], + logit_sq_sum=data["logit_sq_sum"], + embed_token_count=data["embed_token_count"], + ci_threshold=data["ci_threshold"], + n_tokens_processed=data["n_tokens_processed"], + ) + + @classmethod + def merge(cls, paths: list[Path]) -> "DatasetAttributionStorage": + """Merge partial attribution files from parallel workers. + + All stored values are raw sums — merge is element-wise addition. """ - src_idx = self._source_idx(source_key) - comp_values = self.source_to_component[src_idx, :] # (n_components,) + assert paths, "No files to merge" - if include_outputs: - assert w_unembed is not None, "w_unembed required when include_outputs=True" - # Compute attributions to all output tokens - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - all_values = torch.cat([comp_values, output_values]) + merged = cls.load(paths[0]) - def combined_idx_to_key(idx: int) -> str: - if idx < self.n_components: - return self._component_target_idx_to_key(idx) - return self._output_target_idx_to_key(idx - self.n_components) + for path in paths[1:]: + other = cls.load(path) + assert other.ci_threshold == merged.ci_threshold, "CI threshold mismatch" - return self._get_top_k(all_values, k, sign, combined_idx_to_key) + for target, sources in other._regular_attr.items(): + for source, tensor in sources.items(): + merged._regular_attr[target][source] += tensor + merged._regular_attr_abs[target][source] += other._regular_attr_abs[target][ + source + ] - return self._get_top_k(comp_values, k, sign, self._component_target_idx_to_key) + for target, tensor in other._embed_attr.items(): + merged._embed_attr[target] += tensor + merged._embed_attr_abs[target] += other._embed_attr_abs[target] - def get_top_component_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - ) -> list[DatasetAttributionEntry]: - """Get top-k component targets (excluding outputs) this source attributes TO. + for source, tensor in other._unembed_attr.items(): + merged._unembed_attr[source] += tensor - Convenience method that doesn't require w_unembed. - """ - return self.get_top_targets(source_key, k, sign, w_unembed=None, include_outputs=False) + merged._embed_unembed_attr += other._embed_unembed_attr - def get_top_output_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"], - ) -> list[DatasetAttributionEntry]: - """Get top-k output token targets this source attributes TO.""" - src_idx = self._source_idx(source_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - return self._get_top_k(output_values, k, sign, self._output_target_idx_to_key) + for layer in other._ci_sum: + merged._ci_sum[layer] += other._ci_sum[layer] + + for layer in other._component_act_sq_sum: + merged._component_act_sq_sum[layer] += other._component_act_sq_sum[layer] + + merged._logit_sq_sum += other._logit_sq_sum + merged._embed_token_count += other._embed_token_count + merged.n_tokens_processed += other.n_tokens_processed + + return merged + + +def _to_cpu_nested(d: dict[str, dict[str, Tensor]]) -> dict[str, dict[str, Tensor]]: + return { + target: {source: v.detach().cpu() for source, v in sources.items()} + for target, sources in d.items() + } + + +def _to_cpu(d: dict[str, Tensor]) -> dict[str, Tensor]: + return {k: v.detach().cpu() for k, v in d.items()} diff --git a/spd/editing/README.md b/spd/editing/README.md new file mode 100644 index 000000000..2860a7650 --- /dev/null +++ b/spd/editing/README.md @@ -0,0 +1,95 @@ +# spd.editing + +Component-level model editing for VPD decompositions. + +## Setup + +```python +from spd.editing import EditableModel, generate, measure_kl, measure_token_probs +from spd.harvest.repo import HarvestRepo +from spd.autointerp.repo import InterpRepo + +em, tok = EditableModel.from_wandb("wandb:goodfire/spd/s-892f140b") +harvest = HarvestRepo("s-892f140b") +interp = InterpRepo("s-892f140b") +``` + +## Finding components + +By autointerp label: +```python +from spd.editing import search_interpretations +matches = search_interpretations(harvest, interp, r"male pronoun") +# -> [ComponentMatch(key='h.1.attn.v_proj:52', label='male pronouns', ...)] +``` + +By output token PMI (best for ablation targets): +```python +from spd.editing import search_by_token_pmi +he_id = tok.encode("he") +matches = search_by_token_pmi(harvest, he_id, side="output", min_pmi=1.0) +``` + +By circuit optimization across examples: +```python +examples = [(tokens1, target_pos1), (tokens2, target_pos2), ...] +components = em.find_components_by_examples(examples, optim_steps=100) +# -> [('h.1.attn.v_proj:52', 0.9), ('h.1.mlp.down_proj:798', 0.8), ...] +``` + +## Inspecting components + +```python +from spd.editing import inspect_component +data = inspect_component(harvest, interp, "h.1.mlp.down_proj:798", tok) +# Prints: label, input/output PMI tokens, activation examples +``` + +Component geometry: +```python +vecs = em.get_component_vectors("h.1.mlp.down_proj:798") # read (V) and write (U) vectors +alignment = em.component_alignment("h.1.attn.o_proj:82", "h.1.mlp.c_fc:144") # cosine, percentile +boosted, suppressed = em.unembed_alignment("h.1.mlp.down_proj:798", tok) # top logit-lens tokens +``` + +## Editing (runtime masks) + +```python +# 0.0 = ablate, 2.0 = boost +edit_fn = em.make_edit_fn({"h.1.mlp.down_proj:798": 0.0, "h.1.attn.v_proj:52": 0.0}) + +# Generate with edits +text = generate(edit_fn, tokens, tok) + +# Measure effect +effect = measure_kl(em, edit_fn, eval_seqs) +print(f"KL={effect.mean_kl:.3f}, PPL: {effect.baseline_ppl:.1f} -> {effect.edited_ppl:.1f}") + +# Token group probability shifts +shifts = measure_token_probs(em, edit_fn, eval_seqs, { + "he": tok.encode("he"), + "she": tok.encode("she"), +}) +print(f"P(he) change: {shifts['he'].change_pct:+.1f}%") +``` + +CI-conditional editing (only edit where component is active): +```python +edit_fn = em.make_edit_fn({"h.1.mlp.down_proj:798": 0.0}, ci_threshold=0.1) +``` + +## Permanent weight editing + +```python +clean_em = em.without_components(["h.1.mlp.down_proj:798"]) +# Returns a new EditableModel with rank-1 subtraction baked into weights +text = generate(clean_em, tokens, tok) +``` + +## Circuit analysis + +```python +circuit = em.optimize_circuit(tokens, target_position=15, target_token=tok.encode("he")[0]) +em.print_circuit(circuit, tokens, tok, interp=interp) +# Prints: edges, node CI, component labels +``` diff --git a/spd/editing/__init__.py b/spd/editing/__init__.py new file mode 100644 index 000000000..4c71de2f4 --- /dev/null +++ b/spd/editing/__init__.py @@ -0,0 +1,46 @@ +"""Component-level model editing for VPD decompositions.""" + +# Re-export everything from the main module so `from spd.editing import ...` still works +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.compute import OptimizedPromptAttributionResult +from spd.editing._editing import ( + AblationEffect, + AlignmentResult, + ComponentMatch, + ComponentVectors, + EditableModel, + ForwardFn, + TokenGroupShift, + TokenPMIMatch, + UnembedMatch, + generate, + inspect_component, + measure_kl, + measure_token_probs, + parse_component_key, + search_by_token_pmi, + search_interpretations, +) +from spd.editing.component_trainer import ComponentTrainer + +__all__ = [ + "AblationEffect", + "AlignmentResult", + "AppTokenizer", + "ComponentTrainer", + "OptimizedPromptAttributionResult", + "ComponentMatch", + "ComponentVectors", + "EditableModel", + "ForwardFn", + "TokenGroupShift", + "TokenPMIMatch", + "UnembedMatch", + "generate", + "inspect_component", + "measure_kl", + "measure_token_probs", + "parse_component_key", + "search_by_token_pmi", + "search_interpretations", +] diff --git a/spd/editing/_editing.py b/spd/editing/_editing.py new file mode 100644 index 000000000..f34e384a4 --- /dev/null +++ b/spd/editing/_editing.py @@ -0,0 +1,808 @@ +"""Component-level model editing for VPD decompositions. + +Core class: EditableModel wraps ComponentModel + TransformerTopology and provides +methods for component analysis, editing, and measurement. It's callable +(tokens → logits) so it works as a ForwardFn anywhere. + +Usage: + from spd.editing import EditableModel, search_interpretations, generate + + em = EditableModel.from_wandb("wandb:goodfire/spd/s-892f140b") + matches = search_interpretations(harvest, interp, r"male pronoun") + + edit_fn = em.make_edit_fn({m.key: 0.0 for m in matches[:3]}) + text = generate(edit_fn, tokens, tokenizer) + effect = em.measure_kl(edit_fn, token_seqs) +""" + +import copy +import re +import sqlite3 +from collections.abc import Callable +from dataclasses import dataclass + +import orjson +import torch +import torch.nn.functional as F +from jaxtyping import Float, Int +from torch import Tensor + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.app.backend.compute import OptimizedPromptAttributionResult +from spd.autointerp.repo import InterpRepo +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ComponentData +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.components import make_mask_infos +from spd.topology.topology import TransformerTopology + +ForwardFn = Callable[[Int[Tensor, " seq"]], Float[Tensor, "seq vocab"]] + + +# -- Component key utilities --------------------------------------------------- + + +def parse_component_key(key: str) -> tuple[str, int]: + """'h.1.mlp.c_fc:802' -> ('h.1.mlp.c_fc', 802).""" + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) + + +# -- Search (free functions, don't need the model) ----------------------------- + + +@dataclass +class ComponentMatch: + key: str + label: str + confidence: str + firing_density: float + mean_activations: dict[str, float] + + +def search_interpretations( + harvest: HarvestRepo, + interp: InterpRepo, + pattern: str, + min_firing_density: float = 0.0, +) -> list[ComponentMatch]: + """Search component interpretations by regex on label. Sorted by firing density desc.""" + all_interps = interp.get_all_interpretations() + summary = harvest.get_summary() + + matches = [] + for key, result in all_interps.items(): + if key not in summary: + continue + if not re.search(pattern, result.label, re.IGNORECASE): + continue + s = summary[key] + if s.firing_density < min_firing_density: + continue + matches.append( + ComponentMatch( + key=key, + label=result.label, + confidence=result.confidence, + firing_density=s.firing_density, + mean_activations=s.mean_activations, + ) + ) + + matches.sort(key=lambda m: -m.firing_density) + return matches + + +@dataclass +class TokenPMIMatch: + key: str + pmi: float + firing_density: float + + +def search_by_token_pmi( + harvest: HarvestRepo, + token_ids: list[int], + side: str, + min_pmi: float = 0.5, + min_firing_density: float = 0.01, + top_k: int = 20, +) -> list[TokenPMIMatch]: + """Find components by input or output token PMI. + + side="output" finds components that PREDICT the given tokens. + side="input" finds components that RESPOND TO (fire on) the given tokens. + + For ablation, you almost always want side="output" — ablating output-side + components suppresses token production with far less collateral damage than + ablating input-side components. + """ + assert side in ("input", "output") + column = "output_token_pmi" if side == "output" else "input_token_pmi" + target_set = set(token_ids) + summary = harvest.get_summary() + + db_path = harvest._dir / "harvest.db" + conn = sqlite3.connect(f"file:{db_path}?immutable=1", uri=True) + + results = [] + for row in conn.execute(f"SELECT component_key, {column} FROM components"): + key: str = row[0] + if key not in summary or summary[key].firing_density < min_firing_density: + continue + pmi_data: dict[str, list[list[float]]] = orjson.loads(row[1]) + max_pmi = 0.0 + for tok_id, pmi in pmi_data.get("top", []): + if int(tok_id) in target_set and pmi > max_pmi: + max_pmi = pmi + if max_pmi >= min_pmi: + results.append( + TokenPMIMatch( + key=key, + pmi=max_pmi, + firing_density=summary[key].firing_density, + ) + ) + + conn.close() + results.sort(key=lambda r: -r.pmi) + return results[:top_k] + + +def inspect_component( + harvest: HarvestRepo, + interp: InterpRepo, + key: str, + tokenizer: AppTokenizer, + n_examples: int = 5, + n_pmi_tokens: int = 10, +) -> ComponentData: + """Print a detailed inspection of a component and return its data.""" + comp = harvest.get_component(key) + assert comp is not None, f"No harvest data for {key}" + interp_result = interp.get_interpretation(key) + + ci = comp.mean_activations.get("causal_importance", None) + ci_str = f", ci={ci:.4f}" if ci is not None else "" + print(f"{'=' * 70}") + print(f"{key} (density={comp.firing_density:.4f}{ci_str})") + if interp_result: + print(f"Label: [{interp_result.confidence}] {interp_result.label}") + print() + + decode = tokenizer.decode + + print("INPUT tokens (what makes it fire):") + for tok_id, pmi in comp.input_token_pmi.top[:n_pmi_tokens]: + print(f" {decode([tok_id]):15s} PMI={pmi:.2f}") + + print("\nOUTPUT tokens (what it predicts):") + for tok_id, pmi in comp.output_token_pmi.top[:n_pmi_tokens]: + print(f" {decode([tok_id]):15s} PMI={pmi:.2f}") + + print(f"\nActivation examples ({n_examples}):") + for ex in comp.activation_examples[:n_examples]: + parts = [] + for tid, firing in zip(ex.token_ids, ex.firings, strict=True): + tok_str = decode([tid]) + parts.append(f">>>{tok_str}<<<" if firing else tok_str) + act_vals = ex.activations.get("causal_importance", ex.activations.get("activation", [])) + max_act = max(act_vals) if act_vals else 0 + print(f" [max_act={max_act:.3f}] {''.join(parts)}") + print() + + return comp + + +# -- Result types -------------------------------------------------------------- + + +@dataclass +class ComponentVectors: + """Read (V) and write (U) vectors for a single rank-1 component. + + The component forward is: act = x @ read, out = act * write. + So `read` is the input direction (d_in) and `write` is the output direction (d_out). + """ + + key: str + read: Tensor + write: Tensor + d_in: int + d_out: int + + +@dataclass +class AlignmentResult: + cosine: float + dot: float + norm_a: float + norm_b: float + percentile: float + space_dim: int + space_name: str + + +@dataclass +class UnembedMatch: + token_id: int + token_str: str + cosine: float + dot: float + + +@dataclass +class AblationEffect: + mean_kl: float + baseline_ppl: float + edited_ppl: float + n_tokens: int + + @property + def ppl_increase_pct(self) -> float: + return (self.edited_ppl / self.baseline_ppl - 1) * 100 + + +@dataclass +class TokenGroupShift: + group_name: str + baseline_mean_prob: float + edited_mean_prob: float + n_positions: int + + @property + def change_pct(self) -> float: + if self.baseline_mean_prob == 0: + return float("inf") if self.edited_mean_prob > 0 else 0.0 + return (self.edited_mean_prob / self.baseline_mean_prob - 1) * 100 + + +# -- EditableModel ------------------------------------------------------------- + + +class EditableModel: + """ComponentModel + TransformerTopology with methods for editing and analysis. + + Callable: em(tokens) returns logits, so it works as a ForwardFn. + """ + + def __init__(self, model: ComponentModel) -> None: + self.model = model + self.topology = TransformerTopology(model.target_model) + + @classmethod + def from_wandb( + cls, wandb_path: str, device: str = "cuda" + ) -> tuple["EditableModel", AppTokenizer]: + """Load from wandb path. Returns (editable_model, tokenizer).""" + run_info = SPDRunInfo.from_path(wandb_path) + model = ComponentModel.from_run_info(run_info).to(device).eval() + assert run_info.config.tokenizer_name is not None + tokenizer = AppTokenizer.from_pretrained(run_info.config.tokenizer_name) + return cls(model), tokenizer + + def __call__(self, tokens: Int[Tensor, " seq"]) -> Float[Tensor, "seq vocab"]: + return self.model(tokens.unsqueeze(0)).squeeze(0) + + # -- Component geometry ---------------------------------------------------- + + def get_component_vectors(self, key: str) -> ComponentVectors: + """Get the read (V[:, c]) and write (U[c, :]) vectors for a component.""" + layer, idx = parse_component_key(key) + comp = self.model.components[layer] + return ComponentVectors( + key=key, + read=comp.V[:, idx], + write=comp.U[idx, :], + d_in=int(comp.d_in), # pyright: ignore[reportArgumentType] + d_out=int(comp.d_out), # pyright: ignore[reportArgumentType] + ) + + def component_alignment(self, key_a: str, key_b: str) -> AlignmentResult: + """Cosine/dot between key_a's write direction and key_b's read direction. + + Asserts they share a space (key_a's d_out == key_b's d_in). + Percentile is empirical over all pairs in the same two layers. + """ + a = self.get_component_vectors(key_a) + b = self.get_component_vectors(key_b) + assert a.d_out == b.d_in, ( + f"{key_a} writes d={a.d_out}, {key_b} reads d={b.d_in} — no shared space" + ) + + cos = F.cosine_similarity(a.write.unsqueeze(0), b.read.unsqueeze(0)).item() + dot = (a.write * b.read).sum().item() + + layer_a, _ = parse_component_key(key_a) + layer_b, _ = parse_component_key(key_b) + all_writes = self.model.components[layer_a].U + all_reads = self.model.components[layer_b].V + all_cos = F.normalize(all_writes, dim=1) @ F.normalize(all_reads, dim=0) + percentile = (all_cos.abs() < abs(cos)).float().mean().item() * 100 + + resid_dim = self.topology.unembed_module.in_features + space_name = "residual" if a.d_out == resid_dim else "neuron" + + return AlignmentResult( + cosine=cos, + dot=dot, + norm_a=a.write.norm().item(), + norm_b=b.read.norm().item(), + percentile=percentile, + space_dim=a.d_out, + space_name=space_name, + ) + + def unembed_alignment( + self, + key: str, + tokenizer: AppTokenizer, + top_k: int = 10, + ) -> tuple[list[UnembedMatch], list[UnembedMatch]]: + """Top boosted and suppressed tokens by alignment with write direction. + + Only works for components that write to the residual stream. + Returns (top_boosted, top_suppressed). + """ + vecs = self.get_component_vectors(key) + unembed = self.topology.unembed_module.weight # [vocab, d_model] + assert vecs.d_out == unembed.shape[1], ( + f"{key} writes d={vecs.d_out}, unembed expects d={unembed.shape[1]}" + ) + + all_cos = F.cosine_similarity(vecs.write.unsqueeze(0), unembed, dim=1) + all_dot = (vecs.write.unsqueeze(0) * unembed).sum(dim=1) + + decode = tokenizer.decode + + top_vals, top_ids = all_cos.topk(top_k) + boosted = [ + UnembedMatch(int(t), decode([int(t)]), v.item(), all_dot[t].item()) + for v, t in zip(top_vals, top_ids, strict=True) + ] + + bot_vals, bot_ids = all_cos.topk(top_k, largest=False) + suppressed = [ + UnembedMatch(int(t), decode([int(t)]), v.item(), all_dot[t].item()) + for v, t in zip(bot_vals, bot_ids, strict=True) + ] + + return boosted, suppressed + + def get_component_activations( + self, + tokens: Int[Tensor, " seq"], + key: str, + ) -> Float[Tensor, " seq"]: + """Component activation (v_c^T @ x) at each sequence position.""" + layer, idx = parse_component_key(key) + with torch.no_grad(): + out = self.model(tokens.unsqueeze(0), cache_type="input") + pre_weight_acts = out.cache[layer] # [1, seq, d_in] + comp = self.model.components[layer] + return (pre_weight_acts @ comp.V[:, idx]).squeeze(0) # [seq] + + def get_ci( + self, + tokens: Int[Tensor, " seq"], + ) -> dict[str, Float[Tensor, " seq C"]]: + """Get CI values for all components at all positions. Returns {layer: [seq, C]}.""" + with torch.no_grad(): + out = self.model(tokens.unsqueeze(0), cache_type="input") + ci = self.model.calc_causal_importances( + pre_weight_acts=out.cache, + sampling="continuous", + detach_inputs=False, + ) + return {layer: vals.squeeze(0) for layer, vals in ci.lower_leaky.items()} + + def find_components_by_examples( + self, + examples: list[tuple[Int[Tensor, " seq"], int]], + optim_steps: int = 100, + context_window: int = 10, + ci_alive_threshold: float = 0.0, + min_frequency: float = 0.7, + top_k: int = 20, + ) -> list[tuple[str, float]]: + """Find components needed for a behavior by optimizing sparse CI on examples. + + For each (token_sequence, target_position) pair, runs CI optimization + to find the minimal set of components needed to predict the token at + target_position. Components that appear in the sparse set across + >= min_frequency of examples are returned. + + Args: + examples: List of (token_sequence, target_position) pairs. + target_position is the sequence index of the token whose + prediction we want to explain. + optim_steps: Number of optimization steps per example. + ci_alive_threshold: CI threshold for considering a component "active" + in the optimized mask. + min_frequency: Fraction of examples where a component must be active. + top_k: Number of components to return. + + Returns: + List of (component_key, frequency) sorted by frequency descending. + """ + from spd.app.backend.optim_cis import ( + CELossConfig, + OptimCIConfig, + optimize_ci_values, + ) + from spd.configs import ImportanceMinimalityLossConfig + + counts: dict[str, int] = {} + n_examples = len(examples) + + for i, (tokens, target_pos) in enumerate(examples): + assert target_pos > 0, "target_position must be > 0 (need a previous position)" + + # Truncate to context window ending at target_pos (inclusive) + start = max(0, target_pos - context_window + 1) + window = tokens[start : target_pos + 1] + window_target_pos = target_pos - start + target_token = window[window_target_pos].item() + + config = OptimCIConfig( + seed=42, + lr=0.1, + steps=optim_steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.1, + log_freq=optim_steps + 1, # suppress logging + imp_min_config=ImportanceMinimalityLossConfig(coeff=0.1, pnorm=0.5, beta=1.0), + loss_config=CELossConfig( + coeff=20.0, + position=window_target_pos - 1, + label_token=int(target_token), + ), + sampling="continuous", + ce_kl_rounding_threshold=0.5, + mask_type="ci", + adv_pgd=None, + ) + + result = optimize_ci_values( + model=self.model, + tokens=window.unsqueeze(0), + config=config, + device=str(tokens.device), + ) + + # Extract active components from optimized CI + ci_outputs = result.params.create_ci_outputs(self.model, str(tokens.device)) + for layer_name, ci_vals in ci_outputs.lower_leaky.items(): + # ci_vals: [1, window_len, C] + pred_pos = window_target_pos - 1 + active = ci_vals[0, pred_pos, :] > ci_alive_threshold + for c in active.nonzero(as_tuple=True)[0]: + key = f"{layer_name}:{c.item()}" + counts[key] = counts.get(key, 0) + 1 + + print(f" Example {i + 1}/{n_examples}: L0={result.metrics.l0_total:.0f}") + + min_count = int(min_frequency * n_examples) + freq_results = [ + (key, count / n_examples) for key, count in counts.items() if count >= min_count + ] + freq_results.sort(key=lambda x: -x[1]) + return freq_results[:top_k] + + def optimize_circuit( + self, + tokens: Int[Tensor, " seq"], + target_position: int, + target_token: int, + optim_steps: int = 200, + imp_min_coeff: float = 0.1, + ce_coeff: float = 20.0, + ) -> OptimizedPromptAttributionResult: + """Optimize a sparse circuit for predicting target_token at target_position. + + Returns the full attribution graph (edges between components) from the + app's compute pipeline. The result includes node CI values, component + activations, and edge strengths. + + target_position is the sequence index of the token being predicted + (the logits at position target_position predict this token, so internally + we optimize for loss at position target_position). + """ + from spd.app.backend.compute import compute_prompt_attributions_optimized + from spd.app.backend.optim_cis import CELossConfig, OptimCIConfig + from spd.configs import ImportanceMinimalityLossConfig + from spd.topology.gradient_connectivity import get_sources_by_target + + device = str(tokens.device) + batched = tokens.unsqueeze(0) + + sources_by_target = get_sources_by_target(self.model, self.topology, device, "continuous") + + config = OptimCIConfig( + seed=42, + lr=0.1, + steps=optim_steps, + weight_decay=0.0, + lr_schedule="cosine", + lr_exponential_halflife=None, + lr_warmup_pct=0.1, + log_freq=optim_steps + 1, + imp_min_config=ImportanceMinimalityLossConfig(coeff=imp_min_coeff, pnorm=0.5, beta=1.0), + loss_config=CELossConfig( + coeff=ce_coeff, + position=target_position, + label_token=target_token, + ), + sampling="continuous", + ce_kl_rounding_threshold=0.5, + mask_type="ci", + adv_pgd=None, + ) + + return compute_prompt_attributions_optimized( + model=self.model, + topology=self.topology, + tokens=batched, + sources_by_target=sources_by_target, + optim_config=config, + output_prob_threshold=0.01, + device=device, + ) + + def print_circuit( + self, + circuit: OptimizedPromptAttributionResult, + tokens: Int[Tensor, " seq"], + tok: AppTokenizer, + interp: "InterpRepo | None" = None, + top_edges: int = 5, + min_ci: float = 0.0, + ) -> None: + """Print a human-readable summary of an optimized circuit.""" + from collections import defaultdict + + spans = tok.get_spans(tokens.tolist()) + + def parse_node(key: str) -> tuple[str, int, int]: + parts = key.split(":") + return ":".join(parts[:-2]), int(parts[-2]), int(parts[-1]) + + def node_label(key: str) -> str: + layer, seq, cidx = parse_node(key) + label = "" + if interp is not None: + ir = interp.get_interpretation(f"{layer}:{cidx}") + if ir: + label = f" [{ir.label[:35]}]" + return f"{layer}:{cidx}@{spans[seq].strip()}(p{seq}){label}" + + edges_by_target: dict[str, list[tuple[str, float, bool]]] = defaultdict(list) + for e in circuit.edges: + edges_by_target[str(e.target)].append((str(e.source), e.strength, e.is_cross_seq)) + + print(f"Circuit: {len(circuit.edges)} edges, L0={circuit.metrics.l0_total:.0f}") + print(f"Tokens: {list(enumerate(spans))}\n") + + for tgt_key in sorted(edges_by_target.keys()): + ci = circuit.node_ci_vals.get(tgt_key, 0) + if ci <= min_ci: + continue + + sources = edges_by_target[tgt_key] + sources.sort(key=lambda x: -abs(x[1])) + + print(f"{node_label(tgt_key)} ci={ci:.3f}") + for src_key, strength, cross_seq in sources[:top_edges]: + cross = " [x-seq]" if cross_seq else "" + print(f" <- {node_label(src_key)} attr={strength:+.4f}{cross}") + print() + + # -- Editing (mask-based, runtime) ----------------------------------------- + + def _edited_forward_batched( + self, + tokens: Int[Tensor, "1 seq"], + edits: dict[str, float], + ) -> Float[Tensor, "1 seq vocab"]: + """Forward with component mask edits applied uniformly (batched internal).""" + seq_len = tokens.shape[1] + device = tokens.device + + component_masks = { + layer: torch.ones(1, seq_len, C, device=device) + for layer, C in self.model.module_to_c.items() + } + for key, value in edits.items(): + layer, idx = parse_component_key(key) + assert layer in component_masks, f"Unknown layer: {layer}" + component_masks[layer][0, :, idx] = value + + mask_infos = make_mask_infos(component_masks, routing_masks="all") + return self.model(tokens, mask_infos=mask_infos) + + def _ci_guided_forward_batched( + self, + tokens: Int[Tensor, "1 seq"], + edits: dict[str, float], + ci_threshold: float, + ) -> Float[Tensor, "1 seq vocab"]: + """Forward with edits applied only where component CI exceeds threshold (batched).""" + seq_len = tokens.shape[1] + device = tokens.device + + output_with_cache = self.model(tokens, cache_type="input") + ci_outputs = self.model.calc_causal_importances( + pre_weight_acts=output_with_cache.cache, + sampling="continuous", + detach_inputs=False, + ) + ci_vals = ci_outputs.lower_leaky + + component_masks = { + layer: torch.ones(1, seq_len, C, device=device) + for layer, C in self.model.module_to_c.items() + } + for key, value in edits.items(): + layer, idx = parse_component_key(key) + assert layer in component_masks, f"Unknown layer: {layer}" + high_ci = ci_vals[layer][0, :, idx] > ci_threshold + component_masks[layer][0, high_ci, idx] = value + + mask_infos = make_mask_infos(component_masks, routing_masks="all") + return self.model(tokens, mask_infos=mask_infos) + + def make_edit_fn( + self, + edits: dict[str, float], + ci_threshold: float | None = None, + ) -> ForwardFn: + """Create a reusable unbatched tokens [seq] → logits [seq, vocab] function.""" + if ci_threshold is not None: + return lambda tokens: self._ci_guided_forward_batched( + tokens.unsqueeze(0), edits, ci_threshold + ).squeeze(0) + return lambda tokens: self._edited_forward_batched(tokens.unsqueeze(0), edits).squeeze(0) + + # -- Permanent weight editing ---------------------------------------------- + + def without_components(self, ablate_keys: list[str]) -> "EditableModel": + """Deep copy with components permanently subtracted from target model weights. + + The returned model's target_model is a standard transformer — no CI + function or mask_infos needed at inference. + """ + edited_model = copy.deepcopy(self.model) + + by_layer: dict[str, list[int]] = {} + for key in ablate_keys: + layer, idx = parse_component_key(key) + by_layer.setdefault(layer, []).append(idx) + + for layer_name, indices in by_layer.items(): + components = edited_model.components[layer_name] + target_module = edited_model.target_model.get_submodule(layer_name) + + for idx in indices: + contribution = (components.V[:, idx : idx + 1] @ components.U[idx : idx + 1, :]).T + target_module.weight.data -= contribution # pyright: ignore[reportOperatorIssue] + + return EditableModel(edited_model) + + +# -- Free functions (work with any ForwardFn) ---------------------------------- + + +def generate( + forward_fn: ForwardFn, + tokens: Int[Tensor, " seq"], + tokenizer: AppTokenizer, + max_new_tokens: int = 30, + temperature: float = 0.0, +) -> str: + """Greedy (temperature=0) or sampled generation from an arbitrary forward function. + + Takes unbatched tokens [seq]. Strips trailing EOS to avoid the model + treating the prompt as complete. + """ + eos_id = tokenizer.eos_token_id + if tokens[-1].item() == eos_id: + tokens = tokens[:-1] + generated = tokens.clone() + for _ in range(max_new_tokens): + logits = forward_fn(generated) + next_logits = logits[-1] + if temperature == 0: + next_id = next_logits.argmax() + else: + probs = F.softmax(next_logits / temperature, dim=-1) + next_id = torch.multinomial(probs, 1).squeeze() + generated = torch.cat([generated, next_id.unsqueeze(0)]) + if next_id.item() == tokenizer.eos_token_id: + break + return tokenizer.decode(generated.tolist()) + + +def measure_kl( + baseline_fn: ForwardFn, + edited_fn: ForwardFn, + token_seqs: list[Int[Tensor, " seq"]], +) -> AblationEffect: + """KL divergence and perplexity shift between two forward functions. + + Takes unbatched token sequences [seq]. + """ + total_kl = 0.0 + total_baseline_nll = 0.0 + total_edited_nll = 0.0 + total_tokens = 0 + + for tokens in token_seqs: + if tokens.shape[0] < 3: + continue + + with torch.no_grad(): + baseline_logits = baseline_fn(tokens) + edited_logits = edited_fn(tokens) + + baseline_lp = F.log_softmax(baseline_logits[:-1], dim=-1) + edited_lp = F.log_softmax(edited_logits[:-1], dim=-1) + + kl = F.kl_div(edited_lp, baseline_lp.exp(), reduction="sum", log_target=False) + + targets = tokens[1:] + baseline_nll = -baseline_lp[range(len(targets)), targets].sum() + edited_nll = -edited_lp[range(len(targets)), targets].sum() + + total_kl += kl.item() + total_baseline_nll += baseline_nll.item() + total_edited_nll += edited_nll.item() + total_tokens += len(targets) + + assert total_tokens > 0, "No tokens to evaluate" + return AblationEffect( + mean_kl=total_kl / total_tokens, + baseline_ppl=torch.exp(torch.tensor(total_baseline_nll / total_tokens)).item(), + edited_ppl=torch.exp(torch.tensor(total_edited_nll / total_tokens)).item(), + n_tokens=total_tokens, + ) + + +def measure_token_probs( + baseline_fn: ForwardFn, + edited_fn: ForwardFn, + token_seqs: list[Int[Tensor, " seq"]], + token_groups: dict[str, list[int]], +) -> dict[str, TokenGroupShift]: + """Probability shift for named groups of token IDs between two forward functions. + + Takes unbatched token sequences [seq]. + """ + baseline_sums: dict[str, float] = {name: 0.0 for name in token_groups} + edited_sums: dict[str, float] = {name: 0.0 for name in token_groups} + total_positions = 0 + + for tokens in token_seqs: + with torch.no_grad(): + baseline_logits = baseline_fn(tokens) + edited_logits = edited_fn(tokens) + + bp = F.softmax(baseline_logits, dim=-1) + ep = F.softmax(edited_logits, dim=-1) + + for name, ids in token_groups.items(): + baseline_sums[name] += bp[:, ids].sum().item() + edited_sums[name] += ep[:, ids].sum().item() + total_positions += bp.shape[0] + + assert total_positions > 0 + return { + name: TokenGroupShift( + group_name=name, + baseline_mean_prob=baseline_sums[name] / total_positions, + edited_mean_prob=edited_sums[name] / total_positions, + n_positions=total_positions, + ) + for name in token_groups + } diff --git a/spd/editing/component_trainer.py b/spd/editing/component_trainer.py new file mode 100644 index 000000000..e4873c835 --- /dev/null +++ b/spd/editing/component_trainer.py @@ -0,0 +1,164 @@ +"""Train specific component U/V vectors on arbitrary losses. + +Each SPD component is a rank-1 adapter: V[:, c] @ U[c, :]. This module lets you +unfreeze specific read (V column) or write (U row) vectors and optimize them, +while the rest of the model stays frozen. The forward pass uses all-ones component +masks plus a snapshotted weight delta, so it starts from the target model's behavior +and learns rank-1 perturbations. + +Usage: + em, tok = EditableModel.from_wandb("wandb:goodfire/spd/s-892f140b") + + trainer = ComponentTrainer( + em.model, + targets={"h.1.mlp.down_proj:798": "both", "h.1.attn.o_proj:82": "write"}, + lr=1e-4, + ) + + for batch_tokens in data: + logits = trainer(batch_tokens) + loss = F.cross_entropy(logits[:, :-1].flatten(0, 1), batch_tokens[:, 1:].flatten()) + trainer.step(loss) + + trainer.cleanup() +""" + +from functools import partial +from typing import Any, Literal + +import torch +from jaxtyping import Float +from torch import Tensor +from torch.utils.hooks import RemovableHandle + +from spd.editing._editing import parse_component_key +from spd.models.component_model import ComponentModel +from spd.models.components import EmbeddingComponents + +TrainMode = Literal["read", "write", "both"] + + +class ComponentTrainer: + """Trains specific component U/V vectors while the rest of the model is frozen. + + Forward pass runs through all components with ones masks + frozen weight delta, + so the model starts from target-model behavior. Only the specified V columns + (read vectors) and/or U rows (write vectors) receive gradients. + """ + + def __init__( + self, + model: ComponentModel, + targets: dict[str, TrainMode], + lr: float, + weight_decay: float = 0.0, + ): + self.model = model + self.model.train() + + # Snapshot weight deltas BEFORE changing requires_grad + self._frozen_weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] = { + k: v.detach().clone() for k, v in model.calc_weight_deltas().items() + } + + # Freeze everything + model.requires_grad_(False) + + # Parse targets into per-layer specs + layer_specs: dict[str, dict[int, TrainMode]] = {} + for key, mode in targets.items(): + layer, idx = parse_component_key(key) + assert layer in model.components, f"Unknown layer: {layer}" + assert idx < model.components[layer].C, ( + f"Component index {idx} >= C={model.components[layer].C} for {layer}" + ) + layer_specs.setdefault(layer, {})[idx] = mode + + # Unfreeze relevant V/U params and register gradient masks + self._grad_hooks: list[RemovableHandle] = [] + trainable_params: list[Tensor] = [] + + for layer, specs in layer_specs.items(): + comp = model.components[layer] + + train_any_read = any(m in ("read", "both") for m in specs.values()) + train_any_write = any(m in ("write", "both") for m in specs.values()) + + if train_any_read: + comp.V.requires_grad = True + v_mask = torch.zeros(comp.C, device=comp.V.device) + for idx, mode in specs.items(): + if mode in ("read", "both"): + v_mask[idx] = 1.0 + # Mask: [C] broadcast to [d_in, C] — zeros out gradients for non-target columns + hook = comp.V.register_hook(lambda g, m=v_mask: g * m.unsqueeze(0)) + self._grad_hooks.append(hook) + trainable_params.append(comp.V) + + if train_any_write: + comp.U.requires_grad = True + u_mask = torch.zeros(comp.C, device=comp.U.device) + for idx, mode in specs.items(): + if mode in ("write", "both"): + u_mask[idx] = 1.0 + # Mask: [C] broadcast to [C, d_out] — zeros out gradients for non-target rows + hook = comp.U.register_hook(lambda g, m=u_mask: g * m.unsqueeze(1)) + self._grad_hooks.append(hook) + trainable_params.append(comp.U) + + assert trainable_params, "No trainable parameters" + self.optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=weight_decay) + + def __call__(self, *args: Any, **kwargs: Any) -> Tensor: + return self.forward(*args, **kwargs) + + def forward(self, *args: Any, **kwargs: Any) -> Tensor: + """Forward pass with all-ones masks and frozen weight deltas. + + Accepts the same arguments as the target model. Hooks intercept each + decomposed layer and route through components + weight delta. + """ + hooks = { + module_name: partial(self._component_hook, module_name=module_name) + for module_name in self.model.target_module_paths + } + with self.model._attach_forward_hooks(hooks): + raw_out = self.model.target_model(*args, **kwargs) + return self.model._extract_output(raw_out) + + def _component_hook( + self, + _module: Any, + args: list[Any], + kwargs: dict[Any, Any], + _output: Any, + module_name: str, + ) -> Tensor: + assert len(args) == 1 and len(kwargs) == 0 + x = args[0] + components = self.model.components[module_name] + + batch_shape = x.shape if isinstance(components, EmbeddingComponents) else x.shape[:-1] + + weight_delta = self._frozen_weight_deltas[module_name].to(x.device) + weight_delta_mask = torch.ones(batch_shape, device=x.device) + + return components( + x, + mask=None, + weight_delta_and_mask=(weight_delta, weight_delta_mask), + ) + + def step(self, loss: Tensor) -> None: + """Backward + optimizer step.""" + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + def cleanup(self) -> None: + """Remove gradient hooks and re-freeze parameters.""" + for hook in self._grad_hooks: + hook.remove() + self._grad_hooks.clear() + self.model.requires_grad_(False) + self.model.eval() diff --git a/spd/editing/generate_token_divergence.py b/spd/editing/generate_token_divergence.py new file mode 100644 index 000000000..f7569df72 --- /dev/null +++ b/spd/editing/generate_token_divergence.py @@ -0,0 +1,198 @@ +"""Generate per-token divergence data for the token divergence visualisation. + +Runs forward passes on dataset text under named component ablations, +computes KL, reverse KL, JSD, and CE diff per token, writes JSON. + +Usage: + python -m spd.editing.generate_token_divergence \\ + wandb:goodfire/spd/s-892f140b \\ + --edits edits.yaml \\ + --n_tokens 1500 \\ + --out_path /path/to/www/data/kl_tokens.json + +edits.yaml format: + Male pronouns: + - h.1.mlp.down_proj:798 + - h.1.mlp.c_fc:144 + - h.1.attn.o_proj:82 + Question marks: + - h.1.mlp.down_proj:534 +""" + +import json +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +import yaml +from datasets import load_dataset + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.editing import EditableModel, ForwardFn +from spd.settings import SPD_OUT_DIR + +TokenData = dict[str, Any] + + +def compute_token_divergence( + em: EditableModel, + edit_fn: ForwardFn, + token_ids: list[int], + tok: AppTokenizer, + top_k: int = 5, +) -> list[TokenData]: + tokens = torch.tensor(token_ids, device="cuda") + spans = tok.get_spans(token_ids) + + with torch.no_grad(): + bl_logits = em(tokens) + ed_logits = edit_fn(tokens) + + bl_probs = F.softmax(bl_logits, dim=-1) + ed_probs = F.softmax(ed_logits, dim=-1) + bl_lp = F.log_softmax(bl_logits, dim=-1) + ed_lp = F.log_softmax(ed_logits, dim=-1) + + # All metrics at positions [0..seq-2], predicting tokens [1..seq-1] + fwd_kl_per_vocab = bl_probs[:-1] * (bl_lp[:-1] - ed_lp[:-1]) + fwd_kl = fwd_kl_per_vocab.sum(dim=-1) + rev_kl = (ed_probs[:-1] * (ed_lp[:-1] - bl_lp[:-1])).sum(dim=-1) + + m_probs = 0.5 * (bl_probs[:-1] + ed_probs[:-1]) + m_lp = m_probs.log() + jsd = 0.5 * (bl_probs[:-1] * (bl_lp[:-1] - m_lp)).sum(-1) + 0.5 * ( + ed_probs[:-1] * (ed_lp[:-1] - m_lp) + ).sum(-1) + + targets = tokens[1:] + ce_diff = -ed_lp[:-1][range(len(targets)), targets] - ( + -bl_lp[:-1][range(len(targets)), targets] + ) + + result: list[TokenData] = [] + for i in range(len(tokens)): + if i == 0: + result.append( + {"s": spans[i], "kl": 0, "rkl": 0, "jsd": 0, "ce": 0, "bl": [], "ed": [], "kc": []} + ) + continue + + prev = i - 1 + bl_top_v, bl_top_i = bl_probs[prev].topk(top_k) + ed_top_v, ed_top_i = ed_probs[prev].topk(top_k) + + bl_top = [ + [tok.decode([int(t)]), round(v.item(), 4)] + for v, t in zip(bl_top_v, bl_top_i, strict=True) + ] + ed_top = [ + [tok.decode([int(t)]), round(v.item(), 4)] + for v, t in zip(ed_top_v, ed_top_i, strict=True) + ] + + kl_contribs = fwd_kl_per_vocab[prev] + _, kl_top_i = kl_contribs.abs().topk(top_k) + kl_top = [ + [ + tok.decode([int(idx)]), + round(bl_probs[prev, idx].item(), 4), + round(ed_probs[prev, idx].item(), 4), + round(kl_contribs[idx].item(), 5), + ] + for idx in kl_top_i + ] + + result.append( + { + "s": spans[i], + "kl": round(fwd_kl[prev].item(), 5), + "rkl": round(rev_kl[prev].item(), 5), + "jsd": round(jsd[prev].item(), 5), + "ce": round(ce_diff[prev].item(), 5), + "bl": bl_top, + "ed": ed_top, + "kc": kl_top, + } + ) + + return result + + +def load_stories(n_tokens: int, max_seq_len: int = 300) -> list[list[int]]: + """Load stories from SimpleStories until we have >= n_tokens.""" + ds = load_dataset("SimpleStories/SimpleStories", split="train", streaming=True) + tok = AppTokenizer.from_pretrained("goodfire/SimpleStories-Llama-tokenizer") + stories = [] + total = 0 + for item in ds: + token_ids = tok.encode(item["story"]) + if len(token_ids) > max_seq_len: + token_ids = token_ids[:max_seq_len] + stories.append(token_ids) + total += len(token_ids) + if total >= n_tokens: + break + return stories + + +def main( + wandb_path: str, + edits: str, + n_tokens: int = 1500, + out_path: str | None = None, +) -> None: + edits_path = Path(edits) + assert edits_path.exists(), f"Edits file not found: {edits_path}" + with open(edits_path) as f: + edits_config: dict[str, list[str]] = yaml.safe_load(f) + + if out_path is None: + out_path = str(SPD_OUT_DIR / "www" / "data" / "kl_tokens.json") + out = Path(out_path) + out.parent.mkdir(parents=True, exist_ok=True) + + em, tok = EditableModel.from_wandb(wandb_path) + stories = load_stories(n_tokens) + total_tokens = sum(len(s) for s in stories) + print(f"Loaded {len(stories)} stories, {total_tokens} tokens") + + all_data: dict[str, Any] = {} + for edit_name, component_keys in edits_config.items(): + edit_dict = {k: 0.0 for k in component_keys} + edit_fn = em.make_edit_fn(edit_dict) + + edit_stories = [] + for story_ids in stories: + tokens = compute_token_divergence(em, edit_fn, story_ids, tok) + edit_stories.append(tokens) + + all_data[edit_name] = {"components": component_keys, "stories": edit_stories} + print(f" {edit_name}: done") + + # Global p99 scales + def p99(vals: list[float]) -> float: + s = sorted(vals) + return s[int(0.99 * len(s))] + + def collect(key: str) -> list[float]: + return [t[key] for e in all_data.values() for s in e["stories"] for t in s if t[key] != 0] + + all_data["_meta"] = { + "kl_max": round(p99(collect("kl")), 4), + "rkl_max": round(p99(collect("rkl")), 4), + "jsd_max": round(p99(collect("jsd")), 4), + "ce_max": round(p99([abs(v) for v in collect("ce")]), 4), + } + + with open(out, "w") as f: + json.dump(all_data, f, separators=(",", ":")) + + size_kb = out.stat().st_size / 1024 + print(f"Wrote {size_kb:.0f} KB to {out}") + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index de7c2c3b4..7e091e4e7 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -98,8 +98,8 @@ loss_metric_configs: lr_schedule: start_val: 0.01 warmup_pct: 0.025 - final_val_frac: 1.0 - fn_type: constant + final_val_frac: 0.1 + fn_type: cosine scope: type: per_batch_per_position use_sigmoid_parameterization: false diff --git a/spd/graph_interp/CLAUDE.md b/spd/graph_interp/CLAUDE.md new file mode 100644 index 000000000..327db0e7c --- /dev/null +++ b/spd/graph_interp/CLAUDE.md @@ -0,0 +1,71 @@ +# Graph Interpretation Module + +Context-aware component labeling using network graph structure. Unlike standard autointerp (one-shot per component), this module uses dataset attributions to provide graph context: each component's prompt includes labels from already-labeled components connected via the attribution graph. + +## Usage + +```bash +# Via SLURM (standalone) +spd-graph-interp --config config.yaml + +# Direct execution +python -m spd.graph_interp.scripts.run --config_json '{...}' +``` + +Requires `OPENROUTER_API_KEY` env var. Requires both harvest data and dataset attributions to exist. + +## Three-Phase Pipeline + +1. **Output pass** (late → early): "What does this component DO?" Each component's prompt includes top-K downstream components (by attribution) with their labels. Late layers labeled first so earlier layers see labeled downstream context. + +2. **Input pass** (early → late): "What TRIGGERS this component?" Each component's prompt includes top-K upstream components (by attribution) + co-firing components (Jaccard/PMI). Early layers labeled first so later layers see labeled upstream context. Independent of the output pass. + +3. **Unification** (parallel): Synthesizes output + input labels into a single unified label per component. + +All three phases run in a single invocation. Resume is per-phase via completed key sets in the DB. + +## Data Storage + +``` +SPD_OUT_DIR/graph_interp// +└── ti-YYYYMMDD_HHMMSS/ + ├── interp.db # SQLite: output_labels, input_labels, unified_labels, prompt_edges + └── config.yaml +``` + +## Database Schema + +- `output_labels`: component_key → label, confidence, reasoning, raw_response, prompt +- `input_labels`: same schema as output_labels +- `unified_labels`: same schema as output_labels +- `prompt_edges`: directed filtered graph of (component, related_key, pass, attribution, related_label) +- `config`: key-value store + +## Architecture + +| File | Purpose | +|------|---------| +| `config.py` | `GraphInterpConfig`, `GraphInterpSlurmConfig` | +| `schemas.py` | `LabelResult`, `PromptEdge`, path helpers | +| `db.py` | `GraphInterpDB` — SQLite with WAL mode | +| `ordering.py` | Topological sort via `CanonicalWeight` from topology module | +| `graph_context.py` | `RelatedComponent`, gather attributed + co-firing components | +| `prompts.py` | Three prompt formatters (output, input, unification) | +| `interpret.py` | Main three-phase execution loop | +| `repo.py` | `GraphInterpRepo` — read-only access to results | +| `scripts/run.py` | CLI entry point (called by SLURM) | +| `scripts/run_slurm.py` | SLURM submission | +| `scripts/run_slurm_cli.py` | Thin CLI wrapper for `spd-graph-interp` | + +## Dependencies + +- Harvest data (component stats, correlations, token stats) +- Dataset attributions (component-to-component attribution strengths) +- Reuses `map_llm_calls` from `spd/autointerp/llm_api.py` +- Reuses prompt helpers from `spd/autointerp/prompt_helpers.py` + +## SLURM Integration + +- 0 GPUs, 16 CPUs, 240GB memory (CPU-only, LLM API calls) +- Depends on both harvest merge AND attribution merge jobs +- Entry point: `spd-graph-interp` diff --git a/spd/graph_interp/__init__.py b/spd/graph_interp/__init__.py new file mode 100644 index 000000000..61e182fda --- /dev/null +++ b/spd/graph_interp/__init__.py @@ -0,0 +1 @@ +"""Graph interpretation: context-aware component labeling using graph structure.""" diff --git a/spd/graph_interp/config.py b/spd/graph_interp/config.py new file mode 100644 index 000000000..e6e7441d3 --- /dev/null +++ b/spd/graph_interp/config.py @@ -0,0 +1,26 @@ +"""Graph interpretation configuration.""" + +from openrouter.components import Effort + +from spd.base_config import BaseConfig +from spd.dataset_attributions.storage import AttrMetric +from spd.settings import DEFAULT_PARTITION_NAME + + +class GraphInterpConfig(BaseConfig): + model: str = "google/gemini-3-flash-preview" + reasoning_effort: Effort = "low" + attr_metric: AttrMetric = "attr_abs" + top_k_attributed: int = 8 + max_examples: int = 20 + label_max_words: int = 8 + cost_limit_usd: float | None = None + max_requests_per_minute: int = 500 + max_concurrent: int = 50 + limit: int | None = None + + +class GraphInterpSlurmConfig(BaseConfig): + config: GraphInterpConfig + partition: str = DEFAULT_PARTITION_NAME + time: str = "24:00:00" diff --git a/spd/graph_interp/db.py b/spd/graph_interp/db.py new file mode 100644 index 000000000..052b06f26 --- /dev/null +++ b/spd/graph_interp/db.py @@ -0,0 +1,234 @@ +"""SQLite database for graph interpretation data.""" + +import sqlite3 +from pathlib import Path + +from spd.graph_interp.schemas import LabelResult, PromptEdge + +DONE_MARKER = ".done" + +_SCHEMA = """\ +CREATE TABLE IF NOT EXISTS output_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS input_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS unified_labels ( + component_key TEXT PRIMARY KEY, + label TEXT NOT NULL, + confidence TEXT NOT NULL, + reasoning TEXT NOT NULL, + raw_response TEXT NOT NULL, + prompt TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS prompt_edges ( + component_key TEXT NOT NULL, + related_key TEXT NOT NULL, + pass TEXT NOT NULL, + attribution REAL NOT NULL, + related_label TEXT, + related_confidence TEXT, + PRIMARY KEY (component_key, related_key, pass) +); + +CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL +); +""" + + +class GraphInterpDB: + def __init__(self, db_path: Path, readonly: bool = False) -> None: + if readonly: + self._conn = sqlite3.connect( + f"file:{db_path}?immutable=1", uri=True, check_same_thread=False + ) + else: + self._conn = sqlite3.connect(str(db_path), check_same_thread=False) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.executescript(_SCHEMA) + self._db_path = db_path + self._conn.row_factory = sqlite3.Row + + def mark_done(self) -> None: + (self._db_path.parent / DONE_MARKER).touch() + + # -- Output labels --------------------------------------------------------- + + def save_output_label(self, result: LabelResult) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO output_labels VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def get_output_label(self, component_key: str) -> LabelResult | None: + row = self._conn.execute( + "SELECT * FROM output_labels WHERE component_key = ?", (component_key,) + ).fetchone() + if row is None: + return None + return _row_to_label_result(row) + + def get_all_output_labels(self) -> dict[str, LabelResult]: + rows = self._conn.execute("SELECT * FROM output_labels").fetchall() + return {row["component_key"]: _row_to_label_result(row) for row in rows} + + def get_completed_output_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM output_labels").fetchall() + return {row["component_key"] for row in rows} + + # -- Input labels ---------------------------------------------------------- + + def save_input_label(self, result: LabelResult) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO input_labels VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def get_input_label(self, component_key: str) -> LabelResult | None: + row = self._conn.execute( + "SELECT * FROM input_labels WHERE component_key = ?", (component_key,) + ).fetchone() + if row is None: + return None + return _row_to_label_result(row) + + def get_all_input_labels(self) -> dict[str, LabelResult]: + rows = self._conn.execute("SELECT * FROM input_labels").fetchall() + return {row["component_key"]: _row_to_label_result(row) for row in rows} + + def get_completed_input_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM input_labels").fetchall() + return {row["component_key"] for row in rows} + + # -- Unified labels -------------------------------------------------------- + + def save_unified_label(self, result: LabelResult) -> None: + self._conn.execute( + "INSERT OR REPLACE INTO unified_labels VALUES (?, ?, ?, ?, ?, ?)", + ( + result.component_key, + result.label, + result.confidence, + result.reasoning, + result.raw_response, + result.prompt, + ), + ) + self._conn.commit() + + def get_unified_label(self, component_key: str) -> LabelResult | None: + row = self._conn.execute( + "SELECT * FROM unified_labels WHERE component_key = ?", (component_key,) + ).fetchone() + if row is None: + return None + return _row_to_label_result(row) + + def get_all_unified_labels(self) -> dict[str, LabelResult]: + rows = self._conn.execute("SELECT * FROM unified_labels").fetchall() + return {row["component_key"]: _row_to_label_result(row) for row in rows} + + def get_completed_unified_keys(self) -> set[str]: + rows = self._conn.execute("SELECT component_key FROM unified_labels").fetchall() + return {row["component_key"] for row in rows} + + # -- Prompt edges ---------------------------------------------------------- + + def save_prompt_edges(self, edges: list[PromptEdge]) -> None: + rows = [ + ( + e.component_key, + e.related_key, + e.pass_name, + e.attribution, + e.related_label, + e.related_confidence, + ) + for e in edges + ] + self._conn.executemany( + "INSERT OR REPLACE INTO prompt_edges VALUES (?, ?, ?, ?, ?, ?)", + rows, + ) + self._conn.commit() + + def get_prompt_edges(self, component_key: str) -> list[PromptEdge]: + rows = self._conn.execute( + "SELECT * FROM prompt_edges WHERE component_key = ?", (component_key,) + ).fetchall() + return [_row_to_prompt_edge(row) for row in rows] + + def get_all_prompt_edges(self) -> list[PromptEdge]: + rows = self._conn.execute("SELECT * FROM prompt_edges").fetchall() + return [_row_to_prompt_edge(row) for row in rows] + + # -- Config ---------------------------------------------------------------- + + def save_config(self, key: str, value: str) -> None: + self._conn.execute("INSERT OR REPLACE INTO config VALUES (?, ?)", (key, value)) + self._conn.commit() + + # -- Stats ----------------------------------------------------------------- + + def get_label_count(self, table: str) -> int: + assert table in ("output_labels", "input_labels", "unified_labels") + row = self._conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone() + assert row is not None + return row[0] + + def close(self) -> None: + self._conn.close() + + +def _row_to_label_result(row: sqlite3.Row) -> LabelResult: + return LabelResult( + component_key=row["component_key"], + label=row["label"], + confidence=row["confidence"], + reasoning=row["reasoning"], + raw_response=row["raw_response"], + prompt=row["prompt"], + ) + + +def _row_to_prompt_edge(row: sqlite3.Row) -> PromptEdge: + return PromptEdge( + component_key=row["component_key"], + related_key=row["related_key"], + pass_name=row["pass"], + attribution=row["attribution"], + related_label=row["related_label"], + related_confidence=row["related_confidence"], + ) diff --git a/spd/graph_interp/graph_context.py b/spd/graph_interp/graph_context.py new file mode 100644 index 000000000..9ac08ad73 --- /dev/null +++ b/spd/graph_interp/graph_context.py @@ -0,0 +1,98 @@ +"""Gather related components from attribution graph.""" + +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal + +from spd.dataset_attributions.storage import DatasetAttributionEntry +from spd.graph_interp.ordering import parse_component_key +from spd.graph_interp.schemas import LabelResult +from spd.harvest.analysis import get_correlated_components +from spd.harvest.storage import CorrelationStorage + + +@dataclass +class RelatedComponent: + component_key: str + attribution: float + label: str | None + confidence: str | None + jaccard: float | None + pmi: float | None + + +GetAttributed = Callable[[str, int, Literal["positive", "negative"]], list[DatasetAttributionEntry]] + + +def get_related_components( + component_key: str, + get_attributed: GetAttributed, + correlation_storage: CorrelationStorage, + labels_so_far: dict[str, LabelResult], + k: int, +) -> list[RelatedComponent]: + """Top-K components connected via attribution, enriched with co-firing stats and labels.""" + my_layer, _ = parse_component_key(component_key) + + pos = get_attributed(component_key, k * 2, "positive") + neg = get_attributed(component_key, k * 2, "negative") + + candidates = pos + neg + candidates.sort(key=lambda e: abs(e.value), reverse=True) + candidates = candidates[:k] + + cofiring = _build_cofiring_lookup(component_key, correlation_storage, k * 3) + result = [_build_related(e.component_key, e.value, cofiring, labels_so_far) for e in candidates] + + for r in result: + r_layer, _ = parse_component_key(r.component_key) + assert r_layer != my_layer, ( + f"Same-layer component {r.component_key} in related list for {component_key}" + ) + + return result + + +def _build_cofiring_lookup( + component_key: str, + correlation_storage: CorrelationStorage, + k: int, +) -> dict[str, tuple[float, float | None]]: + lookup: dict[str, tuple[float, float | None]] = {} + + jaccard_results = get_correlated_components( + correlation_storage, component_key, metric="jaccard", top_k=k + ) + for c in jaccard_results: + lookup[c.component_key] = (c.score, None) + + pmi_results = get_correlated_components( + correlation_storage, component_key, metric="pmi", top_k=k + ) + for c in pmi_results: + if c.component_key in lookup: + jaccard_val = lookup[c.component_key][0] + lookup[c.component_key] = (jaccard_val, c.score) + else: + lookup[c.component_key] = (0.0, c.score) + + return lookup + + +def _build_related( + related_key: str, + attribution: float, + cofiring: dict[str, tuple[float, float | None]], + labels_so_far: dict[str, LabelResult], +) -> RelatedComponent: + label = labels_so_far.get(related_key) + jaccard, pmi = cofiring.get(related_key, (None, None)) + + return RelatedComponent( + component_key=related_key, + attribution=attribution, + label=label.label if label else None, + confidence=label.confidence if label else None, + jaccard=jaccard, + pmi=pmi, + ) diff --git a/spd/graph_interp/interpret.py b/spd/graph_interp/interpret.py new file mode 100644 index 000000000..e10b1ccd2 --- /dev/null +++ b/spd/graph_interp/interpret.py @@ -0,0 +1,381 @@ +"""Main three-phase graph interpretation execution. + +Structure: + output_labels = scan(layers_reversed, step) + input_labels = scan(layers_forward, step) + unified = map(output_labels + input_labels, unify) + +Each scan folds over layers. Within a layer, components are labeled in parallel +via async LLM calls. The fold accumulator (labels_so_far) lets each component's +prompt include labels from previously-processed layers. +""" + +import asyncio +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable +from functools import partial +from pathlib import Path +from typing import Literal + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.llm_api import CostTracker, LLMError, LLMJob, LLMResult, map_llm_calls +from spd.autointerp.schemas import ModelMetadata +from spd.dataset_attributions.storage import ( + AttrMetric, + DatasetAttributionEntry, + DatasetAttributionStorage, +) +from spd.graph_interp import graph_context +from spd.graph_interp.config import GraphInterpConfig +from spd.graph_interp.db import GraphInterpDB +from spd.graph_interp.graph_context import RelatedComponent, get_related_components +from spd.graph_interp.ordering import group_and_sort_by_layer +from spd.graph_interp.prompts import ( + LABEL_SCHEMA, + format_input_prompt, + format_output_prompt, + format_unification_prompt, +) +from spd.graph_interp.schemas import LabelResult, PromptEdge +from spd.harvest.analysis import get_input_token_stats, get_output_token_stats +from spd.harvest.repo import HarvestRepo +from spd.harvest.storage import CorrelationStorage, TokenStatsStorage +from spd.log import logger + +GetRelated = Callable[[str, dict[str, LabelResult]], list[RelatedComponent]] +Step = Callable[[list[str], dict[str, LabelResult]], Awaitable[dict[str, LabelResult]]] + + +def run_graph_interp( + openrouter_api_key: str, + config: GraphInterpConfig, + harvest: HarvestRepo, + attribution_storage: DatasetAttributionStorage, + correlation_storage: CorrelationStorage, + token_stats: TokenStatsStorage, + model_metadata: ModelMetadata, + db_path: Path, + tokenizer_name: str, +) -> None: + logger.info("Loading tokenizer...") + app_tok = AppTokenizer.from_pretrained(tokenizer_name) + + logger.info("Loading component summaries...") + summaries = harvest.get_summary() + alive = {k: s for k, s in summaries.items() if s.firing_density > 0.0} + all_keys = sorted(alive, key=lambda k: alive[k].firing_density, reverse=True) + if config.limit is not None: + all_keys = all_keys[: config.limit] + + layers = group_and_sort_by_layer(all_keys, model_metadata.layer_descriptions) + total = len(all_keys) + logger.info(f"Graph interp: {total} components across {len(layers)} layers") + + # -- Injected behaviours --------------------------------------------------- + + shared_cost = CostTracker(limit_usd=config.cost_limit_usd) + + async def llm_map( + jobs: Iterable[LLMJob], n_total: int | None = None + ) -> AsyncGenerator[LLMResult | LLMError]: + async for result in map_llm_calls( + openrouter_api_key=openrouter_api_key, + model=config.model, + reasoning_effort=config.reasoning_effort, + jobs=jobs, + max_tokens=8000, + max_concurrent=config.max_concurrent, + max_requests_per_minute=config.max_requests_per_minute, + cost_limit_usd=None, + response_schema=LABEL_SCHEMA, + n_total=n_total, + cost_tracker=shared_cost, + ): + yield result + + concrete_to_canon = model_metadata.layer_descriptions + canon_to_concrete = {v: k for k, v in concrete_to_canon.items()} + + def _translate_entries(entries: list[DatasetAttributionEntry]) -> list[DatasetAttributionEntry]: + for e in entries: + if e.layer in canon_to_concrete: + e.layer = canon_to_concrete[e.layer] + e.component_key = f"{e.layer}:{e.component_idx}" + return entries + + def _to_canon(concrete_key: str) -> str: + layer, idx = concrete_key.rsplit(":", 1) + return f"{concrete_to_canon[layer]}:{idx}" + + def _make_get_targets(metric: AttrMetric) -> "graph_context.GetAttributed": + def get( + key: str, k: int, sign: Literal["positive", "negative"] + ) -> list[DatasetAttributionEntry]: + return _translate_entries( + attribution_storage.get_top_targets(_to_canon(key), k=k, sign=sign, metric=metric) + ) + + return get + + def _make_get_sources(metric: AttrMetric) -> "graph_context.GetAttributed": + def get( + key: str, k: int, sign: Literal["positive", "negative"] + ) -> list[DatasetAttributionEntry]: + return _translate_entries( + attribution_storage.get_top_sources(_to_canon(key), k=k, sign=sign, metric=metric) + ) + + return get + + def _get_related(get_attributed: "graph_context.GetAttributed") -> GetRelated: + def get(key: str, labels_so_far: dict[str, LabelResult]) -> list[RelatedComponent]: + return get_related_components( + key, + get_attributed, + correlation_storage, + labels_so_far, + config.top_k_attributed, + ) + + return get + + # -- Layer processors ------------------------------------------------------ + + async def process_output_layer( + get_related: GetRelated, + save_label: Callable[[LabelResult], None], + pending: list[str], + labels_so_far: dict[str, LabelResult], + ) -> dict[str, LabelResult]: + def jobs() -> Iterable[LLMJob]: + for key in pending: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + o_stats = get_output_token_stats(token_stats, key, app_tok, top_k=50) + assert o_stats is not None, f"No output token stats for {key}" + + related = get_related(key, labels_so_far) + _save_edges(db, key, related, "output") + prompt = format_output_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + output_token_stats=o_stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) + + return await _collect_labels(llm_map, jobs(), len(pending), save_label) + + async def process_input_layer( + get_related: GetRelated, + save_label: Callable[[LabelResult], None], + pending: list[str], + labels_so_far: dict[str, LabelResult], + ) -> dict[str, LabelResult]: + def jobs() -> Iterable[LLMJob]: + for key in pending: + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + i_stats = get_input_token_stats(token_stats, key, app_tok, top_k=20) + assert i_stats is not None, f"No input token stats for {key}" + + related = get_related(key, labels_so_far) + _save_edges(db, key, related, "input") + prompt = format_input_prompt( + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + input_token_stats=i_stats, + related=related, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) + + return await _collect_labels(llm_map, jobs(), len(pending), save_label) + + # -- Scan (fold over layers) ----------------------------------------------- + + async def scan( + layer_order: list[tuple[str, list[str]]], + initial: dict[str, LabelResult], + step: Step, + ) -> dict[str, LabelResult]: + labels = dict(initial) + if labels: + logger.info(f"Resuming, {len(labels)} already completed") + + completed_so_far = 0 + for layer, keys in layer_order: + pending = [k for k in keys if k not in labels] + if not pending: + completed_so_far += len(keys) + continue + + new_labels = await step(pending, labels) + labels.update(new_labels) + + completed_so_far += len(keys) + logger.info(f"Completed layer {layer} ({completed_so_far}/{total})") + + return labels + + # -- Map (parallel over all components) ------------------------------------ + + async def map_unify( + output_labels: dict[str, LabelResult], + input_labels: dict[str, LabelResult], + ) -> None: + completed = db.get_completed_unified_keys() + keys = [k for k in all_keys if k not in completed] + if not keys: + logger.info("Unification: all labels already completed") + return + if completed: + logger.info(f"Unification: resuming, {len(completed)} already completed") + + n_skipped = 0 + + def jobs() -> Iterable[LLMJob]: + nonlocal n_skipped + for key in keys: + out = output_labels.get(key) + inp = input_labels.get(key) + if out is None or inp is None: + n_skipped += 1 + continue + component = harvest.get_component(key) + assert component is not None, f"Component {key} not found in harvest DB" + prompt = format_unification_prompt( + output_label=out, + input_label=inp, + component=component, + model_metadata=model_metadata, + app_tok=app_tok, + label_max_words=config.label_max_words, + max_examples=config.max_examples, + ) + yield LLMJob(prompt=prompt, schema=LABEL_SCHEMA, key=key) + + logger.info(f"Unifying {len(keys)} components") + new_labels = await _collect_labels(llm_map, jobs(), len(keys), db.save_unified_label) + + if n_skipped: + logger.warning(f"Skipped {n_skipped} components missing output or input labels") + logger.info(f"Unification: completed {len(new_labels)}/{len(keys)}") + + # -- Run ------------------------------------------------------------------- + + logger.info("Initializing DB and building scan steps...") + db = GraphInterpDB(db_path) + + metric = config.attr_metric + get_targets = _make_get_targets(metric) + get_sources = _make_get_sources(metric) + + label_output = partial( + process_output_layer, + _get_related(get_targets), + db.save_output_label, + ) + label_input = partial( + process_input_layer, + _get_related(get_sources), + db.save_input_label, + ) + + async def _run() -> None: + logger.section("Phase 1: Output pass (late → early)") + output_labels = await scan(list(reversed(layers)), db.get_all_output_labels(), label_output) + + logger.section("Phase 2: Input pass (early → late)") + input_labels = await scan(list(layers), db.get_all_input_labels(), label_input) + + logger.section("Phase 3: Unification") + await map_unify(output_labels, input_labels) + + logger.info( + f"Completed: {db.get_label_count('output_labels')} output, " + f"{db.get_label_count('input_labels')} input, " + f"{db.get_label_count('unified_labels')} unified labels -> {db_path}" + ) + db.mark_done() + + try: + asyncio.run(_run()) + finally: + db.close() + + +# -- Shared LLM call machinery ------------------------------------------------ + + +async def _collect_labels( + llm_map: Callable[[Iterable[LLMJob], int | None], AsyncGenerator[LLMResult | LLMError]], + jobs: Iterable[LLMJob], + n_total: int, + save_label: Callable[[LabelResult], None], +) -> dict[str, LabelResult]: + """Run LLM jobs, parse results, save to DB, return new labels.""" + new_labels: dict[str, LabelResult] = {} + n_errors = 0 + + async for outcome in llm_map(jobs, n_total): + match outcome: + case LLMResult(job=job, parsed=parsed, raw=raw): + result = _parse_label(job.key, parsed, raw, job.prompt) + save_label(result) + new_labels[job.key] = result + case LLMError(job=job, error=e): + n_errors += 1 + logger.error(f"Skipping {job.key}: {type(e).__name__}: {e}") + _check_error_rate(n_errors, len(new_labels)) + + return new_labels + + +def _parse_label(key: str, parsed: dict[str, object], raw: str, prompt: str) -> LabelResult: + assert len(parsed) == 3, f"Expected 3 fields, got {len(parsed)}" + label = parsed["label"] + confidence = parsed["confidence"] + reasoning = parsed["reasoning"] + assert isinstance(label, str) and isinstance(confidence, str) and isinstance(reasoning, str) + return LabelResult( + component_key=key, + label=label, + confidence=confidence, + reasoning=reasoning, + raw_response=raw, + prompt=prompt, + ) + + +def _check_error_rate(n_errors: int, n_done: int) -> None: + total = n_errors + n_done + if total > 10 and n_errors / total > 0.05: + raise RuntimeError( + f"Error rate {n_errors / total:.0%} ({n_errors}/{total}) exceeds 5% threshold" + ) + + +def _save_edges( + db: GraphInterpDB, + component_key: str, + related: list[RelatedComponent], + pass_name: Literal["output", "input"], +) -> None: + edges = [ + PromptEdge( + component_key=component_key, + related_key=r.component_key, + pass_name=pass_name, + attribution=r.attribution, + related_label=r.label, + related_confidence=r.confidence, + ) + for r in related + ] + if edges: + db.save_prompt_edges(edges) diff --git a/spd/graph_interp/ordering.py b/spd/graph_interp/ordering.py new file mode 100644 index 000000000..9ef4d5afa --- /dev/null +++ b/spd/graph_interp/ordering.py @@ -0,0 +1,88 @@ +"""Layer ordering for graph interpretation. + +Uses the topology module's CanonicalWeight system for correct ordering +across all model architectures. Canonical addresses are provided by +ModelMetadata.layer_descriptions (concrete path → canonical string). +""" + +from spd.topology.canonical import ( + CanonicalWeight, + FusedAttnWeight, + GLUWeight, + LayerWeight, + MLPWeight, + SeparateAttnWeight, +) + +_SUBLAYER_ORDER = {"attn": 0, "attn_fused": 0, "glu": 1, "mlp": 1} + +_PROJECTION_ORDER: dict[type, dict[str, int]] = { + SeparateAttnWeight: {"q": 0, "k": 1, "v": 2, "o": 3}, + FusedAttnWeight: {"qkv": 0, "o": 1}, + GLUWeight: {"gate": 0, "up": 1, "down": 2}, + MLPWeight: {"up": 0, "down": 1}, +} + + +def canonical_sort_key(canonical: str) -> tuple[int, int, int]: + """Sort key for a canonical address string like '0.attn.q' or '1.mlp.down'.""" + weight = CanonicalWeight.parse(canonical) + assert isinstance(weight, LayerWeight), f"Expected LayerWeight, got {type(weight).__name__}" + + match weight.name: + case SeparateAttnWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["attn"] + proj_idx = _PROJECTION_ORDER[SeparateAttnWeight][p] + case FusedAttnWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["attn_fused"] + proj_idx = _PROJECTION_ORDER[FusedAttnWeight][p] + case GLUWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["glu"] + proj_idx = _PROJECTION_ORDER[GLUWeight][p] + case MLPWeight(weight=p): + sublayer_idx = _SUBLAYER_ORDER["mlp"] + proj_idx = _PROJECTION_ORDER[MLPWeight][p] + + return weight.layer_idx, sublayer_idx, proj_idx + + +def parse_component_key(key: str) -> tuple[str, int]: + """Split 'h.1.mlp.c_fc:42' into ('h.1.mlp.c_fc', 42).""" + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) + + +def group_and_sort_by_layer( + component_keys: list[str], + layer_descriptions: dict[str, str], +) -> list[tuple[str, list[str]]]: + """Group component keys by layer, return [(layer, [keys])] in topological order. + + Args: + component_keys: Component keys like 'h.0.attn.q_proj:42'. + layer_descriptions: Mapping from concrete layer path to canonical address + (from ModelMetadata.layer_descriptions). + """ + by_layer: dict[str, list[str]] = {} + for key in component_keys: + layer, _ = parse_component_key(key) + by_layer.setdefault(layer, []).append(key) + + def sort_key(layer: str) -> tuple[int, int, int]: + canonical = layer_descriptions[layer] + return canonical_sort_key(canonical) + + sorted_layers = sorted(by_layer.keys(), key=sort_key) + + result: list[tuple[str, list[str]]] = [] + for layer in sorted_layers: + keys = sorted(by_layer[layer], key=lambda k: parse_component_key(k)[1]) + result.append((layer, keys)) + return result + + +def is_later_layer(earlier: str, later: str, layer_descriptions: dict[str, str]) -> bool: + """Check if `later` is topologically after `earlier`.""" + return canonical_sort_key(layer_descriptions[earlier]) < canonical_sort_key( + layer_descriptions[later] + ) diff --git a/spd/graph_interp/prompts.py b/spd/graph_interp/prompts.py new file mode 100644 index 000000000..01874a745 --- /dev/null +++ b/spd/graph_interp/prompts.py @@ -0,0 +1,227 @@ +"""Prompt formatters for graph interpretation. + +Three prompts: +1. Output pass (late→early): "What does this component DO?" — output tokens, says examples, downstream +2. Input pass (early→late): "What TRIGGERS this component?" — input tokens, fires-on examples, upstream +3. Unification: Synthesize output + input labels into unified label. +""" + +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.prompt_helpers import ( + build_fires_on_examples, + build_input_section, + build_output_section, + build_says_examples, + density_note, + human_layer_desc, + layer_position_note, +) +from spd.autointerp.schemas import ModelMetadata +from spd.graph_interp.graph_context import RelatedComponent +from spd.graph_interp.schemas import LabelResult +from spd.harvest.analysis import TokenPRLift +from spd.harvest.schemas import ComponentData + +LABEL_SCHEMA: dict[str, object] = { + "type": "object", + "properties": { + "label": {"type": "string"}, + "confidence": {"type": "string", "enum": ["low", "medium", "high"]}, + "reasoning": {"type": "string"}, + }, + "required": ["label", "confidence", "reasoning"], + "additionalProperties": False, +} + + +def _component_header( + component: ComponentData, + model_metadata: ModelMetadata, +) -> str: + canonical = model_metadata.layer_descriptions.get(component.layer, component.layer) + layer_desc = human_layer_desc(canonical, model_metadata.n_blocks) + position_note = layer_position_note(canonical, model_metadata.n_blocks) + dens_note = density_note(component.firing_density) + + rate_str = ( + f"~1 in {int(1 / component.firing_density)} tokens" + if component.firing_density > 0.0 + else "extremely rare" + ) + + context_notes = " ".join(filter(None, [position_note, dens_note])) + + return f"""\ +## Context +- Component: {layer_desc} (component {component.component_idx}), {model_metadata.n_blocks}-block model +- Firing rate: {component.firing_density * 100:.2f}% ({rate_str}) +{context_notes}""" + + +def format_output_prompt( + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + output_token_stats: TokenPRLift, + related: list[RelatedComponent], + label_max_words: int, + max_examples: int, +) -> str: + header = _component_header(component, model_metadata) + + output_pmi = ( + [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.output_token_pmi.top] + if component.output_token_pmi.top + else None + ) + output_section = build_output_section(output_token_stats, output_pmi) + says = build_says_examples(component, app_tok, max_examples) + related_table = _format_related_table(related, model_metadata, app_tok) + + return f"""\ +You are analyzing a component in a neural network to understand its OUTPUT FUNCTION — what it does when it fires. + +{header} + +## Output tokens (what the model produces when this component fires) +{output_section} +## Activation examples — what the model produces +{says} +## Downstream components (what this component influences) +These components in later layers are most influenced by this component (by gradient attribution): +{related_table} +## Task +Give a {label_max_words}-word-or-fewer label describing this component's OUTPUT FUNCTION — what it does when it fires. + +Say "unclear" if the evidence is too weak. + +Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} +""" + + +def format_input_prompt( + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + input_token_stats: TokenPRLift, + related: list[RelatedComponent], + label_max_words: int, + max_examples: int, +) -> str: + header = _component_header(component, model_metadata) + + input_pmi = ( + [(app_tok.get_tok_display(tid), pmi) for tid, pmi in component.input_token_pmi.top] + if component.input_token_pmi.top + else None + ) + input_section = build_input_section(input_token_stats, input_pmi) + fires_on = build_fires_on_examples(component, app_tok, max_examples) + related_table = _format_related_table(related, model_metadata, app_tok) + + return f"""\ +You are analyzing a component in a neural network to understand its INPUT FUNCTION — what triggers it to fire. + +{header} + +## Input tokens (what causes this component to fire) +{input_section} +## Activation examples — where the component fires +{fires_on} +## Upstream components (what feeds into this component) +These components in earlier layers most strongly attribute to this component: +{related_table} +## Task +Give a {label_max_words}-word-or-fewer label describing this component's INPUT FUNCTION — what conditions trigger it to fire. + +Say "unclear" if the evidence is too weak. + +Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} +""" + + +def format_unification_prompt( + output_label: LabelResult, + input_label: LabelResult, + component: ComponentData, + model_metadata: ModelMetadata, + app_tok: AppTokenizer, + label_max_words: int, + max_examples: int, +) -> str: + header = _component_header(component, model_metadata) + fires_on = build_fires_on_examples(component, app_tok, max_examples) + says = build_says_examples(component, app_tok, max_examples) + + return f"""\ +A neural network component has been analyzed from two perspectives. + +{header} + +## Activation examples — where the component fires +{fires_on} +## Activation examples — what the model produces +{says} +## Two-perspective analysis + +OUTPUT FUNCTION: "{output_label.label}" (confidence: {output_label.confidence}) + Reasoning: {output_label.reasoning} + +INPUT FUNCTION: "{input_label.label}" (confidence: {input_label.confidence}) + Reasoning: {input_label.reasoning} + +## Task +Synthesize these into a single unified label (max {label_max_words} words) that captures the component's complete role. If input and output suggest the same concept, unify them. If they describe genuinely different aspects (e.g. fires on X, produces Y), combine both. + +Respond with JSON: {{"label": "...", "confidence": "low|medium|high", "reasoning": "..."}} +""" + + +def _format_related_table( + components: list[RelatedComponent], + model_metadata: ModelMetadata, + app_tok: AppTokenizer, +) -> str: + # Filter: only show labeled components and token entries (embed/output) + visible = [n for n in components if n.label is not None or _is_token_entry(n.component_key)] + if not visible: + return "(no related components with labels found)\n" + + # Normalize attributions: strongest = 1.0 + max_attr = max(abs(n.attribution) for n in visible) + norm = max_attr if max_attr > 0 else 1.0 + + lines: list[str] = [] + for n in visible: + display = _component_display(n.component_key, model_metadata, app_tok) + rel_attr = n.attribution / norm + + parts = [f" {display} (relative attribution: {rel_attr:+.2f}"] + if n.jaccard is not None: + parts.append(f", co-firing Jaccard: {n.jaccard:.3f}") + parts.append(")") + + line = "".join(parts) + if n.label is not None: + line += f'\n label: "{n.label}" (confidence: {n.confidence})' + lines.append(line) + + return "\n".join(lines) + "\n" + + +def _is_token_entry(key: str) -> bool: + layer = key.rsplit(":", 1)[0] + return layer in ("embed", "output") + + +def _component_display(key: str, model_metadata: ModelMetadata, app_tok: AppTokenizer) -> str: + layer, idx_str = key.rsplit(":", 1) + match layer: + case "embed": + return f'input token "{app_tok.get_tok_display(int(idx_str))}"' + case "output": + return f'output token "{app_tok.get_tok_display(int(idx_str))}"' + case _: + canonical = model_metadata.layer_descriptions.get(layer, layer) + desc = human_layer_desc(canonical, model_metadata.n_blocks) + return f"component from {desc}" diff --git a/spd/graph_interp/repo.py b/spd/graph_interp/repo.py new file mode 100644 index 000000000..6667c4e1e --- /dev/null +++ b/spd/graph_interp/repo.py @@ -0,0 +1,95 @@ +"""Graph interpretation data repository. + +Owns SPD_OUT_DIR/graph_interp// and provides read access +to output, input, and unified labels. + +Use GraphInterpRepo.open() to construct — returns None if no data exists. +""" + +from pathlib import Path +from typing import Any + +import yaml + +from spd.graph_interp.db import DONE_MARKER, GraphInterpDB +from spd.graph_interp.schemas import LabelResult, PromptEdge, get_graph_interp_dir + + +class GraphInterpRepo: + """Read access to graph interpretation data for a single run.""" + + def __init__(self, db: GraphInterpDB, subrun_dir: Path, run_id: str) -> None: + self._db = db + self._subrun_dir = subrun_dir + self.subrun_id = subrun_dir.name + self.run_id = run_id + + @classmethod + def open(cls, run_id: str) -> "GraphInterpRepo | None": + """Open graph interp data for a run. Returns None if no data exists.""" + base_dir = get_graph_interp_dir(run_id) + if not base_dir.exists(): + return None + candidates = sorted( + [ + d + for d in base_dir.iterdir() + if d.is_dir() and d.name.startswith("ti-") and (d / DONE_MARKER).exists() + ], + key=lambda d: d.name, + ) + if not candidates: + return None + subrun_dir = candidates[-1] + db_path = subrun_dir / "interp.db" + if not db_path.exists(): + return None + return cls( + db=GraphInterpDB(db_path, readonly=True), + subrun_dir=subrun_dir, + run_id=run_id, + ) + + def get_config(self) -> dict[str, Any] | None: + config_path = self._subrun_dir / "config.yaml" + if not config_path.exists(): + return None + with open(config_path) as f: + return yaml.safe_load(f) + + # -- Labels ---------------------------------------------------------------- + + def get_all_output_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_output_labels() + + def get_all_input_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_input_labels() + + def get_all_unified_labels(self) -> dict[str, LabelResult]: + return self._db.get_all_unified_labels() + + def get_output_label(self, component_key: str) -> LabelResult | None: + return self._db.get_output_label(component_key) + + def get_input_label(self, component_key: str) -> LabelResult | None: + return self._db.get_input_label(component_key) + + def get_unified_label(self, component_key: str) -> LabelResult | None: + return self._db.get_unified_label(component_key) + + # -- Edges ----------------------------------------------------------------- + + def get_prompt_edges(self, component_key: str) -> list[PromptEdge]: + return self._db.get_prompt_edges(component_key) + + def get_all_prompt_edges(self) -> list[PromptEdge]: + return self._db.get_all_prompt_edges() + + # -- Stats ----------------------------------------------------------------- + + def get_label_counts(self) -> dict[str, int]: + return { + "output": self._db.get_label_count("output_labels"), + "input": self._db.get_label_count("input_labels"), + "unified": self._db.get_label_count("unified_labels"), + } diff --git a/spd/graph_interp/schemas.py b/spd/graph_interp/schemas.py new file mode 100644 index 000000000..ad391e270 --- /dev/null +++ b/spd/graph_interp/schemas.py @@ -0,0 +1,37 @@ +"""Data types and path helpers for graph interpretation.""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +from spd.settings import SPD_OUT_DIR + +GRAPH_INTERP_DIR = SPD_OUT_DIR / "graph_interp" + + +def get_graph_interp_dir(decomposition_id: str) -> Path: + return GRAPH_INTERP_DIR / decomposition_id + + +def get_graph_interp_subrun_dir(decomposition_id: str, subrun_id: str) -> Path: + return get_graph_interp_dir(decomposition_id) / subrun_id + + +@dataclass +class LabelResult: + component_key: str + label: str + confidence: str + reasoning: str + raw_response: str + prompt: str + + +@dataclass +class PromptEdge: + component_key: str + related_key: str + pass_name: Literal["output", "input"] + attribution: float + related_label: str | None + related_confidence: str | None diff --git a/spd/graph_interp/scripts/__init__.py b/spd/graph_interp/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/graph_interp/scripts/export_html.py b/spd/graph_interp/scripts/export_html.py new file mode 100644 index 000000000..14a5955da --- /dev/null +++ b/spd/graph_interp/scripts/export_html.py @@ -0,0 +1,268 @@ +"""Export graph interpretation data to JSON for the static HTML page. + +Usage: + python -m spd.graph_interp.scripts.export_html s-17805b61 + python -m spd.graph_interp.scripts.export_html s-17805b61 --subrun_id ti-20260223_213443 + python -m spd.graph_interp.scripts.export_html s-17805b61 --mock +""" + +import json +import random +from dataclasses import asdict +from typing import Any + +from spd.graph_interp.repo import GraphInterpRepo +from spd.graph_interp.schemas import LabelResult, get_graph_interp_dir +from spd.settings import SPD_OUT_DIR + +WWW_DIR = SPD_OUT_DIR / "www" +DATA_DIR = WWW_DIR / "data" + + +def _label_to_dict(label: LabelResult) -> dict[str, str]: + return { + "label": label.label, + "confidence": label.confidence, + "reasoning": label.reasoning, + } + + +def _parse_component_key(key: str) -> tuple[str, int]: + layer, idx_str = key.rsplit(":", 1) + return layer, int(idx_str) + + +def export_from_repo(repo: GraphInterpRepo) -> dict[str, Any]: + output_labels = repo.get_all_output_labels() + input_labels = repo.get_all_input_labels() + unified_labels = repo.get_all_unified_labels() + + all_keys = sorted( + set(output_labels) | set(input_labels) | set(unified_labels), + key=lambda k: (_parse_component_key(k)[0], _parse_component_key(k)[1]), + ) + + components = [] + for key in all_keys: + layer, component_idx = _parse_component_key(key) + entry: dict[str, Any] = { + "key": key, + "layer": layer, + "component_idx": component_idx, + } + if key in output_labels: + entry["output_label"] = _label_to_dict(output_labels[key]) + if key in input_labels: + entry["input_label"] = _label_to_dict(input_labels[key]) + if key in unified_labels: + entry["unified_label"] = _label_to_dict(unified_labels[key]) + + edges = repo.get_prompt_edges(key) + if edges: + entry["edges"] = [asdict(e) for e in edges] + + components.append(entry) + + label_counts = repo.get_label_counts() + + return { + "decomposition_id": repo.run_id, + "subrun_id": repo.subrun_id, + "label_counts": label_counts, + "components": components, + } + + +def generate_mock_data(decomposition_id: str) -> dict[str, Any]: + random.seed(42) + + layers = [ + "h.0.mlp.c_fc", + "h.0.mlp.down_proj", + "h.0.attn.q_proj", + "h.0.attn.k_proj", + "h.0.attn.v_proj", + "h.0.attn.o_proj", + "h.1.mlp.c_fc", + "h.1.mlp.down_proj", + "h.1.attn.q_proj", + "h.1.attn.k_proj", + "h.1.attn.v_proj", + "h.1.attn.o_proj", + ] + + output_labels_pool = [ + "sentence-final punctuation and period prediction", + "proper nouns and character name completions", + "emotional adjectives describing characters", + "temporal adverbs and time-related transitions", + "morphological suffix completion (-ing, -ed, -ly)", + "determiners preceding concrete nouns", + "dialogue-opening quotation marks and speech verbs", + "plural noun suffixes after quantity words", + "conjunction and clause boundary detection", + "verb tense agreement and auxiliary verbs", + "spatial prepositions and location descriptors", + "possessive pronouns and genitive markers", + "narrative action verbs (walked, looked, said)", + "abstract emotion nouns (fear, joy, anger)", + "comparative and superlative adjective forms", + ] + + input_labels_pool = [ + "punctuation and common function words", + "sentence-initial capital letters and proper nouns", + "mid-sentence verbs following subject nouns", + "adjective-noun boundaries in descriptive phrases", + "clause-final positions before conjunctions", + "article-noun sequences in noun phrases", + "subject pronouns at clause boundaries", + "preposition-object sequences", + "verb stems preceding inflectional suffixes", + "quotation marks and dialogue boundaries", + "comma-separated list items", + "sentence-medial adverbs after auxiliaries", + "concrete nouns following determiners", + "coordinating conjunctions between clauses", + "word stems requiring morphological completion", + ] + + unified_labels_pool = [ + "sentence termination tracking and terminal punctuation prediction", + "character name recognition and proper noun completion", + "emotional state description through adjective selection", + "temporal transition signaling via adverbs and tense markers", + "morphological word completion from stems to suffixed forms", + "noun phrase construction: determiners predicting concrete nouns", + "dialogue framing through quotation marks and speech attribution", + "plural morphology following quantifiers and numerals", + "clause coordination and syntactic boundary marking", + "verbal agreement and auxiliary verb selection", + "spatial relationship encoding via prepositional phrases", + "possessive construction and genitive case marking", + "narrative action sequencing through core verbs", + "abstract emotional vocabulary and sentiment expression", + "degree modification and comparative construction", + ] + + confidences = ["high", "high", "high", "medium", "medium", "low"] + + reasoning_templates = [ + "The output function focuses on {output_focus}, while the input function responds to {input_focus}. Together, this component acts as a bridge between {bridge_from} and {bridge_to}, consistent with its position in {layer}.", + "This component's output pattern of {output_focus} is activated by {input_focus} in the input. The unified interpretation captures how {bridge_from} contexts trigger {bridge_to} predictions.", + "Downstream context shows this component feeds into {output_focus} pathways, while upstream context reveals activation by {input_focus}. The synthesis reflects a coherent role in {bridge_from}-to-{bridge_to} processing.", + ] + + focus_terms = [ + "punctuation patterns", + "noun completions", + "verb inflections", + "emotional descriptors", + "syntactic boundaries", + "morphological suffixes", + "dialogue markers", + "temporal signals", + "spatial relationships", + ] + + components = [] + for layer in layers: + n_components = random.randint(8, 20) + indices = sorted(random.sample(range(500), n_components)) + for idx in indices: + key = f"{layer}:{idx}" + conf = random.choice(confidences) + output_conf = random.choice(confidences) + input_conf = random.choice(confidences) + + output_label = random.choice(output_labels_pool) + input_label = random.choice(input_labels_pool) + unified_label = random.choice(unified_labels_pool) + + reasoning = random.choice(reasoning_templates).format( + output_focus=random.choice(focus_terms), + input_focus=random.choice(focus_terms), + bridge_from=random.choice(focus_terms), + bridge_to=random.choice(focus_terms), + layer=layer, + ) + + components.append( + { + "key": key, + "layer": layer, + "component_idx": idx, + "output_label": { + "label": output_label, + "confidence": output_conf, + "reasoning": f"Output: {reasoning}", + }, + "input_label": { + "label": input_label, + "confidence": input_conf, + "reasoning": f"Input: {reasoning}", + }, + "unified_label": { + "label": unified_label, + "confidence": conf, + "reasoning": reasoning, + }, + } + ) + + return { + "decomposition_id": decomposition_id, + "subrun_id": "ti-mock", + "label_counts": { + "output": len(components), + "input": len(components), + "unified": len(components), + }, + "components": components, + } + + +def main( + decomposition_id: str, + subrun_id: str | None = None, + mock: bool = False, +) -> None: + DATA_DIR.mkdir(parents=True, exist_ok=True) + out_path = DATA_DIR / f"graph_interp_{decomposition_id}.json" + + if mock: + data = generate_mock_data(decomposition_id) + print(f"Generated mock data: {len(data['components'])} components") + else: + if subrun_id is not None: + base_dir = get_graph_interp_dir(decomposition_id) + subrun_dir = base_dir / subrun_id + assert subrun_dir.exists(), f"Subrun dir not found: {subrun_dir}" + db_path = subrun_dir / "interp.db" + assert db_path.exists(), f"No interp.db in {subrun_dir}" + from spd.graph_interp.db import GraphInterpDB + + db = GraphInterpDB(db_path, readonly=True) + repo = GraphInterpRepo(db=db, subrun_dir=subrun_dir, run_id=decomposition_id) + else: + repo = GraphInterpRepo.open(decomposition_id) + if repo is None: + print(f"No graph interp data for {decomposition_id}. Generating mock data instead.") + data = generate_mock_data(decomposition_id) + with open(out_path, "w") as f: + json.dump(data, f) + print(f"Wrote mock data to {out_path}") + return + + data = export_from_repo(repo) + print(f"Exported {len(data['components'])} components from {data['subrun_id']}") + + with open(out_path, "w") as f: + json.dump(data, f) + print(f"Wrote {out_path}") + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/graph_interp/scripts/run.py b/spd/graph_interp/scripts/run.py new file mode 100644 index 000000000..f6c1ff492 --- /dev/null +++ b/spd/graph_interp/scripts/run.py @@ -0,0 +1,104 @@ +"""CLI entry point for graph interpretation. + +Called by SLURM or directly: + python -m spd.graph_interp.scripts.run --config_json '{...}' +""" + +import os +from datetime import datetime +from typing import Any + +from dotenv import load_dotenv + +from spd.adapters import adapter_from_id +from spd.adapters.spd import SPDAdapter +from spd.dataset_attributions.repo import AttributionRepo +from spd.graph_interp.config import GraphInterpConfig +from spd.graph_interp.interpret import run_graph_interp +from spd.graph_interp.schemas import get_graph_interp_subrun_dir +from spd.harvest.repo import HarvestRepo +from spd.log import logger + + +def main( + decomposition_id: str, + config_json: dict[str, Any], + harvest_subrun_id: str | None = None, + subrun_id: str | None = None, +) -> None: + assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" + config = GraphInterpConfig.model_validate(config_json) + + load_dotenv() + openrouter_api_key = os.environ.get("OPENROUTER_API_KEY") + assert openrouter_api_key, "OPENROUTER_API_KEY not set" + + if subrun_id is None: + subrun_id = "ti-" + datetime.now().strftime("%Y%m%d_%H%M%S") + subrun_dir = get_graph_interp_subrun_dir(decomposition_id, subrun_id) + subrun_dir.mkdir(parents=True, exist_ok=True) + config.to_file(subrun_dir / "config.yaml") + db_path = subrun_dir / "interp.db" + logger.info(f"Graph interp run: {subrun_dir}") + + logger.info("Loading adapter and model metadata...") + adapter = adapter_from_id(decomposition_id) + assert isinstance(adapter, SPDAdapter) + logger.info("Loading harvest data...") + if harvest_subrun_id is not None: + harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=True) + else: + harvest = HarvestRepo.open_most_recent(decomposition_id, readonly=True) + assert harvest is not None, f"No harvest data for {decomposition_id}" + + logger.info("Loading dataset attributions...") + attributions = AttributionRepo.open(decomposition_id) + assert attributions is not None, f"Dataset attributions required for {decomposition_id}" + attribution_storage = attributions.get_attributions() + logger.info( + f" {attribution_storage.n_components} components, {attribution_storage.n_tokens_processed:,} tokens" + ) + + logger.info("Loading component correlations...") + correlations = harvest.get_correlations() + assert correlations is not None, f"Component correlations required for {decomposition_id}" + + logger.info("Loading token stats...") + token_stats = harvest.get_token_stats() + assert token_stats is not None, f"Token stats required for {decomposition_id}" + + logger.info("Data loading complete") + + run_graph_interp( + openrouter_api_key=openrouter_api_key, + config=config, + harvest=harvest, + attribution_storage=attribution_storage, + correlation_storage=correlations, + token_stats=token_stats, + model_metadata=adapter.model_metadata, + db_path=db_path, + tokenizer_name=adapter.tokenizer_name, + ) + + +def get_command( + decomposition_id: str, + config: GraphInterpConfig, + harvest_subrun_id: str | None = None, +) -> str: + config_json = config.model_dump_json(exclude_none=True) + cmd = ( + "python -m spd.graph_interp.scripts.run " + f"--decomposition_id {decomposition_id} " + f"--config_json '{config_json}' " + ) + if harvest_subrun_id is not None: + cmd += f"--harvest_subrun_id {harvest_subrun_id} " + return cmd + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/graph_interp/scripts/run_slurm.py b/spd/graph_interp/scripts/run_slurm.py new file mode 100644 index 000000000..fed1c146b --- /dev/null +++ b/spd/graph_interp/scripts/run_slurm.py @@ -0,0 +1,69 @@ +"""SLURM launcher for graph interpretation. + +Submits a single CPU job that runs the three-phase interpretation pipeline. +Depends on both harvest merge and attribution merge jobs. +""" + +from dataclasses import dataclass + +from spd.graph_interp.config import GraphInterpSlurmConfig +from spd.graph_interp.scripts import run +from spd.log import logger +from spd.utils.slurm import SlurmConfig, SubmitResult, generate_script, submit_slurm_job + + +@dataclass +class GraphInterpSubmitResult: + result: SubmitResult + + +def submit_graph_interp( + decomposition_id: str, + config: GraphInterpSlurmConfig, + dependency_job_ids: list[str], + snapshot_branch: str | None = None, + harvest_subrun_id: str | None = None, +) -> GraphInterpSubmitResult: + """Submit graph interpretation to SLURM. + + Args: + decomposition_id: ID of the target decomposition. + config: Graph interp SLURM configuration. + dependency_job_ids: Jobs to wait for (harvest merge + attribution merge). + snapshot_branch: Git snapshot branch to use. + harvest_subrun_id: Specific harvest subrun to use. + """ + cmd = run.get_command( + decomposition_id=decomposition_id, + config=config.config, + harvest_subrun_id=harvest_subrun_id, + ) + + dependency_str = ":".join(dependency_job_ids) if dependency_job_ids else None + + slurm_config = SlurmConfig( + job_name="spd-graph-interp", + partition=config.partition, + n_gpus=0, + cpus_per_task=16, + mem="240G", + time=config.time, + snapshot_branch=snapshot_branch, + dependency_job_id=dependency_str, + comment=decomposition_id, + ) + script_content = generate_script(slurm_config, cmd) + result = submit_slurm_job(script_content, "spd-graph-interp") + + logger.section("Graph interp job submitted") + logger.values( + { + "Job ID": result.job_id, + "Decomposition ID": decomposition_id, + "Model": config.config.model, + "Depends on": ", ".join(dependency_job_ids), + "Log": result.log_pattern, + } + ) + + return GraphInterpSubmitResult(result=result) diff --git a/spd/graph_interp/scripts/run_slurm_cli.py b/spd/graph_interp/scripts/run_slurm_cli.py new file mode 100644 index 000000000..a40fbee0b --- /dev/null +++ b/spd/graph_interp/scripts/run_slurm_cli.py @@ -0,0 +1,27 @@ +"""CLI entry point for graph interp SLURM launcher. + +Thin wrapper for fast --help. Heavy imports deferred to run_slurm.py. + +Usage: + spd-graph-interp --config graph_interp_config.yaml +""" + +import fire + + +def main(decomposition_id: str, config: str) -> None: + """Submit graph interpretation pipeline to SLURM. + + Args: + decomposition_id: ID of the target decomposition run. + config: Path to GraphInterpSlurmConfig YAML/JSON. + """ + from spd.graph_interp.config import GraphInterpSlurmConfig + from spd.graph_interp.scripts.run_slurm import submit_graph_interp + + slurm_config = GraphInterpSlurmConfig.from_file(config) + submit_graph_interp(decomposition_id, slurm_config, dependency_job_ids=[]) + + +def cli() -> None: + fire.Fire(main) diff --git a/spd/harvest/config.py b/spd/harvest/config.py index 8c3afba4e..cc01b3cd0 100644 --- a/spd/harvest/config.py +++ b/spd/harvest/config.py @@ -68,7 +68,7 @@ class IntruderEvalConfig(BaseConfig): max_concurrent: int = 50 limit: int | None = None cost_limit_usd: float | None = None - max_requests_per_minute: int = 200 + max_requests_per_minute: int = 500 class IntruderSlurmConfig(BaseConfig): @@ -83,7 +83,7 @@ class HarvestConfig(BaseConfig): method_config: Annotated[DecompositionMethodHarvestConfig, Field(discriminator="type")] n_batches: int | Literal["whole_dataset"] = 20_000 batch_size: int = 32 - activation_examples_per_component: int = 1000 + activation_examples_per_component: int = 400 activation_context_tokens_per_side: int = 20 pmi_token_top_k: int = 40 max_examples_per_batch_per_component: int = 5 diff --git a/spd/harvest/db.py b/spd/harvest/db.py index 10573c276..ebcd0fb02 100644 --- a/spd/harvest/db.py +++ b/spd/harvest/db.py @@ -74,9 +74,6 @@ def _deserialize_component(row: sqlite3.Row) -> ComponentData: class HarvestDB: def __init__(self, db_path: Path, readonly: bool = False) -> None: if readonly: - # immutable=1 skips ALL locking — required on network filesystems where - # SQLite's locking protocol fails. Safe because readers only open DBs - # that are fully written and closed by a prior pipeline stage. self._conn = sqlite3.connect( f"file:{db_path}?immutable=1", uri=True, check_same_thread=False ) diff --git a/spd/harvest/intruder.py b/spd/harvest/intruder.py index f91a5e0c2..4fb9b2052 100644 --- a/spd/harvest/intruder.py +++ b/spd/harvest/intruder.py @@ -175,7 +175,9 @@ async def run_intruder_scoring( jobs: list[LLMJob] = [] ground_truth: dict[str, _TrialGroundTruth] = {} - for component in remaining: + for i, component in enumerate(remaining): + if i > 0 and i % 1000 == 0: + logger.info(f"Building trials: {i}/{len(remaining)} components") for trial_idx in range(n_trials): real_examples = rng.sample(component.activation_examples, n_real) intruder = _sample_intruder(component, density_index, rng, density_tolerance) @@ -194,6 +196,7 @@ async def run_intruder_scoring( component_key=component.component_key, correct_answer=correct_answer, ) + logger.info(f"Built {len(jobs)} trials") component_trials: defaultdict[str, list[IntruderTrial]] = defaultdict(list) component_errors: defaultdict[str, int] = defaultdict(int) diff --git a/spd/harvest/repo.py b/spd/harvest/repo.py index f3a2be107..d1eb7019b 100644 --- a/spd/harvest/repo.py +++ b/spd/harvest/repo.py @@ -55,8 +55,10 @@ def open_most_recent( db_path = subrun_dir / "harvest.db" if not db_path.exists(): + logger.info(f"No harvest data found for {decomposition_id}") return None + logger.info(f"Opening harvest data for {decomposition_id} from {subrun_dir}") subrun_id = subrun_dir.name return cls(decomposition_id=decomposition_id, subrun_id=subrun_id, readonly=readonly) diff --git a/spd/harvest/scripts/run_intruder.py b/spd/harvest/scripts/run_intruder.py index 4dbef810a..1870d1a06 100644 --- a/spd/harvest/scripts/run_intruder.py +++ b/spd/harvest/scripts/run_intruder.py @@ -9,6 +9,7 @@ from spd.harvest.db import HarvestDB from spd.harvest.intruder import run_intruder_scoring from spd.harvest.repo import HarvestRepo +from spd.log import logger def main( @@ -28,7 +29,9 @@ def main( harvest = HarvestRepo(decomposition_id, subrun_id=harvest_subrun_id, readonly=True) score_db = HarvestDB(harvest._dir / "harvest.db") + logger.info("Loading components from harvest DB...") components = harvest.get_all_components() + logger.info(f"Loaded {len(components)} components") asyncio.run( run_intruder_scoring( diff --git a/spd/harvest/scripts/run_slurm.py b/spd/harvest/scripts/run_slurm.py index 77a4ce0ef..7775bedac 100644 --- a/spd/harvest/scripts/run_slurm.py +++ b/spd/harvest/scripts/run_slurm.py @@ -38,6 +38,7 @@ def submit_harvest( config: HarvestSlurmConfig, job_suffix: str | None = None, snapshot_branch: str | None = None, + dependency_job_id: str | None = None, ) -> HarvestSubmitResult: """Submit multi-GPU harvest job to SLURM. @@ -76,6 +77,7 @@ def submit_harvest( n_gpus=1, time=time, snapshot_branch=snapshot_branch, + dependency_job_id=dependency_job_id, comment=config.config.method_config.id, ) array_script = generate_array_script(array_config, worker_commands) diff --git a/spd/investigate/CLAUDE.md b/spd/investigate/CLAUDE.md new file mode 100644 index 000000000..922734220 --- /dev/null +++ b/spd/investigate/CLAUDE.md @@ -0,0 +1,118 @@ +# Investigation Module + +Launch a Claude Code agent to investigate a specific research question about an SPD model decomposition. + +## Usage + +```bash +spd-investigate "How does the model handle gendered pronouns?" +spd-investigate "What circuit handles verb agreement?" --max_turns 30 --time 4:00:00 +``` + +For parallel investigations, run the command multiple times with different prompts. + +## Architecture + +``` +spd/investigate/ +├── __init__.py # Public exports +├── CLAUDE.md # This file +├── schemas.py # Pydantic models for outputs (BehaviorExplanation, InvestigationEvent) +├── agent_prompt.py # System prompt template with model info injection +└── scripts/ + ├── __init__.py + ├── run_slurm_cli.py # CLI entry point (spd-investigate) + ├── run_slurm.py # SLURM submission logic + └── run_agent.py # Worker script (runs in SLURM job) +``` + +## How It Works + +1. `spd-investigate` creates output dir, metadata, git snapshot, and submits a single SLURM job +2. The SLURM job runs `run_agent.py` which: + - Starts an isolated FastAPI backend with MCP support + - Loads the SPD run onto GPU + - Fetches model architecture info + - Generates the agent prompt (research question + model context + methodology) + - Launches Claude Code with MCP tools +3. The agent investigates using MCP tools and writes findings to the output directory + +## MCP Tools + +The agent accesses all SPD functionality via MCP at `/mcp`: + +**Circuit Discovery:** +- `optimize_graph` — Find minimal circuit for a behavior (streams progress) +- `create_prompt` — Tokenize text and get next-token probabilities + +**Component Analysis:** +- `get_component_info` — Interpretation, token stats, correlations +- `probe_component` — Fast CI probing on custom text +- `get_component_activation_examples` — Training examples where a component fires +- `get_component_attributions` — Dataset-level component dependencies +- `get_attribution_strength` — Attribution between specific component pairs + +**Testing:** +- `run_ablation` — Test circuit with only selected components +- `search_dataset` — Search training data + +**Metadata:** +- `get_model_info` — Architecture details + +**Output:** +- `update_research_log` — Append to research log (PRIMARY OUTPUT) +- `save_graph_artifact` — Save graph for inline visualization +- `save_explanation` — Save complete behavior explanation +- `set_investigation_summary` — Set title/summary for UI + +## Output Structure + +``` +SPD_OUT_DIR/investigations// +├── metadata.json # Investigation config (wandb_path, prompt, etc.) +├── research_log.md # Human-readable progress log (PRIMARY OUTPUT) +├── events.jsonl # Structured progress events +├── explanations.jsonl # Complete behavior explanations +├── summary.json # Agent-provided title/summary for UI +├── artifacts/ # Graph artifacts for visualization +│ └── graph_001.json +├── app.db # Isolated SQLite database +├── backend.log # Backend subprocess output +├── claude_output.jsonl # Raw Claude Code output +├── agent_prompt.md # The prompt given to the agent +└── mcp_config.json # MCP server configuration +``` + +## Environment + +The backend runs with `SPD_INVESTIGATION_DIR` set to the investigation directory. This controls: +- Database location: `/app.db` +- Events log: `/events.jsonl` +- Research log: `/research_log.md` + +## Configuration + +CLI arguments: +- `wandb_path` — Required. WandB run path for the SPD decomposition. +- `prompt` — Required. Research question or investigation directive. +- `--context_length` — Token context length (default: 128) +- `--max_turns` — Max Claude turns (default: 50, prevents runaway) +- `--partition` — SLURM partition (default: h200-reserved) +- `--time` — Job time limit (default: 8:00:00) +- `--job_suffix` — Optional suffix for job names + +## Monitoring + +```bash +# Watch research log +tail -f SPD_OUT_DIR/investigations//research_log.md + +# Watch events +tail -f SPD_OUT_DIR/investigations//events.jsonl + +# View explanations +cat SPD_OUT_DIR/investigations//explanations.jsonl | jq . + +# Check SLURM job status +squeue --me +``` diff --git a/spd/investigate/__init__.py b/spd/investigate/__init__.py new file mode 100644 index 000000000..9e666dd7d --- /dev/null +++ b/spd/investigate/__init__.py @@ -0,0 +1,22 @@ +"""Investigation: SLURM-based agent investigation of model behaviors. + +This module provides infrastructure for launching a Claude Code agent to investigate +behaviors in an SPD model decomposition. Each investigation: +1. Starts an isolated app backend instance (separate database, unique port) +2. Receives a specific research question and detailed instructions +3. Investigates behaviors and writes findings to append-only JSONL files +""" + +from spd.investigate.schemas import ( + BehaviorExplanation, + ComponentInfo, + Evidence, + InvestigationEvent, +) + +__all__ = [ + "BehaviorExplanation", + "ComponentInfo", + "Evidence", + "InvestigationEvent", +] diff --git a/spd/investigate/agent_prompt.py b/spd/investigate/agent_prompt.py new file mode 100644 index 000000000..d53a47ac3 --- /dev/null +++ b/spd/investigate/agent_prompt.py @@ -0,0 +1,211 @@ +"""System prompt for SPD investigation agents. + +This module contains the detailed instructions given to the investigation agent. +The agent has access to SPD tools via MCP - tools are self-documenting. +""" + +from typing import Any + +AGENT_SYSTEM_PROMPT = """ +# SPD Behavior Investigation Agent + +You are a research agent investigating behaviors in a neural network model decomposition. +A researcher has given you a specific question to investigate. Your job is to answer it +thoroughly using the SPD analysis tools available to you. + +## Your Mission + +{prompt} + +## Available Tools (via MCP) + +You have access to SPD analysis tools. Use them directly - they have full documentation. + +**Circuit Discovery:** +- **optimize_graph**: Find the minimal circuit for a behavior (e.g., "boy" → "he") +- **create_prompt**: Tokenize text and get next-token probabilities + +**Component Analysis:** +- **get_component_info**: Get interpretation and token stats for a component +- **probe_component**: Fast CI probing - test if a component activates on specific text +- **get_component_activation_examples**: See training examples where a component fires +- **get_component_attributions**: Dataset-level component dependencies (sources and targets) +- **get_attribution_strength**: Query attribution strength between two specific components + +**Testing:** +- **run_ablation**: Test a circuit by running with only selected components +- **search_dataset**: Find examples in the training data + +**Metadata:** +- **get_model_info**: Get model architecture details +- **get_stored_graphs**: Retrieve previously computed graphs + +**Output:** +- **update_research_log**: Append to your research log (PRIMARY OUTPUT - use frequently!) +- **save_graph_artifact**: Save a graph for inline visualization in your research log +- **save_explanation**: Save a complete, validated behavior explanation +- **set_investigation_summary**: Set a title and summary for your investigation + +## Investigation Methodology + +### Step 1: Understand the Question + +Read the research question carefully. Think about what behaviors, components, or mechanisms +might be relevant. Use `get_model_info` if you need to understand the model architecture. + +### Step 2: Explore and Hypothesize + +- Use `create_prompt` to test prompts and see what the model predicts +- Use `search_dataset` to find relevant examples in the training data +- Use `probe_component` to quickly test whether specific components respond to your prompts +- Use `get_component_info` to understand what components do + +### Step 3: Find Circuits + +- Use `optimize_graph` to find the minimal circuit for specific behaviors +- Examine which components have high CI values +- Note the circuit size (fewer active components = cleaner mechanism) + +### Step 4: Understand Component Roles + +For each important component in a circuit: +1. Use `get_component_info` for interpretation and token associations +2. Use `probe_component` to test activation on different inputs +3. Use `get_component_activation_examples` to see training examples +4. Use `get_component_attributions` to understand information flow +5. Check correlated components for related functions + +### Step 5: Test with Ablations + +Form hypotheses and test them: +1. Use `run_ablation` with the circuit's components +2. Verify predictions match expectations +3. Try removing individual components to find critical ones + +### Step 6: Document Your Findings + +Use `update_research_log` frequently - this is how humans monitor your work! +When you have a complete explanation, use `save_explanation` to create a structured record. + +## Scientific Principles + +- **Be skeptical**: Your first hypothesis is probably incomplete +- **Triangulate**: Don't rely on a single type of evidence +- **Document uncertainty**: Note what you're confident in vs. uncertain about +- **Consider alternatives**: What else could explain the behavior? + +## Output Format + +### Research Log (PRIMARY OUTPUT - Update frequently!) + +Use `update_research_log` with markdown content. Call it every few minutes to show progress: + +Example calls: +``` +update_research_log("## Hypothesis: Gendered Pronoun Circuit\\n\\nTesting prompt: 'The boy said that' → expecting ' he'\\n\\n") + +update_research_log("## Ablation Test\\n\\nResult: P(he) = 0.89 (vs 0.22 baseline)\\n\\nThis confirms the circuit is sufficient!\\n\\n") +``` + +### Including Graph Visualizations + +After running `optimize_graph`, embed the circuit visualization in your research log: + +1. Call `save_graph_artifact` with the graph_id returned by optimize_graph +2. Reference it in your research log using the `spd:graph` code block + +Example: +``` +save_graph_artifact(graph_id=42, caption="Circuit predicting 'he' after 'The boy'") + +update_research_log('''## Circuit Visualization + +```spd:graph +artifact: graph_001 +``` + +This circuit shows the key components involved in predicting "he"... +''') +``` + +### Saving Explanations + +When you have a complete explanation, use `save_explanation`: + +``` +save_explanation( + subject_prompt="The boy said that", + behavior_description="Predicts masculine pronoun 'he' after male subject", + components_involved=[ + {{"component_key": "h.0.mlp.c_fc:407", "role": "Male subject detector"}}, + {{"component_key": "h.3.attn.o_proj:262", "role": "Masculine pronoun promoter"}} + ], + explanation="Component h.0.mlp.c_fc:407 activates on male subjects...", + confidence="medium", + limitations=["Only tested on simple sentences"] +) +``` + +## Getting Started + +1. **Create your research log** with `update_research_log` +2. Understand the research question and plan your approach +3. Use analysis tools to explore the model +4. **Call `update_research_log` frequently** - humans are watching! +5. Use `save_explanation` for complete findings +6. **Call `set_investigation_summary`** with a title and summary when done + +Document what you learn, even if it's "this was more complicated than expected." +""" + + +def _format_model_info(model_info: dict[str, Any]) -> str: + """Format model architecture info for inclusion in the agent prompt.""" + parts = [f"- **Architecture**: {model_info.get('summary', 'Unknown')}"] + + target_config = model_info.get("target_model_config") + if target_config: + if "n_layer" in target_config: + parts.append(f"- **Layers**: {target_config['n_layer']}") + if "n_embd" in target_config: + parts.append(f"- **Hidden dim**: {target_config['n_embd']}") + if "vocab_size" in target_config: + parts.append(f"- **Vocab size**: {target_config['vocab_size']}") + if "n_ctx" in target_config: + parts.append(f"- **Context length**: {target_config['n_ctx']}") + + topology = model_info.get("topology") + if topology and topology.get("block_structure"): + block = topology["block_structure"][0] + attn = ", ".join(block.get("attn_projections", [])) + ffn = ", ".join(block.get("ffn_projections", [])) + parts.append(f"- **Attention projections**: {attn}") + parts.append(f"- **FFN projections**: {ffn}") + + return "\n".join(parts) + + +def get_agent_prompt( + wandb_path: str, + prompt: str, + model_info: dict[str, Any], +) -> str: + """Generate the full agent prompt with runtime parameters filled in.""" + formatted_prompt = AGENT_SYSTEM_PROMPT.format(prompt=prompt) + + model_section = f""" +## Model Architecture + +{_format_model_info(model_info)} + +## Runtime Context + +- **Model Run**: {wandb_path} + +Use the MCP tools for ALL output: +- `update_research_log` → **PRIMARY OUTPUT** - Update frequently with your progress! +- `save_explanation` → Save complete, validated behavior explanations + +**Start by calling update_research_log to create your log, then investigate!** +""" + return formatted_prompt + model_section diff --git a/spd/investigate/schemas.py b/spd/investigate/schemas.py new file mode 100644 index 000000000..d4da1a896 --- /dev/null +++ b/spd/investigate/schemas.py @@ -0,0 +1,104 @@ +"""Schemas for investigation outputs. + +All agent outputs are append-only JSONL files. Each line is a JSON object +conforming to one of the schemas defined here. +""" + +from datetime import UTC, datetime +from typing import Any, Literal + +from pydantic import BaseModel, Field + + +class ComponentInfo(BaseModel): + """Information about a component involved in a behavior.""" + + component_key: str = Field( + ..., + description="Component key in format 'layer:component_idx' (e.g., 'h.0.mlp.c_fc:5')", + ) + role: str = Field( + ..., + description="The role this component plays in the behavior (e.g., 'stores subject gender')", + ) + interpretation: str | None = Field( + default=None, + description="Auto-interp label for this component if available", + ) + + +class Evidence(BaseModel): + """A piece of supporting evidence for an explanation.""" + + evidence_type: Literal["ablation", "attribution", "activation_pattern", "correlation", "other"] + description: str = Field( + ..., + description="Description of the evidence", + ) + details: dict[str, Any] = Field( + default_factory=dict, + description="Additional structured details (e.g., ablation results, attribution values)", + ) + + +class BehaviorExplanation(BaseModel): + """A candidate explanation for a behavior discovered by an agent. + + This is the primary output schema for agent investigations. Each explanation + describes a behavior (demonstrated by a subject prompt), the components involved, + and supporting evidence. + """ + + subject_prompt: str = Field( + ..., + description="A prompt that demonstrates the behavior being explained", + ) + behavior_description: str = Field( + ..., + description="Clear description of the behavior (e.g., 'correctly predicts gendered pronoun')", + ) + components_involved: list[ComponentInfo] = Field( + ..., + description="List of components involved in this behavior and their roles", + ) + explanation: str = Field( + ..., + description="Explanation of how the components work together to produce the behavior", + ) + supporting_evidence: list[Evidence] = Field( + default_factory=list, + description="Evidence supporting this explanation (ablations, attributions, etc.)", + ) + confidence: Literal["high", "medium", "low"] = Field( + ..., + description="Agent's confidence in this explanation", + ) + alternative_hypotheses: list[str] = Field( + default_factory=list, + description="Alternative hypotheses that were considered but not fully supported", + ) + limitations: list[str] = Field( + default_factory=list, + description="Known limitations of this explanation", + ) + + +class InvestigationEvent(BaseModel): + """A generic event logged by an agent during investigation. + + Used for logging progress, observations, and other non-explanation events. + """ + + event_type: Literal[ + "start", + "progress", + "observation", + "hypothesis", + "test_result", + "explanation", + "error", + "complete", + ] + timestamp: datetime = Field(default_factory=lambda: datetime.now(UTC)) + message: str + details: dict[str, Any] = Field(default_factory=dict) diff --git a/spd/investigate/scripts/__init__.py b/spd/investigate/scripts/__init__.py new file mode 100644 index 000000000..ff51f7654 --- /dev/null +++ b/spd/investigate/scripts/__init__.py @@ -0,0 +1 @@ +"""Investigation SLURM scripts.""" diff --git a/spd/investigate/scripts/run_agent.py b/spd/investigate/scripts/run_agent.py new file mode 100644 index 000000000..54806ed36 --- /dev/null +++ b/spd/investigate/scripts/run_agent.py @@ -0,0 +1,306 @@ +"""Worker script that runs inside each SLURM job. + +This script: +1. Reads the research question from the investigation metadata +2. Starts the app backend with an isolated database +3. Loads the SPD run and fetches model architecture info +4. Configures MCP server for Claude Code +5. Launches Claude Code with the investigation question +6. Handles cleanup on exit +""" + +import json +import os +import signal +import socket +import subprocess +import sys +import time +from pathlib import Path +from types import FrameType +from typing import Any + +import fire +import requests + +from spd.investigate.agent_prompt import get_agent_prompt +from spd.investigate.schemas import InvestigationEvent +from spd.investigate.scripts.run_slurm import get_investigation_output_dir +from spd.log import logger + + +def write_mcp_config(inv_dir: Path, port: int) -> Path: + """Write MCP configuration file for Claude Code.""" + mcp_config = { + "mcpServers": { + "spd": { + "type": "http", + "url": f"http://localhost:{port}/mcp", + } + } + } + config_path = inv_dir / "mcp_config.json" + config_path.write_text(json.dumps(mcp_config, indent=2)) + return config_path + + +def write_claude_settings(inv_dir: Path) -> None: + """Write Claude Code settings to pre-grant MCP tool permissions.""" + claude_dir = inv_dir / ".claude" + claude_dir.mkdir(exist_ok=True) + settings = {"permissions": {"allow": ["mcp__spd__*"]}} + (claude_dir / "settings.json").write_text(json.dumps(settings, indent=2)) + + +def find_available_port(start_port: int = 8000, max_attempts: int = 100) -> int: + """Find an available port starting from start_port.""" + for offset in range(max_attempts): + port = start_port + offset + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + try: + s.bind(("localhost", port)) + return port + except OSError: + continue + raise RuntimeError( + f"Could not find available port in range {start_port}-{start_port + max_attempts}" + ) + + +def wait_for_backend(port: int, timeout: float = 120.0) -> bool: + """Wait for the backend to become healthy.""" + url = f"http://localhost:{port}/api/health" + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + return True + except requests.exceptions.ConnectionError: + pass + time.sleep(1) + return False + + +def load_run(port: int, wandb_path: str, context_length: int) -> None: + """Load the SPD run into the backend. Raises on failure.""" + url = f"http://localhost:{port}/api/runs/load" + params = {"wandb_path": wandb_path, "context_length": context_length} + resp = requests.post(url, params=params, timeout=300) + assert resp.status_code == 200, ( + f"Failed to load run {wandb_path}: {resp.status_code} {resp.text}" + ) + + +def fetch_model_info(port: int) -> dict[str, Any]: + """Fetch model architecture info from the backend.""" + resp = requests.get(f"http://localhost:{port}/api/pretrain_info/loaded", timeout=30) + assert resp.status_code == 200, f"Failed to fetch model info: {resp.status_code} {resp.text}" + result: dict[str, Any] = resp.json() + return result + + +def log_event(events_path: Path, event: InvestigationEvent) -> None: + """Append an event to the events log.""" + with open(events_path, "a") as f: + f.write(event.model_dump_json() + "\n") + + +def run_agent( + wandb_path: str, + inv_id: str, + context_length: int = 128, + max_turns: int = 50, +) -> None: + """Run a single investigation agent. + + Args: + wandb_path: WandB path of the SPD run. + inv_id: Unique identifier for this investigation. + context_length: Context length for prompts. + max_turns: Maximum agentic turns before stopping (prevents runaway agents). + """ + inv_dir = get_investigation_output_dir(inv_id) + assert inv_dir.exists(), f"Investigation directory does not exist: {inv_dir}" + + # Read prompt from metadata + metadata: dict[str, Any] = json.loads((inv_dir / "metadata.json").read_text()) + prompt = metadata["prompt"] + + write_claude_settings(inv_dir) + + events_path = inv_dir / "events.jsonl" + (inv_dir / "explanations.jsonl").touch() + + log_event( + events_path, + InvestigationEvent( + event_type="start", + message=f"Investigation {inv_id} starting", + details={"wandb_path": wandb_path, "inv_id": inv_id, "prompt": prompt}, + ), + ) + + port = find_available_port() + logger.info(f"[{inv_id}] Using port {port}") + + log_event( + events_path, + InvestigationEvent( + event_type="progress", + message=f"Starting backend on port {port}", + details={"port": port}, + ), + ) + + # Start backend with investigation configuration + env = os.environ.copy() + env["SPD_INVESTIGATION_DIR"] = str(inv_dir) + + backend_cmd = [ + sys.executable, + "-m", + "spd.app.backend.server", + "--port", + str(port), + ] + + backend_log_path = inv_dir / "backend.log" + backend_log = open(backend_log_path, "w") # noqa: SIM115 - managed manually + backend_proc = subprocess.Popen( + backend_cmd, + env=env, + stdout=backend_log, + stderr=subprocess.STDOUT, + ) + + def cleanup(signum: int | None = None, frame: FrameType | None = None) -> None: + _ = frame + logger.info(f"[{inv_id}] Cleaning up...") + if backend_proc.poll() is None: + backend_proc.terminate() + try: + backend_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + backend_proc.kill() + backend_log.close() + if signum is not None: + sys.exit(1) + + signal.signal(signal.SIGTERM, cleanup) + signal.signal(signal.SIGINT, cleanup) + + try: + logger.info(f"[{inv_id}] Waiting for backend...") + if not wait_for_backend(port): + log_event( + events_path, + InvestigationEvent(event_type="error", message="Backend failed to start"), + ) + raise RuntimeError("Backend failed to start") + + logger.info(f"[{inv_id}] Backend ready, loading run...") + log_event( + events_path, + InvestigationEvent(event_type="progress", message="Backend ready, loading run"), + ) + + load_run(port, wandb_path, context_length) + + logger.info(f"[{inv_id}] Run loaded, fetching model info...") + model_info = fetch_model_info(port) + + logger.info(f"[{inv_id}] Launching Claude Code...") + log_event( + events_path, + InvestigationEvent( + event_type="progress", message="Run loaded, launching Claude Code agent" + ), + ) + + agent_prompt = get_agent_prompt( + wandb_path=wandb_path, + prompt=prompt, + model_info=model_info, + ) + + (inv_dir / "agent_prompt.md").write_text(agent_prompt) + + mcp_config_path = write_mcp_config(inv_dir, port) + logger.info(f"[{inv_id}] MCP config written to {mcp_config_path}") + + claude_output_path = inv_dir / "claude_output.jsonl" + claude_cmd = [ + "claude", + "--print", + "--verbose", + "--output-format", + "stream-json", + "--max-turns", + str(max_turns), + # MCP: only our backend, no inherited servers + "--mcp-config", + str(mcp_config_path), + # Permissions: only MCP tools, deny everything else + "--permission-mode", + "dontAsk", + "--allowedTools", + "mcp__spd__*", + # Isolation: skip all user/project settings (no plugins, no inherited config) + "--setting-sources", + "", + "--model", + "opus", + ] + + logger.info(f"[{inv_id}] Starting Claude Code (max_turns={max_turns})...") + logger.info(f"[{inv_id}] Monitor with: tail -f {claude_output_path}") + + with open(claude_output_path, "w") as output_file: + claude_proc = subprocess.Popen( + claude_cmd, + stdin=subprocess.PIPE, + stdout=output_file, + stderr=subprocess.STDOUT, + text=True, + cwd=str(inv_dir), + ) + + assert claude_proc.stdin is not None + claude_proc.stdin.write(agent_prompt) + claude_proc.stdin.close() + + claude_proc.wait() + + log_event( + events_path, + InvestigationEvent( + event_type="complete", + message="Investigation complete", + details={"exit_code": claude_proc.returncode}, + ), + ) + + logger.info(f"[{inv_id}] Investigation complete") + + except Exception as e: + log_event( + events_path, + InvestigationEvent( + event_type="error", + message=f"Agent failed: {e}", + details={"error_type": type(e).__name__}, + ), + ) + logger.error(f"[{inv_id}] Failed: {e}") + raise + finally: + cleanup() + + +def cli() -> None: + fire.Fire(run_agent) + + +if __name__ == "__main__": + cli() diff --git a/spd/investigate/scripts/run_slurm.py b/spd/investigate/scripts/run_slurm.py new file mode 100644 index 000000000..703ed2f78 --- /dev/null +++ b/spd/investigate/scripts/run_slurm.py @@ -0,0 +1,95 @@ +"""SLURM submission logic for investigation jobs.""" + +import json +import secrets +import sys +from dataclasses import dataclass +from pathlib import Path + +from spd.log import logger +from spd.settings import SPD_OUT_DIR +from spd.utils.git_utils import create_git_snapshot +from spd.utils.slurm import SlurmConfig, generate_script, submit_slurm_job +from spd.utils.wandb_utils import parse_wandb_run_path + + +@dataclass +class InvestigationResult: + inv_id: str + job_id: str + output_dir: Path + + +def get_investigation_output_dir(inv_id: str) -> Path: + return SPD_OUT_DIR / "investigations" / inv_id + + +def launch_investigation( + wandb_path: str, + prompt: str, + context_length: int, + max_turns: int, + partition: str, + time: str, + job_suffix: str | None, +) -> InvestigationResult: + """Launch a single investigation agent via SLURM. + + Creates a SLURM job that starts an isolated app backend, loads the SPD run, + and launches a Claude Code agent with the given research question. + """ + # Normalize wandb_path to canonical form (entity/project/run_id) + entity, project, run_id = parse_wandb_run_path(wandb_path) + canonical_wandb_path = f"{entity}/{project}/{run_id}" + + inv_id = f"inv-{secrets.token_hex(4)}" + output_dir = get_investigation_output_dir(inv_id) + output_dir.mkdir(parents=True, exist_ok=True) + + snapshot_branch, commit_hash = create_git_snapshot(inv_id) + + suffix = f"-{job_suffix}" if job_suffix else "" + job_name = f"spd-investigate{suffix}" + + metadata = { + "inv_id": inv_id, + "wandb_path": canonical_wandb_path, + "prompt": prompt, + "context_length": context_length, + "max_turns": max_turns, + "snapshot_branch": snapshot_branch, + "commit_hash": commit_hash, + } + (output_dir / "metadata.json").write_text(json.dumps(metadata, indent=2)) + + cmd = ( + f"{sys.executable} -m spd.investigate.scripts.run_agent " + f'"{wandb_path}" ' + f"--inv_id {inv_id} " + f"--context_length {context_length} " + f"--max_turns {max_turns}" + ) + + slurm_config = SlurmConfig( + job_name=job_name, + partition=partition, + n_gpus=1, + time=time, + snapshot_branch=snapshot_branch, + ) + script = generate_script(slurm_config, cmd) + result = submit_slurm_job(script, "investigate") + + logger.section("Investigation submitted") + logger.values( + { + "Investigation ID": inv_id, + "Job ID": result.job_id, + "WandB path": canonical_wandb_path, + "Prompt": prompt[:100] + ("..." if len(prompt) > 100 else ""), + "Output directory": str(output_dir), + "Logs": result.log_pattern, + } + ) + + return InvestigationResult(inv_id=inv_id, job_id=result.job_id, output_dir=output_dir) diff --git a/spd/investigate/scripts/run_slurm_cli.py b/spd/investigate/scripts/run_slurm_cli.py new file mode 100644 index 000000000..df784de61 --- /dev/null +++ b/spd/investigate/scripts/run_slurm_cli.py @@ -0,0 +1,59 @@ +"""CLI entry point for investigation SLURM launcher. + +Usage: + spd-investigate "" + spd-investigate @prompt.txt + spd-investigate "" --max_turns 30 +""" + +from pathlib import Path + +import fire + +from spd.settings import DEFAULT_PARTITION_NAME + + +def _resolve_prompt(prompt: str) -> str: + """If prompt starts with @, read from that file path. Otherwise return as-is.""" + if prompt.startswith("@"): + path = Path(prompt[1:]) + assert path.exists(), f"Prompt file not found: {path}" + return path.read_text().strip() + return prompt + + +def main( + wandb_path: str, + prompt: str, + context_length: int = 128, + max_turns: int = 50, + partition: str = DEFAULT_PARTITION_NAME, + time: str = "8:00:00", + job_suffix: str | None = None, +) -> None: + """Launch a single investigation agent for a specific question. + + Args: + wandb_path: WandB run path for the SPD decomposition to investigate. + prompt: The research question, or @filepath to read from a file. + context_length: Context length for prompts (default 128). + max_turns: Maximum agentic turns (default 50, prevents runaway). + partition: SLURM partition name. + time: Job time limit (default 8 hours). + job_suffix: Optional suffix for SLURM job names. + """ + from spd.investigate.scripts.run_slurm import launch_investigation + + launch_investigation( + wandb_path=wandb_path, + prompt=_resolve_prompt(prompt), + context_length=context_length, + max_turns=max_turns, + partition=partition, + time=time, + job_suffix=job_suffix, + ) + + +def cli() -> None: + fire.Fire(main) diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index fb708ceb5..6d546c86a 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -30,11 +30,11 @@ def _init_adv_sources( match pgd_config.mask_scope: case "unique_per_datapoint": shape = torch.Size([*batch_dims, mask_c]) - source = _get_pgd_init_tensor(pgd_config.init, shape, device) + source = get_pgd_init_tensor(pgd_config.init, shape, device) case "shared_across_batch": singleton_batch_dims = [1 for _ in batch_dims] shape = torch.Size([*singleton_batch_dims, mask_c]) - source = broadcast_tensor(_get_pgd_init_tensor(pgd_config.init, shape, device)) + source = broadcast_tensor(get_pgd_init_tensor(pgd_config.init, shape, device)) adv_sources[module_name] = source.requires_grad_(True) return adv_sources @@ -86,7 +86,7 @@ def _construct_mask_infos_from_adv_sources( adv_sources_components = {k: v[..., :-1] for k, v in expanded_adv_sources.items()} return make_mask_infos( - component_masks=_interpolate_component_mask(ci, adv_sources_components), + component_masks=interpolate_pgd_mask(ci, adv_sources_components), weight_deltas_and_masks=weight_deltas_and_masks, routing_masks=routing_masks, ) @@ -171,7 +171,7 @@ def calc_multibatch_pgd_masked_recon_loss( mask_c = module_c if not use_delta_component else module_c + 1 shape = torch.Size([*singleton_batch_dims, mask_c]) adv_sources[module_name] = broadcast_tensor( - _get_pgd_init_tensor(pgd_config.init, shape, device) + get_pgd_init_tensor(pgd_config.init, shape, device) ).requires_grad_(True) fwd_bwd_fn = partial( @@ -301,11 +301,12 @@ def _multibatch_pgd_fwd_bwd( return pgd_step_accum_sum_loss, pgd_step_accum_n_examples, pgd_step_accum_grads -def _get_pgd_init_tensor( +def get_pgd_init_tensor( init: PGDInitStrategy, shape: tuple[int, ...], device: torch.device | str, ) -> Float[Tensor, "... shape"]: + """Create initial PGD source tensor (random, ones, or zeroes). Shared by training PGD and app eval PGD.""" match init: case "random": return torch.rand(shape, device=device) @@ -315,7 +316,7 @@ def _get_pgd_init_tensor( return torch.zeros(shape, device=device) -def _interpolate_component_mask( +def interpolate_pgd_mask( ci: dict[str, Float[Tensor, "*batch_dims C"]], adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]], ) -> dict[str, Float[Tensor, "*batch_dims C"]]: diff --git a/spd/persistent_pgd.py b/spd/persistent_pgd.py index 1200829f1..fc7a77be1 100644 --- a/spd/persistent_pgd.py +++ b/spd/persistent_pgd.py @@ -138,6 +138,7 @@ def __init__( self._use_sigmoid_parameterization = cfg.use_sigmoid_parameterization self._router = _get_router_for_ppgd_config(cfg, device) self._n_warmup_steps = cfg.n_warmup_steps + self._n_samples = cfg.n_samples self._output_loss_type: Literal["mse", "kl"] = output_loss_type self._lr_schedule = cfg.optimizer.lr_schedule @@ -219,8 +220,11 @@ def warmup( Each step computes the recon loss, extracts gradients, and updates sources in-place. When n_warmup_steps=0 (default), this is a no-op. """ + all_layers = AllLayersRouter() for _ in range(self._n_warmup_steps): - loss = self.compute_recon_loss(model, batch, target_out, ci, weight_deltas) + loss = self.compute_recon_loss( + model, batch, target_out, ci, weight_deltas, router=all_layers + ) grads = self.get_grads(loss, retain_graph=False) self.step(grads) @@ -231,23 +235,32 @@ def compute_recon_loss( target_out: Float[Tensor, "... vocab"], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + router: Router | None = None, ) -> Float[Tensor, ""]: """Pure forward pass that returns the PPGD reconstruction loss. No source mutation.""" batch_dims = next(iter(ci.values())).shape[:-1] - routing_masks = self._router.get_masks( - module_names=model.target_module_paths, mask_shape=batch_dims - ) + router = router or self._router ppgd_sources = self.get_effective_sources() - sum_loss, n_examples = _compute_ppgd_recon_loss( - model=model, - ppgd_sources=ppgd_sources, - output_loss_type=self._output_loss_type, - batch=batch, - target_out=target_out, - ci=ci, - weight_deltas=weight_deltas, - routing_masks=routing_masks, - ) + + device = next(iter(ci.values())).device + sum_loss = torch.tensor(0.0, device=device) + n_examples = 0 + for _ in range(self._n_samples): + routing_masks = router.get_masks( + module_names=model.target_module_paths, mask_shape=batch_dims + ) + loss, n = _compute_ppgd_recon_loss( + model=model, + ppgd_sources=ppgd_sources, + output_loss_type=self._output_loss_type, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=weight_deltas, + routing_masks=routing_masks, + ) + sum_loss = sum_loss + loss + n_examples += n return sum_loss / n_examples diff --git a/spd/postprocess/__init__.py b/spd/postprocess/__init__.py index e2feab509..35922436b 100644 --- a/spd/postprocess/__init__.py +++ b/spd/postprocess/__init__.py @@ -20,6 +20,7 @@ from spd.autointerp.scripts.run_slurm import AutointerpSubmitResult, submit_autointerp from spd.dataset_attributions.scripts.run_slurm import submit_attributions +from spd.graph_interp.scripts.run_slurm import GraphInterpSubmitResult, submit_graph_interp from spd.harvest.config import SPDHarvestConfig from spd.harvest.scripts import run_intruder from spd.harvest.scripts.run_slurm import submit_harvest @@ -30,9 +31,14 @@ from spd.utils.slurm import SlurmConfig, SubmitResult, generate_script, submit_slurm_job -def postprocess(config: PostprocessConfig) -> Path: +def postprocess(config: PostprocessConfig, dependency_job_id: str | None = None) -> Path: """Submit all postprocessing jobs with SLURM dependency chaining. + Args: + config: Postprocessing configuration. + dependency_job_id: SLURM job to wait for before starting harvest + (e.g. a training job that must complete first). + Returns: Path to the manifest YAML file. """ @@ -43,7 +49,11 @@ def postprocess(config: PostprocessConfig) -> Path: decomp_cfg = config.harvest.config.method_config # === 1. Harvest (always runs, upserts into harvest.db) === - harvest_result = submit_harvest(config.harvest, snapshot_branch=snapshot_branch) + harvest_result = submit_harvest( + config.harvest, + snapshot_branch=snapshot_branch, + dependency_job_id=dependency_job_id, + ) # === 2. Autointerp (depends on harvest, resumes via completed keys) === autointerp_result: AutointerpSubmitResult | None = None @@ -97,6 +107,21 @@ def postprocess(config: PostprocessConfig) -> Path: harvest_subrun_id=harvest_result.subrun_id, ) + # === 5. Graph interp (depends on harvest merge + attribution merge) === + graph_interp_result: GraphInterpSubmitResult | None = None + if config.graph_interp is not None: + assert attr_result is not None + graph_interp_result = submit_graph_interp( + decomposition_id=decomp_cfg.id, + config=config.graph_interp, + dependency_job_ids=[ + harvest_result.merge_result.job_id, + attr_result.merge_result.job_id, + ], + snapshot_branch=snapshot_branch, + harvest_subrun_id=harvest_result.subrun_id, + ) + # === Write manifest === manifest_id = "pp-" + datetime.now().strftime("%Y%m%d_%H%M%S") manifest_dir = SPD_OUT_DIR / "postprocess" / manifest_id @@ -120,6 +145,8 @@ def postprocess(config: PostprocessConfig) -> Path: jobs["detection"] = autointerp_result.detection_result.job_id if autointerp_result.fuzzing_result is not None: jobs["fuzzing"] = autointerp_result.fuzzing_result.job_id + if graph_interp_result is not None: + jobs["graph_interp"] = graph_interp_result.result.job_id manifest = { "timestamp": datetime.now().isoformat(timespec="seconds"), diff --git a/spd/postprocess/cli.py b/spd/postprocess/cli.py index 7a3c7db1d..7bd4672b0 100644 --- a/spd/postprocess/cli.py +++ b/spd/postprocess/cli.py @@ -3,35 +3,39 @@ Thin wrapper for fast --help. Heavy imports deferred to postprocess.py. Usage: - spd-postprocess - spd-postprocess --config my_config.yaml + spd-postprocess config.yaml + spd-postprocess config.yaml --dependency 311644_1 """ -import fire +import argparse -def main(config: str, dry_run: bool = False) -> None: - """Submit all postprocessing jobs for an SPD run. +def main() -> None: + parser = argparse.ArgumentParser(description="Submit all postprocessing jobs for an SPD run.") + parser.add_argument("config", help="Path to PostprocessConfig YAML.") + parser.add_argument( + "--dependency", + help="SLURM job ID to wait for before starting (e.g. a training job).", + ) + parser.add_argument("--dry_run", action="store_true") + args = parser.parse_args() - Args: - config: Path to PostprocessConfig YAML. - """ import yaml from spd.log import logger from spd.postprocess import postprocess from spd.postprocess.config import PostprocessConfig - cfg = PostprocessConfig.from_file(config) + cfg = PostprocessConfig.from_file(args.config) - if dry_run: + if args.dry_run: logger.info("Dry run: skipping submission\n\nConfig:\n") logger.info(yaml.dump(cfg.model_dump(), indent=2, sort_keys=False)) return - manifest_path = postprocess(config=cfg) + manifest_path = postprocess(config=cfg, dependency_job_id=args.dependency) logger.info(f"Manifest: {manifest_path}") def cli() -> None: - fire.Fire(main) + main() diff --git a/spd/postprocess/config.py b/spd/postprocess/config.py index 2164ab97c..858df94bb 100644 --- a/spd/postprocess/config.py +++ b/spd/postprocess/config.py @@ -9,6 +9,7 @@ from spd.autointerp.config import AutointerpSlurmConfig from spd.base_config import BaseConfig from spd.dataset_attributions.config import AttributionsSlurmConfig +from spd.graph_interp.config import GraphInterpSlurmConfig from spd.harvest.config import HarvestSlurmConfig, IntruderSlurmConfig, SPDHarvestConfig @@ -32,6 +33,7 @@ class PostprocessConfig(BaseConfig): autointerp: AutointerpSlurmConfig | None intruder: IntruderSlurmConfig | None attributions: AttributionsSlurmConfig | None + graph_interp: GraphInterpSlurmConfig | None @override def model_post_init(self, __context: Any) -> None: @@ -39,3 +41,12 @@ def model_post_init(self, __context: Any) -> None: is_not_spd = not isinstance(self.harvest.config.method_config, SPDHarvestConfig) if expects_attributions and is_not_spd: raise ValueError("Attributions only work for SPD decompositions") + if self.graph_interp is not None and self.attributions is None: + raise ValueError("Graph interp requires attributions") + + +if __name__ == "__main__": + import json + + with open("spd/postprocess/postprocess.schema.json", "w") as f: + json.dump(PostprocessConfig.model_json_schema(), f, indent=2) diff --git a/spd/postprocess/s-55ea3f9b.yaml b/spd/postprocess/s-55ea3f9b.yaml new file mode 100644 index 000000000..1504a06d0 --- /dev/null +++ b/spd/postprocess/s-55ea3f9b.yaml @@ -0,0 +1,32 @@ +harvest: + n_gpus: 16 + time: "24:00:00" + merge_time: "24:00:00" + config: + method_config: + type: SPDHarvestConfig + wandb_path: "wandb:goodfire/spd/s-55ea3f9b" + +autointerp: + time: "24:00:00" + config: + template_strategy: + type: compact_skeptical + forbidden_words: [] + cost_limit_usd: 400 + evals: null + +intruder: null + +attributions: + n_gpus: 16 + time: "24:00:00" + merge_time: "24:00:00" + config: + spd_run_wandb_path: "wandb:goodfire/spd/s-55ea3f9b" + n_batches: 640 + +graph_interp: + time: "24:00:00" + config: + cost_limit_usd: 400 diff --git a/spd/postprocess/s-82ffb969.yaml b/spd/postprocess/s-82ffb969.yaml new file mode 100644 index 000000000..b48bc4add --- /dev/null +++ b/spd/postprocess/s-82ffb969.yaml @@ -0,0 +1,71 @@ +# yaml-language-server: $schema=postprocess.schema.json + +harvest: + n_gpus: 32 + time: "24:00:00" + merge_time: "24:00:00" + config: + method_config: + type: SPDHarvestConfig + wandb_path: "wandb:goodfire/spd/s-82ffb969" + n_batches: 20000 + batch_size: 32 + activation_examples_per_component: 400 + activation_context_tokens_per_side: 20 + pmi_token_top_k: 40 + max_examples_per_batch_per_component: 5 + +autointerp: + time: "24:00:00" + config: + model: "google/gemini-3-flash-preview" + reasoning_effort: "low" + # model: "google/gemini-3.1-pro-preview" + # reasoning_effort: "low" + template_strategy: + type: dual_view + cost_limit_usd: 2000 + evals: + model: "google/gemini-3-flash-preview" + reasoning_effort: "low" + detection_config: + type: detection + n_activating: 5 + n_non_activating: 5 + n_trials: 5 + fuzzing_config: + type: fuzzing + n_correct: 5 + n_incorrect: 2 + n_trials: 5 + cost_limit_usd: 1000 + + +intruder: + config: + model: google/gemini-3-flash-preview + reasoning_effort: "none" + n_real: 4 + n_trials: 10 + density_tolerance: 0.05 + max_concurrent: 50 + max_requests_per_minute: 200 + cost_limit_usd: 1000 + +attributions: + n_gpus: 32 + time: "24:00:00" + merge_time: "24:00:00" + config: + spd_run_wandb_path: "wandb:goodfire/spd/s-82ffb969" + n_batches: 1280 + +graph_interp: + time: "24:00:00" + config: + model: "google/gemini-3-flash-preview" + reasoning_effort: "low" + # model: "google/gemini-3.1-pro-preview" + # reasoning_effort: "low" + cost_limit_usd: 2000 + diff --git a/spd/settings.py b/spd/settings.py index 9e3e37f7b..56d60ecfe 100644 --- a/spd/settings.py +++ b/spd/settings.py @@ -24,3 +24,5 @@ DEFAULT_PARTITION_NAME = "h200-reserved" DEFAULT_PROJECT_NAME = "spd" + +SPD_APP_DEFAULT_RUN: str | None = os.environ.get("SPD_APP_DEFAULT_RUN") diff --git a/spd/topology/gradient_connectivity.py b/spd/topology/gradient_connectivity.py index bcaac8423..31ba61b5a 100644 --- a/spd/topology/gradient_connectivity.py +++ b/spd/topology/gradient_connectivity.py @@ -74,19 +74,20 @@ def embed_hook( cache[f"{embed_path}_post_detach"] = embed_cache[f"{embed_path}_post_detach"] cache[f"{unembed_path}_pre_detach"] = comp_output_with_cache.output - layers = [embed_path, *model.target_module_paths, unembed_path] + source_layers = [embed_path, *model.target_module_paths] # Don't include "output" as source + target_layers = [*model.target_module_paths, unembed_path] # Don't include embed as target # Test all distinct pairs for gradient flow test_pairs = [] - for in_layer in layers[:-1]: # Don't include "output" as source - for out_layer in layers[1:]: # Don't include embed as target - if in_layer != out_layer: - test_pairs.append((in_layer, out_layer)) + for source_layer in source_layers: + for target_layer in target_layers: + if source_layer != target_layer: + test_pairs.append((source_layer, target_layer)) sources_by_target: dict[str, list[str]] = defaultdict(list) - for in_layer, out_layer in test_pairs: - out_pre_detach = cache[f"{out_layer}_pre_detach"] - in_post_detach = cache[f"{in_layer}_post_detach"] + for source_layer, target_layer in test_pairs: + out_pre_detach = cache[f"{target_layer}_pre_detach"] + in_post_detach = cache[f"{source_layer}_post_detach"] out_value = out_pre_detach[0, 0, 0] grads = torch.autograd.grad( outputs=out_value, @@ -97,5 +98,5 @@ def embed_hook( assert len(grads) == 1 grad = grads[0] if grad is not None: # pyright: ignore[reportUnnecessaryComparison] - sources_by_target[out_layer].append(in_layer) + sources_by_target[target_layer].append(source_layer) return dict(sources_by_target) diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index a8e878407..4d71807d6 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -31,7 +31,11 @@ # Regex patterns for parsing W&B run references # Run IDs can be 8 chars (e.g., "d2ec3bfe") or prefixed with char-dash (e.g., "s-d2ec3bfe") +DEFAULT_WANDB_ENTITY = "goodfire" +DEFAULT_WANDB_PROJECT = "spd" + _RUN_ID_PATTERN = r"(?:[a-z0-9]-)?[a-z0-9]{8}" +_BARE_RUN_ID_RE = re.compile(r"^(s-[a-z0-9]{8})$") _WANDB_PATH_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/({_RUN_ID_PATTERN})$") _WANDB_PATH_WITH_RUNS_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/runs/({_RUN_ID_PATTERN})$") _WANDB_URL_RE = re.compile( @@ -174,6 +178,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: """Parse various W&B run reference formats into (entity, project, run_id). Accepts: + - "s-xxxxxxxx" (bare SPD run ID, assumes goodfire/spd) - "entity/project/runId" (compact form) - "entity/project/runs/runId" (with /runs/) - "wandb:entity/project/runId" (with wandb: prefix) @@ -192,6 +197,10 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: if s.startswith("wandb:"): s = s[6:] + # Bare run ID (e.g. "s-17805b61") → default entity/project + if m := _BARE_RUN_ID_RE.match(s): + return DEFAULT_WANDB_ENTITY, DEFAULT_WANDB_PROJECT, m.group(1) + # Try compact form: entity/project/runid if m := _WANDB_PATH_RE.match(s): return m.group(1), m.group(2), m.group(3) @@ -206,6 +215,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: raise ValueError( f"Invalid W&B run reference. Expected one of:\n" + f' - "s-xxxxxxxx" (bare run ID)\n' f' - "entity/project/xxxxxxxx"\n' f' - "entity/project/runs/xxxxxxxx"\n' f' - "wandb:entity/project/runs/xxxxxxxx"\n' diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index f39ef385f..d2cd74057 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -16,6 +16,7 @@ from spd.app.backend.app_tokenizer import AppTokenizer from spd.app.backend.database import PromptAttrDB from spd.app.backend.routers import graphs as graphs_router +from spd.app.backend.routers import intervention as intervention_router from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app from spd.app.backend.state import RunState, StateManager @@ -54,6 +55,7 @@ def app_with_state(): # Patch DEVICE in all router modules to use CPU for tests with ( mock.patch.object(graphs_router, "DEVICE", DEVICE), + mock.patch.object(intervention_router, "DEVICE", DEVICE), mock.patch.object(runs_router, "DEVICE", DEVICE), ): db = PromptAttrDB(db_path=Path(":memory:"), check_same_thread=False) @@ -147,6 +149,7 @@ def app_with_state(): harvest=None, interp=None, attributions=None, + graph_interp=None, ) manager = StateManager.get() @@ -231,6 +234,49 @@ def test_compute_graph(app_with_prompt: tuple[TestClient, int]): assert "outputProbs" in data +def test_run_and_save_intervention_without_text(app_with_prompt: tuple[TestClient, int]): + """Run-and-save intervention should use graph-linked prompt tokens (no text in request).""" + client, prompt_id = app_with_prompt + + graph_response = client.post( + "/api/graphs", + params={"prompt_id": prompt_id, "normalize": "none", "ci_threshold": 0.0}, + ) + assert graph_response.status_code == 200 + events = [line for line in graph_response.text.strip().split("\n") if line.startswith("data:")] + final_data = json.loads(events[-1].replace("data: ", "")) + graph_data = final_data["data"] + graph_id = graph_data["id"] + + selected_nodes = [ + key + for key, ci in graph_data["nodeCiVals"].items() + if not key.startswith("embed:") and not key.startswith("output:") and ci > 0 + ] + assert len(selected_nodes) > 0 + + request = { + "graph_id": graph_id, + "selected_nodes": selected_nodes[: min(5, len(selected_nodes))], + "top_k": 5, + "adv_pgd": {"n_steps": 1, "step_size": 1.0}, + } + response = client.post("/api/intervention/run", json=request) + assert response.status_code == 200 + body = response.json() + assert body["selected_nodes"] == request["selected_nodes"] + result = body["result"] + assert len(result["input_tokens"]) > 0 + assert len(result["ci"]) > 0 + assert len(result["stochastic"]) > 0 + assert len(result["adversarial"]) > 0 + assert result["target_sans"] is None + assert "ci_loss" in result + assert "stochastic_loss" in result + assert "adversarial_loss" in result + assert result["target_sans_loss"] is None + + # ----------------------------------------------------------------------------- # Streaming: Prompt Generation # ----------------------------------------------------------------------------- diff --git a/tests/dataset_attributions/test_harvester.py b/tests/dataset_attributions/test_harvester.py deleted file mode 100644 index 96ebc5df8..000000000 --- a/tests/dataset_attributions/test_harvester.py +++ /dev/null @@ -1,265 +0,0 @@ -"""Tests for dataset attribution harvester logic.""" - -from pathlib import Path - -import torch - -from spd.dataset_attributions.storage import DatasetAttributionStorage - - -def _make_storage( - n_components: int = 2, - vocab_size: int = 3, - d_model: int = 4, - source_to_component: torch.Tensor | None = None, - source_to_out_residual: torch.Tensor | None = None, -) -> DatasetAttributionStorage: - """Helper to create storage with default values.""" - n_sources = vocab_size + n_components - if source_to_component is None: - source_to_component = torch.zeros(n_sources, n_components) - if source_to_out_residual is None: - source_to_out_residual = torch.zeros(n_sources, d_model) - - return DatasetAttributionStorage( - component_layer_keys=[f"layer1:{i}" for i in range(n_components)], - vocab_size=vocab_size, - d_model=d_model, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - n_batches_processed=10, - n_tokens_processed=1000, - ci_threshold=0.0, - ) - - -class TestDatasetAttributionStorage: - """Tests for DatasetAttributionStorage. - - Storage structure: - - source_to_component: (n_sources, n_components) for component target attributions - - source_to_out_residual: (n_sources, d_model) for output target attributions (via w_unembed) - """ - - def test_has_source_and_target(self) -> None: - """Test has_source and has_target methods.""" - storage = _make_storage(n_components=2, vocab_size=3) - - # wte tokens can only be sources - assert storage.has_source("wte:0") - assert storage.has_source("wte:2") - assert not storage.has_source("wte:3") # Out of vocab - assert not storage.has_target("wte:0") # wte can't be target - - # Component layers can be both sources and targets - assert storage.has_source("layer1:0") - assert storage.has_source("layer1:1") - assert storage.has_target("layer1:0") - assert storage.has_target("layer1:1") - assert not storage.has_source("layer1:2") - assert not storage.has_target("layer1:2") - - # output tokens can only be targets - assert storage.has_target("output:0") - assert storage.has_target("output:2") - assert not storage.has_target("output:3") # Out of vocab - assert not storage.has_source("output:0") # output can't be source - - def test_get_attribution_component_target(self) -> None: - """Test get_attribution for component targets (no w_unembed needed).""" - # 2 component layers: layer1:0, layer1:1 - # vocab_size=2, d_model=4 - # n_sources = 2 + 2 = 4 - # source_to_component shape: (4, 2) - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 -> components - [3.0, 4.0], # wte:1 -> components - [5.0, 6.0], # layer1:0 -> components - [7.0, 8.0], # layer1:1 -> components - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # wte:0 -> layer1:0 - assert storage.get_attribution("wte:0", "layer1:0") == 1.0 - # wte:1 -> layer1:1 - assert storage.get_attribution("wte:1", "layer1:1") == 4.0 - # layer1:0 -> layer1:1 - assert storage.get_attribution("layer1:0", "layer1:1") == 6.0 - - def test_get_attribution_output_target(self) -> None: - """Test get_attribution for output targets (requires w_unembed).""" - # source_to_out_residual shape: (4, 4) for n_sources=4, d_model=4 - source_to_out_residual = torch.tensor( - [ - [1.0, 0.0, 0.0, 0.0], # wte:0 -> out_residual - [0.0, 1.0, 0.0, 0.0], # wte:1 -> out_residual - [0.0, 0.0, 1.0, 0.0], # layer1:0 -> out_residual - [0.0, 0.0, 0.0, 1.0], # layer1:1 -> out_residual - ] - ) - # w_unembed shape: (d_model=4, vocab=2) - w_unembed = torch.tensor( - [ - [1.0, 2.0], # d0 -> outputs - [3.0, 4.0], # d1 -> outputs - [5.0, 6.0], # d2 -> outputs - [7.0, 8.0], # d3 -> outputs - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, d_model=4, source_to_out_residual=source_to_out_residual - ) - - # wte:0 -> output:0 = out_residual[0] @ w_unembed[:, 0] = [1,0,0,0] @ [1,3,5,7] = 1.0 - assert storage.get_attribution("wte:0", "output:0", w_unembed=w_unembed) == 1.0 - # wte:1 -> output:1 = [0,1,0,0] @ [2,4,6,8] = 4.0 - assert storage.get_attribution("wte:1", "output:1", w_unembed=w_unembed) == 4.0 - # layer1:0 -> output:0 = [0,0,1,0] @ [1,3,5,7] = 5.0 - assert storage.get_attribution("layer1:0", "output:0", w_unembed=w_unembed) == 5.0 - - def test_get_top_sources_component_target(self) -> None: - """Test get_top_sources for component targets.""" - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 - [5.0, 3.0], # wte:1 - [2.0, 4.0], # layer1:0 - [3.0, 1.0], # layer1:1 - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # Top sources TO layer1:0 (column 0): wte:0=1.0, wte:1=5.0, layer1:0=2.0, layer1:1=3.0 - sources = storage.get_top_sources("layer1:0", k=2, sign="positive") - assert len(sources) == 2 - assert sources[0].component_key == "wte:1" - assert sources[0].value == 5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == 3.0 - - def test_get_top_sources_negative(self) -> None: - """Test get_top_sources with negative sign.""" - source_to_component = torch.tensor( - [ - [-1.0, 2.0], - [-5.0, 3.0], - [-2.0, 4.0], - [-3.0, 1.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - sources = storage.get_top_sources("layer1:0", k=2, sign="negative") - assert len(sources) == 2 - # wte:1 has most negative (-5.0), then layer1:1 (-3.0) - assert sources[0].component_key == "wte:1" - assert sources[0].value == -5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == -3.0 - - def test_get_top_component_targets(self) -> None: - """Test get_top_component_targets (no w_unembed needed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - targets = storage.get_top_component_targets("layer1:0", k=2, sign="positive") - assert len(targets) == 2 - assert targets[0].component_key == "layer1:1" - assert targets[0].value == 4.0 - assert targets[1].component_key == "layer1:0" - assert targets[1].value == 2.0 - - def test_get_top_targets_with_outputs(self) -> None: - """Test get_top_targets including outputs (requires w_unembed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - # Make out_residual attribution that produces high output values - source_to_out_residual = torch.tensor( - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0], # layer1:0 -> out_residual (sum=4 per output) - [0.0, 0.0, 0.0, 0.0], - ] - ) - # w_unembed that gives output:0=10, output:1=5 - w_unembed = torch.tensor( - [ - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - ] - ) - storage = _make_storage( - n_components=2, - vocab_size=2, - d_model=4, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - ) - - targets = storage.get_top_targets("layer1:0", k=3, sign="positive", w_unembed=w_unembed) - assert len(targets) == 3 - # output:0 = 10.0, output:1 = 5.0, layer1:1 = 4.0 - assert targets[0].component_key == "output:0" - assert targets[0].value == 10.0 - assert targets[1].component_key == "output:1" - assert targets[1].value == 5.0 - assert targets[2].component_key == "layer1:1" - assert targets[2].value == 4.0 - - def test_save_and_load(self, tmp_path: Path) -> None: - """Test save and load roundtrip.""" - n_components = 2 - vocab_size = 3 - d_model = 4 - n_sources = vocab_size + n_components - - original = DatasetAttributionStorage( - component_layer_keys=["layer:0", "layer:1"], - vocab_size=vocab_size, - d_model=d_model, - source_to_component=torch.randn(n_sources, n_components), - source_to_out_residual=torch.randn(n_sources, d_model), - n_batches_processed=100, - n_tokens_processed=10000, - ci_threshold=0.01, - ) - - path = tmp_path / "test_attributions.pt" - original.save(path) - - loaded = DatasetAttributionStorage.load(path) - - assert loaded.component_layer_keys == original.component_layer_keys - assert loaded.vocab_size == original.vocab_size - assert loaded.d_model == original.d_model - assert loaded.n_batches_processed == original.n_batches_processed - assert loaded.n_tokens_processed == original.n_tokens_processed - assert loaded.ci_threshold == original.ci_threshold - assert torch.allclose(loaded.source_to_component, original.source_to_component) - assert torch.allclose(loaded.source_to_out_residual, original.source_to_out_residual) diff --git a/tests/dataset_attributions/test_storage.py b/tests/dataset_attributions/test_storage.py new file mode 100644 index 000000000..95fb788f7 --- /dev/null +++ b/tests/dataset_attributions/test_storage.py @@ -0,0 +1,181 @@ +"""Tests for DatasetAttributionStorage.""" + +from pathlib import Path + +import torch +from torch import Tensor + +from spd.dataset_attributions.storage import DatasetAttributionStorage + +VOCAB_SIZE = 4 +D_MODEL = 4 +LAYER_0 = "0.glu.up" +LAYER_1 = "1.glu.up" +C0 = 3 # components in layer 0 +C1 = 2 # components in layer 1 + + +def _make_storage(seed: int = 0, n_tokens: int = 640) -> DatasetAttributionStorage: + """Build storage for test topology. + + Sources by target: + "0.glu.up": ["embed"] -> embed edge (C0, VOCAB_SIZE) + "1.glu.up": ["embed", "0.glu.up"] -> embed edge (C1, VOCAB_SIZE) + regular (C1, C0) + "output": ["0.glu.up", "1.glu.up"] -> unembed (D_MODEL, C0), (D_MODEL, C1) + "output": ["embed"] -> embed_unembed (D_MODEL, VOCAB_SIZE) + """ + g = torch.Generator().manual_seed(seed) + + def rand(*shape: int) -> Tensor: + return torch.randn(*shape, generator=g) + + return DatasetAttributionStorage( + regular_attr={LAYER_1: {LAYER_0: rand(C1, C0)}}, + regular_attr_abs={LAYER_1: {LAYER_0: rand(C1, C0)}}, + embed_attr={LAYER_0: rand(C0, VOCAB_SIZE), LAYER_1: rand(C1, VOCAB_SIZE)}, + embed_attr_abs={LAYER_0: rand(C0, VOCAB_SIZE), LAYER_1: rand(C1, VOCAB_SIZE)}, + unembed_attr={LAYER_0: rand(D_MODEL, C0), LAYER_1: rand(D_MODEL, C1)}, + embed_unembed_attr=rand(D_MODEL, VOCAB_SIZE), + w_unembed=rand(D_MODEL, VOCAB_SIZE), + ci_sum={LAYER_0: rand(C0).abs() + 1.0, LAYER_1: rand(C1).abs() + 1.0}, + component_act_sq_sum={LAYER_0: rand(C0).abs() + 1.0, LAYER_1: rand(C1).abs() + 1.0}, + logit_sq_sum=rand(VOCAB_SIZE).abs() + 1.0, + embed_token_count=torch.randint(100, 1000, (VOCAB_SIZE,), generator=g), + ci_threshold=1e-6, + n_tokens_processed=n_tokens, + ) + + +class TestNComponents: + def test_counts_all_target_layers(self): + storage = _make_storage() + assert storage.n_components == C0 + C1 + + +class TestGetTopSources: + def test_component_target_returns_entries(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + assert all(r.value > 0 for r in results) + assert len(results) <= 5 + + def test_component_target_includes_embed(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=20, sign="positive", metric="attr") + layers = {r.layer for r in results} + assert "embed" in layers or LAYER_0 in layers + + def test_output_target(self): + storage = _make_storage() + results = storage.get_top_sources("output:0", k=5, sign="positive", metric="attr") + assert len(results) <= 5 + + def test_output_target_attr_abs_returns_empty(self): + storage = _make_storage() + results = storage.get_top_sources("output:0", k=5, sign="positive", metric="attr_abs") + assert results == [] + + def test_target_only_in_embed_attr(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_0}:0", k=5, sign="positive", metric="attr") + assert len(results) <= 5 + assert all(r.layer == "embed" for r in results) + + def test_attr_abs_metric(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr_abs") + assert len(results) <= 5 + + def test_no_nan_in_results(self): + storage = _make_storage() + results = storage.get_top_sources(f"{LAYER_1}:0", k=20, sign="positive", metric="attr") + assert all(not torch.isnan(torch.tensor(r.value)) for r in results) + + +class TestGetTopTargets: + def test_component_source(self): + storage = _make_storage() + results = storage.get_top_targets( + f"{LAYER_0}:0", k=5, sign="positive", metric="attr", include_outputs=False + ) + assert len(results) <= 5 + assert all(r.value > 0 for r in results) + + def test_embed_source(self): + storage = _make_storage() + results = storage.get_top_targets( + "embed:0", k=5, sign="positive", metric="attr", include_outputs=False + ) + assert len(results) <= 5 + + def test_include_outputs(self): + storage = _make_storage() + results = storage.get_top_targets(f"{LAYER_0}:0", k=20, sign="positive", metric="attr") + assert len(results) > 0 + + def test_embed_source_with_outputs(self): + storage = _make_storage() + results = storage.get_top_targets("embed:0", k=20, sign="positive", metric="attr") + assert len(results) > 0 + + def test_attr_abs_skips_output_targets(self): + storage = _make_storage() + results = storage.get_top_targets(f"{LAYER_0}:0", k=20, sign="positive", metric="attr_abs") + assert all(r.layer != "output" for r in results) + + +class TestSaveLoad: + def test_roundtrip(self, tmp_path: Path): + original = _make_storage() + path = tmp_path / "attrs.pt" + original.save(path) + + loaded = DatasetAttributionStorage.load(path) + + assert loaded.ci_threshold == original.ci_threshold + assert loaded.n_tokens_processed == original.n_tokens_processed + assert loaded.n_components == original.n_components + + def test_roundtrip_query_consistency(self, tmp_path: Path): + original = _make_storage() + path = tmp_path / "attrs.pt" + original.save(path) + loaded = DatasetAttributionStorage.load(path) + + orig_results = original.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + load_results = loaded.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + + assert len(orig_results) == len(load_results) + for orig, loaded in zip(orig_results, load_results, strict=True): + assert orig.component_key == loaded.component_key + assert abs(orig.value - loaded.value) < 1e-5 + + +class TestMerge: + def test_two_workers_additive(self, tmp_path: Path): + s1 = _make_storage(seed=0, n_tokens=320) + s2 = _make_storage(seed=42, n_tokens=320) + + p1 = tmp_path / "rank_0.pt" + p2 = tmp_path / "rank_1.pt" + s1.save(p1) + s2.save(p2) + + merged = DatasetAttributionStorage.merge([p1, p2]) + + assert merged.n_tokens_processed == 640 + + def test_single_file(self, tmp_path: Path): + original = _make_storage(seed=7, n_tokens=640) + path = tmp_path / "rank_0.pt" + original.save(path) + + merged = DatasetAttributionStorage.merge([path]) + + assert merged.n_tokens_processed == original.n_tokens_processed + + orig_results = original.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + merge_results = merged.get_top_sources(f"{LAYER_1}:0", k=5, sign="positive", metric="attr") + for o, m in zip(orig_results, merge_results, strict=True): + assert o.component_key == m.component_key + assert abs(o.value - m.value) < 1e-5 diff --git a/uv.lock b/uv.lock index a750ef28e..becdf51ec 100644 --- a/uv.lock +++ b/uv.lock @@ -824,6 +824,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, ] +[[package]] +name = "kaleido" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f7/0ccaa596ec341963adbb4f839774c36d5659e75a0812d946732b927d480e/kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl", hash = "sha256:ca6f73e7ff00aaebf2843f73f1d3bacde1930ef5041093fe76b83a15785049a7", size = 85153681, upload-time = "2021-03-08T10:27:34.202Z" }, + { url = "https://files.pythonhosted.org/packages/45/8e/4297556be5a07b713bb42dde0f748354de9a6918dee251c0e6bdcda341e7/kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bb9a5d1f710357d5d432ee240ef6658a6d124c3e610935817b4b42da9c787c05", size = 85808197, upload-time = "2021-03-08T10:27:46.561Z" }, + { url = "https://files.pythonhosted.org/packages/ae/b3/a0f0f4faac229b0011d8c4a7ee6da7c2dca0b6fd08039c95920846f23ca4/kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aa21cf1bf1c78f8fa50a9f7d45e1003c387bd3d6fe0a767cfbbf344b95bdc3a8", size = 79902476, upload-time = "2021-03-08T10:27:57.364Z" }, + { url = "https://files.pythonhosted.org/packages/a1/2b/680662678a57afab1685f0c431c2aba7783ce4344f06ec162074d485d469/kaleido-0.2.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:845819844c8082c9469d9c17e42621fbf85c2b237ef8a86ec8a8527f98b6512a", size = 83711746, upload-time = "2021-03-08T10:28:08.847Z" }, + { url = "https://files.pythonhosted.org/packages/88/89/4b6f8bb3f9ab036fd4ad1cb2d628ab5c81db32ac9aa0641d7b180073ba43/kaleido-0.2.1-py2.py3-none-win32.whl", hash = "sha256:ecc72635860be616c6b7161807a65c0dbd9b90c6437ac96965831e2e24066552", size = 62312480, upload-time = "2021-03-08T10:28:18.204Z" }, + { url = "https://files.pythonhosted.org/packages/f7/9a/0408b02a4bcb3cf8b338a2b074ac7d1b2099e2b092b42473def22f7b625f/kaleido-0.2.1-py2.py3-none-win_amd64.whl", hash = "sha256:4670985f28913c2d063c5734d125ecc28e40810141bdb0a46f15b76c1d45f23c", size = 65945521, upload-time = "2021-03-08T10:28:26.823Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.9" @@ -1950,6 +1963,7 @@ dependencies = [ { name = "httpx" }, { name = "ipykernel" }, { name = "jaxtyping" }, + { name = "kaleido" }, { name = "matplotlib" }, { name = "numpy" }, { name = "openrouter" }, @@ -1991,6 +2005,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.28.0" }, { name = "ipykernel" }, { name = "jaxtyping" }, + { name = "kaleido", specifier = "==0.2.1" }, { name = "matplotlib" }, { name = "numpy" }, { name = "openrouter", specifier = ">=0.1.1" },