diff --git a/spd/paper_vis/__init__.py b/spd/paper_vis/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/paper_vis/build_dashboard.py b/spd/paper_vis/build_dashboard.py new file mode 100644 index 000000000..e883c964b --- /dev/null +++ b/spd/paper_vis/build_dashboard.py @@ -0,0 +1,115 @@ +"""Build a dashboard with per-component JSON files for incremental loading. + +Outputs: + / + index.html — self-contained dashboard shell + vpd/manifest.json — VPD metadata (no heavy component data) + vpd/components/ — per-component JSON files + tc/manifest.json — transcoder metadata + tc/components/ — per-component JSON files + +Usage: + python -m spd.paper_vis.build_dashboard \ + --vpd_id s-55ea3f9b --tc_id tc-3f297233 \ + --out_dir dashboard_out --limit 50 +""" + +from pathlib import Path + +import fire +import orjson + +from spd.paper_vis.data import DecompositionData +from spd.paper_vis.generate import build_decomposition_data + + +def _component_index(data: DecompositionData) -> list[dict[str, object]]: + """Extract lightweight component index from full data (for inline embedding).""" + return [ + { + "component_key": c.component_key, + "layer": c.layer, + "layer_display": c.layer_display, + "component_idx": c.component_idx, + "firing_density": c.firing_density, + "mean_activation": c.mean_activation, + "label": c.label, + "confidence": c.confidence, + "detection_score": c.detection_score.model_dump() if c.detection_score else None, + "fuzzing_score": c.fuzzing_score.model_dump() if c.fuzzing_score else None, + } + for c in data.components + ] + + +def build( + out_dir: str = "dashboard_out", + vpd_id: str | None = None, + tc_id: str | None = None, + limit: int | None = None, +) -> None: + assert vpd_id or tc_id, "Provide at least one of --vpd_id or --tc_id" + + out = Path(out_dir) + out.mkdir(parents=True, exist_ok=True) + + manifest: dict[str, object] = {"vpd": None, "transcoder": None} + + if vpd_id: + print(f"Loading VPD data: {vpd_id}") + vpd_data = build_decomposition_data(vpd_id, "vpd", limit, out / "vpd") + manifest["vpd"] = { + **vpd_data.model_dump(exclude={"components"}), + "component_index": _component_index(vpd_data), + "components_path": "vpd/components", + } + + if tc_id: + print(f"Loading transcoder data: {tc_id}") + tc_data = build_decomposition_data(tc_id, "transcoder", limit, out / "tc") + manifest["transcoder"] = { + **tc_data.model_dump(exclude={"components"}), + "component_index": _component_index(tc_data), + "components_path": "tc/components", + } + + manifest_json = orjson.dumps(manifest).decode() + + # Standalone dashboard + dashboard_template = Path(__file__).parent / "dashboard.html" + dashboard_html = dashboard_template.read_text() + dashboard_html = dashboard_html.replace("/*DATA_JSON*/null", manifest_json) + (out / "index.html").write_text(dashboard_html) + print(f"Wrote dashboard to {out}/index.html") + + # Research post with dashboard inlined (no iframe) + post_template = Path(__file__).parent / "research_post.html" + if post_template.exists(): + post_html = post_template.read_text() + + # Extract dashboard body content (between tags), style, and script + import re + + style_match = re.search(r"", dashboard_html, re.DOTALL) + script_match = re.search(r"", dashboard_html, re.DOTALL) + body_match = re.search(r"(.*?)", dashboard_html, re.DOTALL) + + assert style_match and script_match and body_match + + dashboard_inline = ( + f"\n" + f"{body_match.group(1)}\n" + f"" + ) + + post_html = post_html.replace( + '
\n \n
', + f'
\n{dashboard_inline}\n
', + ) + + (out / "research_post.html").write_text(post_html) + print(f"Wrote research post to {out}/research_post.html") + + +if __name__ == "__main__": + fire.Fire(build) diff --git a/spd/paper_vis/dashboard.html b/spd/paper_vis/dashboard.html new file mode 100644 index 000000000..b0373954d --- /dev/null +++ b/spd/paper_vis/dashboard.html @@ -0,0 +1,746 @@ + + + + + + + Component Interpretability — VPD vs Transcoder + + + + + +
+ + + +
+ + + + + + + \ No newline at end of file diff --git a/spd/paper_vis/data.py b/spd/paper_vis/data.py new file mode 100644 index 000000000..3725576b7 --- /dev/null +++ b/spd/paper_vis/data.py @@ -0,0 +1,80 @@ +"""Data types for paper visualisation dashboards. + +JSON-serializable types that bridge the harvest/autointerp pipeline outputs +into static dashboard data. Each DecompositionData bundles everything needed +to render a component-level comparison dashboard for one decomposition method. +""" + +from pydantic import BaseModel + + +class TokenSpan(BaseModel): + """A token with its firing/activation state in context.""" + + token: str + is_firing: bool + activation: float + + +class ActivationExampleData(BaseModel): + """One activation example: a window of tokens around a firing.""" + + tokens: list[TokenSpan] + center_idx: int + + +class TokenPMIData(BaseModel): + """Top tokens by PMI for a component.""" + + top: list[tuple[str, float]] + bottom: list[tuple[str, float]] + + +class ScoreData(BaseModel): + """Autointerp eval score for a component.""" + + score: float + n_trials: int + + +class ComponentDashboardData(BaseModel): + """Everything we know about a single component, ready for the dashboard.""" + + component_key: str + layer: str + layer_display: str + component_idx: int + + # Harvest data + firing_density: float + mean_activation: float + activation_examples: list[ActivationExampleData] + input_token_pmi: TokenPMIData + output_token_pmi: TokenPMIData + + # Autointerp data (None if not yet interpreted) + label: str | None + confidence: str | None + reasoning: str | None + + # Scoring data (None if not yet scored) + detection_score: ScoreData | None + fuzzing_score: ScoreData | None + + +class DecompositionData(BaseModel): + """All dashboard data for one decomposition method.""" + + decomposition_id: str + method: str # "vpd" or "transcoder" + base_model: str + n_components: int + n_layers: int + components: list[ComponentDashboardData] + + +class ComparisonDashboardData(BaseModel): + """Top-level data for a VPD-vs-transcoder comparison dashboard.""" + + vpd: DecompositionData + transcoder: DecompositionData diff --git a/spd/paper_vis/generate.py b/spd/paper_vis/generate.py new file mode 100644 index 000000000..660a931c1 --- /dev/null +++ b/spd/paper_vis/generate.py @@ -0,0 +1,212 @@ +"""Generate dashboard JSON from harvest + autointerp data. + +Outputs: + - manifest.json: lightweight metadata + component list (inline in HTML) + - components/{component_key}.json: per-component full data (loaded on demand) + +Usage: + python -m spd.paper_vis.generate --decomposition_id s-55ea3f9b --method vpd --out_dir out/vpd +""" + +import contextlib +import json +from pathlib import Path + +import fire +import orjson + +from spd.adapters import adapter_from_id +from spd.app.backend.app_tokenizer import AppTokenizer +from spd.autointerp.prompt_helpers import human_layer_desc +from spd.autointerp.repo import InterpRepo +from spd.harvest.repo import HarvestRepo +from spd.harvest.schemas import ActivationExample, ComponentData +from spd.paper_vis.data import ( + ActivationExampleData, + ComponentDashboardData, + DecompositionData, + ScoreData, + TokenPMIData, + TokenSpan, +) + + +def _convert_pmi(pmi_data: list[tuple[int, float]], tok: AppTokenizer) -> list[tuple[str, float]]: + return [(tok.get_tok_display(tid), score) for tid, score in pmi_data] + + +def _convert_example( + example: "ActivationExample", tok: AppTokenizer, act_type: str +) -> ActivationExampleData: + activations = example.activations.get(act_type, [0.0] * len(example.token_ids)) + tokens = [ + TokenSpan( + token=tok.get_tok_display(tid), + is_firing=f, + activation=a, + ) + for tid, f, a in zip(example.token_ids, example.firings, activations, strict=True) + ] + center_idx = len(tokens) // 2 + return ActivationExampleData(tokens=tokens, center_idx=center_idx) + + +def _primary_activation_type(component: ComponentData) -> str: + if not component.activation_examples: + return "activation" + act_types = list(component.activation_examples[0].activations.keys()) + if "causal_importance" in act_types: + return "causal_importance" + return act_types[0] if act_types else "activation" + + +def _build_component( + comp: ComponentData, + tok: AppTokenizer, + interp_repo: InterpRepo | None, + detection_scores: dict[str, float], + fuzzing_scores: dict[str, float], + layer_descriptions: dict[str, str], + n_blocks: int, +) -> ComponentDashboardData: + act_type = _primary_activation_type(comp) + mean_act = comp.mean_activations.get(act_type, 0.0) + + examples = [_convert_example(ex, tok, act_type) for ex in comp.activation_examples[:20]] + + input_pmi = TokenPMIData( + top=_convert_pmi(comp.input_token_pmi.top, tok), + bottom=_convert_pmi(comp.input_token_pmi.bottom, tok), + ) + output_pmi = TokenPMIData( + top=_convert_pmi(comp.output_token_pmi.top, tok), + bottom=_convert_pmi(comp.output_token_pmi.bottom, tok), + ) + + label = None + confidence = None + reasoning = None + detection_score = None + fuzzing_score = None + + if interp_repo is not None: + interp = interp_repo.get_interpretation(comp.component_key) + if interp is not None: + label = interp.label + confidence = interp.confidence + reasoning = interp.reasoning + + det_val = detection_scores.get(comp.component_key) + if det_val is not None: + detection_score = ScoreData(score=det_val, n_trials=0) + + fuz_val = fuzzing_scores.get(comp.component_key) + if fuz_val is not None: + fuzzing_score = ScoreData(score=fuz_val, n_trials=0) + + canonical = layer_descriptions.get(comp.layer, comp.layer) + layer_display = human_layer_desc(canonical, n_blocks) + + return ComponentDashboardData( + component_key=comp.component_key, + layer=comp.layer, + layer_display=layer_display, + component_idx=comp.component_idx, + firing_density=comp.firing_density, + mean_activation=mean_act, + activation_examples=examples, + input_token_pmi=input_pmi, + output_token_pmi=output_pmi, + label=label, + confidence=confidence, + reasoning=reasoning, + detection_score=detection_score, + fuzzing_score=fuzzing_score, + ) + + +def build_decomposition_data( + decomposition_id: str, + method: str, + limit: int | None, + out_dir: Path, +) -> DecompositionData: + adapter = adapter_from_id(decomposition_id) + + harvest = HarvestRepo.open_most_recent(decomposition_id) + assert harvest is not None, f"No harvest data for {decomposition_id}" + + tok = AppTokenizer.from_pretrained(adapter.tokenizer_name) + + interp_repo: InterpRepo | None = None + with contextlib.suppress(Exception): + interp_repo = InterpRepo.open(decomposition_id) + + detection_scores: dict[str, float] = {} + fuzzing_scores: dict[str, float] = {} + if interp_repo is not None: + detection_scores = interp_repo.get_scores("detection") + fuzzing_scores = interp_repo.get_scores("fuzzing") + + metadata = adapter.model_metadata + layer_descriptions = metadata.layer_descriptions + n_blocks = metadata.n_blocks + + summaries = harvest.get_summary() + keys = list(summaries.keys()) + if limit is not None: + keys = keys[:limit] + + comp_dir = out_dir / "components" + comp_dir.mkdir(parents=True, exist_ok=True) + + dashboard_components: list[ComponentDashboardData] = [] + for i, key in enumerate(keys): + comp = harvest.get_component(key) + assert comp is not None + dash_comp = _build_component( + comp, + tok, + interp_repo, + detection_scores, + fuzzing_scores, + layer_descriptions, + n_blocks, + ) + dashboard_components.append(dash_comp) + + safe_key = key.replace(":", "_").replace("/", "_") + comp_path = comp_dir / f"{safe_key}.json" + comp_path.write_bytes(orjson.dumps(dash_comp.model_dump())) + + if (i + 1) % 50 == 0: + print(f" {i + 1}/{len(keys)} components", flush=True) + + print(f" {len(keys)}/{len(keys)} components done", flush=True) + + layers = adapter.layer_activation_sizes + return DecompositionData( + decomposition_id=decomposition_id, + method=method, + base_model=adapter.tokenizer_name, + n_components=sum(n for _, n in layers), + n_layers=len(layers), + components=dashboard_components, + ) + + +def main( + decomposition_id: str, + method: str, + out_dir: str, + limit: int | None = None, +) -> None: + out_path = Path(out_dir) + data = build_decomposition_data(decomposition_id, method, limit, out_path) + manifest = out_path / "manifest.json" + manifest.write_text(json.dumps(data.model_dump(exclude={"components"}), indent=2)) + print(f"Wrote {len(data.components)} components to {out_path}/") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/spd/paper_vis/research_post.html b/spd/paper_vis/research_post.html new file mode 100644 index 000000000..a760b4809 --- /dev/null +++ b/spd/paper_vis/research_post.html @@ -0,0 +1,322 @@ + + + + + +Interpreting Language Model Parameters — Goodfire Research + + + + + + + + + +
+

Interpreting Language Model Parameters Draft

+ + + + +

Introduction

+ +

Structure in the parameters of language models is responsible for their remarkable intelligence. The trainable parameters of these neural networks, in interaction with the architecture and dataset, learn to implement algorithms that we do not know how to design directly. On the one hand, deep learning thus affords us the ability to build machines to solve tasks that otherwise resist engineering solutions, and incidentally creates objects that are of great scientific interest in their own right. On the other hand, it means that an increasing portion of our daily lives depend on systems that we do not deeply understand.

+ +

A key barrier to understanding the computations that these systems perform is how best to decompose them into simpler parts that we can study in relative isolation. Naive choices of these parts—such as neurons, attention heads, or whole layers—don't always map to individual, interpretable computations.

+ +

Alternative approaches, such as transcoders or mixtures of linear transforms (MOLTs), typically involve fitting a set of simple functions to the transitions between activations at different layers in the network, and linearly combining the outputs of these simple functions. The idea is to approximate the complex, nonlinear function implemented by the network's layers using a simpler, easier to understand function. Unfortunately, because these simpler functions are of a different functional form to the original network, it is hard to relate their accounts of network function to the actual objects that are doing the computations—the network's parameters and its nonlinearities.

+ +

These issues motivate parameter decomposition methods, which give accounts of network function in terms of the components of the network's parameters that are used by the network on a given datapoint. Ablation-based parameter decomposition methods identify a set of parameter components where as few components as possible are necessary to perform the same computations as the original network on any datapoint, while the set of components sums to the parameters of the target network and are as computationally simple as possible.

+ +
+

In this work, we introduce Adversarial Parameter Decomposition (VPD), which builds on SPD but with several important modifications that make it more mechanistically faithful and scalable. We decompose a small language model (67M parameters) trained on The Pile and find parameter components that are highly interpretable, compare favorably to transcoder and CLT latents, and enable novel circuit analysis.

+
+ +

Method — Adversarial Parameter Decomposition

+ +

Our method, VPD, builds heavily on SPD, but we do not assume familiarity with that work. Our goal is to decompose a neural network's parameters into the mechanisms that it uses to compute its behavior. Networks appear not to use all of their parameters simultaneously on every datapoint. If particular parameters are unused by the network on a particular datapoint, then we should be able to ablate them without adversely affecting the network's output.

+ +

Ablation-based parameter decomposition methods thus aim to decompose network parameters into a set of vectors in parameter space called parameter components that sum to the network's total parameter vector and are:

+ + +

Parameter components are vectors in parameter space

+ +

We decompose individual weight matrices into sums of rank-one matrices called subcomponents, each parametrized as an outer product of two vectors. Although a single subcomponent parameterizes only a single weight matrix, it implicitly parametrises a full parameter vector if we assume it takes values of 0 in every other weight matrix. It is therefore possible to combine these subcomponents into full parameter components by clustering them together.

+ +

Optimizing for minimality

+ +

We train a causal importance function to predict how ablatable each subcomponent is on a given datapoint. We want causal importance values to take minimal values, leading to the importance minimality loss:

+ +
+ $$\mathcal{L}_{\text{importance-minimality}} = \sum_{l=1}^{L} \sum_{c=1}^{C} |g^l_c(x)|^p$$ +
+ +

Optimizing for mechanistic faithfulness

+ +

We create stochastic masks $m^l_c(x, r) := g^l_c(x) + (1 - g^l_c(x))r^l_c(x)$ where $r \sim \mathcal{U}(0,1)$, and minimize:

+ +
+ $$\mathcal{L}_{\text{stochastic-recon}} = \frac{1}{S}\sum_{s=1}^{S} D\!\left(f(x \mid W'(x, r^{(s)})),\; f(x \mid W)\right)$$ +
+ +

VPD optimizes for a stricter criterion than SPD: adversarial ablatability. The stochastic reconstruction loss approximates our desired condition on average, but not for worst-case values. VPD therefore introduces an adversarial loss:

+ +
+ $$\mathcal{L}_{\text{adversarial-recon}} = \frac{1}{S} \max_{r^{(s)}(x)} D\!\left(f(x \mid W'(x, r^{(s)}(x))),\; f(x \mid W)\right)$$ +
+ +

Optimizing for simplicity

+ +

VPD introduces a frequency-minimality loss that encourages subcomponents to activate on as few datapoints as possible, complementing the importance minimality loss which encourages datapoints to activate as few subcomponents as possible:

+ +
+ $$\mathcal{L}_{\text{frequency-minimality}} = \sum_{l=1}^{L}\sum_{c=1}^{C} |g^l_c(x)|^p \log_2\!\left(1 + \sum_{x'} |g^l_c(x')|^p\right)$$ +
+ +

Where importance minimality encourages each datapoint to activate few components, frequency minimality encourages each component to activate on few datapoints. This creates a useful tradeoff during training.

+ +

Results

+ +

The target language model

+ +

We trained a four-layer, 67M parameter decoder-only transformer model on an uncopyrighted subset of The Pile. It uses standard multihead attention with RoPE positional encoding and MLPs with GELU activation. The model achieves a final validation cross-entropy loss of approximately 2.71.

+ + + + + + + + + + +
PropertyValue
Layers4
Residual stream dmodel768
MLP intermediate dimension3,072
Attention heads6
Context length512
Vocabulary size50,277
Total parameters~67M
+ +

VPD achieves a better sparsity–accuracy tradeoff

+ +

We compare VPD's reconstruction quality against transcoders and cross-layer transcoders (CLTs). VPD achieves lower CE degradation than both at comparable sparsity levels, consistent across multiple normalizations of component count.

+ +

End-to-end trained activation-based methods exhibit severe brittleness to evaluation mode mismatch. When evaluated in a mismatched setting (cascading vs parallel), performance degrades by 5–20×. VPD's CE degradation, by contrast, is relatively consistent across all evaluation protocols, because its stochastic and adversarial masking during training naturally covers both patterns.

+ +

Parameter components are highly interpretable

+ +

Automated interpretation of VPD's parameter components shows that they correspond to recognizable linguistic and computational functions. Below, we show example components from our decomposition alongside transcoder latents trained on the same architecture, demonstrating that VPD components are at least as interpretable as transcoder features.

+ +

Example components

+ +

The interactive viewer below shows VPD parameter components (left) and transcoder latents (right) from the same model architecture. Each card displays the component's automatically generated label, activation examples (tokens where the component fires are highlighted), and token correlation statistics. Use the arrow keys or navigation buttons to browse.

+ +
+ +
+ +

Case studies

+ +

Case study 1: Gender for possessive pronouns

+ +

On the prompt The princess lost her crown. the target model correctly predicts that her follows lost, assigning probability 0.586. This requires recognizing that a possessive pronoun is coming up, remembering that the previous token was princess, and realizing that princesses use feminine-gendered pronouns.

+ +

The attribution graph reveals two core mechanisms: one which moves the femaleness attribute of "princess" over to the next token in attention layer 3, and another which suggests that a possessive pronoun might follow the verb "lost". A subset of just six components proves sufficient to predict "her" under causal importance masking—but fails under adversarial masking, revealing that many more components play important computational roles.

+ +

Case study 2: Distributed attention behaviors

+ +

Previous token behavior—attention from timestep t to t−1—is typically associated with "previous token heads." We find that in our model, a single pair of VPD rank-one components, whose weight spans all heads in that layer, is responsible for a greater amount of previous token behavior than the model's dedicated previous token head (L1H1). This demonstrates VPD's ability to identify attention computations distributed across heads—something activation-based methods struggle with.

+ +

Case study 3: Bracket closing

+ +

On the prompt <u,v>, the model predicts the closing bracket >. The attribution graph reveals a rich multi-layer computation: layer 1 attention carries forward general delimiter information from <, layer 2 attention distinguishes which specific bracket type was opened, and layer 3 MLPs produce the final output. Ablation experiments confirm that each layer's contribution is necessary—but only adversarial evaluation reveals this, as standard causal masking dramatically underestimates the number of components involved.

+ +

Discussion

+ +

VPD and other ablation-based parameter decomposition methods obey a principle of correspondence: ablations in the decomposed model have exactly corresponding ablations in the original model, making it straightforward to use insights from the decomposition for model editing. For example, we can remove any component from the original model by subtracting it from the total parameter vector.

+ +

Our results suggest that parameter decomposition methods offer a promising alternative to activation-based approaches. VPD components are more mechanistically faithful, exhibit less feature splitting, and are comparably or more interpretable than transcoder and SAE latents. The rarity of complex nonlinear interactions between components further suggests an underlying computational simplicity in the target model itself.

+ +
+ + +