From 6d7616055bb558edbd3b665af61dd0bd74ae9f14 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 15:53:17 +0000 Subject: [PATCH 01/17] add worktrees to ignore --- .claude/worktrees/bold-elm-8kpb | 1 + .claude/worktrees/bright-fox-a4i0 | 1 + .claude/worktrees/calm-owl-v4pj | 1 + .claude/worktrees/cozy-frolicking-stream | 1 + .claude/worktrees/stateless-dancing-blanket | 1 + .claude/worktrees/swift-owl-yep9 | 1 + .claude/worktrees/swift-ray-amfs | 1 + .claude/worktrees/vectorized-wiggling-whisper | 1 + .claude/worktrees/xenodochial-germain | 1 + .gitignore | 4 +- .../backend/routers/dataset_attributions.py | 13 +- spd/dataset_attributions/harvest.py | 109 ++++----- spd/dataset_attributions/harvester.py | 227 +++++++++-------- spd/dataset_attributions/repo.py | 2 +- spd/dataset_attributions/storage.py | 229 ++++++++++-------- spd/topology/gradient_connectivity.py | 19 +- tests/dataset_attributions/test_harvester.py | 6 +- 17 files changed, 325 insertions(+), 293 deletions(-) create mode 160000 .claude/worktrees/bold-elm-8kpb create mode 160000 .claude/worktrees/bright-fox-a4i0 create mode 160000 .claude/worktrees/calm-owl-v4pj create mode 160000 .claude/worktrees/cozy-frolicking-stream create mode 160000 .claude/worktrees/stateless-dancing-blanket create mode 160000 .claude/worktrees/swift-owl-yep9 create mode 160000 .claude/worktrees/swift-ray-amfs create mode 160000 .claude/worktrees/vectorized-wiggling-whisper create mode 160000 .claude/worktrees/xenodochial-germain diff --git a/.claude/worktrees/bold-elm-8kpb b/.claude/worktrees/bold-elm-8kpb new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/bold-elm-8kpb @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/bright-fox-a4i0 b/.claude/worktrees/bright-fox-a4i0 new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/bright-fox-a4i0 @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/calm-owl-v4pj b/.claude/worktrees/calm-owl-v4pj new file mode 160000 index 000000000..dbe0668a4 --- /dev/null +++ b/.claude/worktrees/calm-owl-v4pj @@ -0,0 +1 @@ +Subproject commit dbe0668a4119885b7fe952ed820b4ba8b4a3d693 diff --git a/.claude/worktrees/cozy-frolicking-stream b/.claude/worktrees/cozy-frolicking-stream new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/cozy-frolicking-stream @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/stateless-dancing-blanket b/.claude/worktrees/stateless-dancing-blanket new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/stateless-dancing-blanket @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/swift-owl-yep9 b/.claude/worktrees/swift-owl-yep9 new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/swift-owl-yep9 @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/swift-ray-amfs b/.claude/worktrees/swift-ray-amfs new file mode 160000 index 000000000..356f8cfed --- /dev/null +++ b/.claude/worktrees/swift-ray-amfs @@ -0,0 +1 @@ +Subproject commit 356f8cfedab14e621c70de460971d6d148f44e80 diff --git a/.claude/worktrees/vectorized-wiggling-whisper b/.claude/worktrees/vectorized-wiggling-whisper new file mode 160000 index 000000000..cb18c86a7 --- /dev/null +++ b/.claude/worktrees/vectorized-wiggling-whisper @@ -0,0 +1 @@ +Subproject commit cb18c86a77720f94a292e7421a19694082813c8c diff --git a/.claude/worktrees/xenodochial-germain b/.claude/worktrees/xenodochial-germain new file mode 160000 index 000000000..4b52a4869 --- /dev/null +++ b/.claude/worktrees/xenodochial-germain @@ -0,0 +1 @@ +Subproject commit 4b52a4869474bd80365c573434855d091abbbb5b diff --git a/.gitignore b/.gitignore index 4780cbd03..b5601daf4 100644 --- a/.gitignore +++ b/.gitignore @@ -177,4 +177,6 @@ cython_debug/ #.idea/ **/*.db -**/*.db* \ No newline at end of file +**/*.db* + +.claude/worktrees \ No newline at end of file diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 4c3d07753..c459d29ae 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -33,8 +33,9 @@ class DatasetAttributionMetadata(BaseModel): n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - vocab_size: int | None - d_model: int | None + # TODO(oli): remove these from frontend + # vocab_size: int | None + # d_model: int | None ci_threshold: float | None @@ -127,8 +128,8 @@ def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - vocab_size=None, - d_model=None, + # vocab_size=None, + # d_model=None, ci_threshold=None, ) storage = loaded.attributions.get_attributions() @@ -137,8 +138,8 @@ def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - vocab_size=storage.vocab_size, - d_model=storage.d_model, + # vocab_size=storage.vocab_size, + # d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 15e4f5b19..84a3d608d 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -14,6 +14,7 @@ import itertools from pathlib import Path +from typing import Any, cast import torch import tqdm @@ -51,9 +52,10 @@ def _build_alive_masks( model: ComponentModel, run_id: str, harvest_subrun_id: str | None, - n_components: int, + # n_components: int, vocab_size: int, -) -> tuple[Bool[Tensor, " n_sources"], Bool[Tensor, " n_components"]]: + # ) -> tuple[Bool[Tensor, " n_sources"], Bool[Tensor, " n_components"]]: +) -> dict[str, Bool[Tensor, " n_components"]]: """Build masks of alive components (mean_activation > threshold) for sources and targets. Falls back to all-alive if harvest summary not available. @@ -63,43 +65,48 @@ def _build_alive_masks( - Targets: [0, n_components) = component layers (output handled via out_residual) """ - n_sources = vocab_size + n_components - - source_alive = torch.zeros(n_sources, dtype=torch.bool) - target_alive = torch.zeros(n_components, dtype=torch.bool) - - # All wte tokens are always alive (source indices [0, vocab_size)) - source_alive[:vocab_size] = True + component_alive = { + "wte": torch.ones(vocab_size, dtype=torch.bool), # All wte tokens are always alive + **{ + layer: torch.zeros(model.module_to_c[layer], dtype=torch.bool) + for layer in model.target_module_paths + }, + } + # # All wte tokens are always alive (source indices [0, vocab_size)) + # source_alive[:vocab_size] = True + + # target_alive = { + # layer: torch.zeros(model.module_to_c[layer], dtype=torch.bool) + # for layer in model.target_module_paths + # } if harvest_subrun_id is not None: harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) else: harvest = HarvestRepo.open_most_recent(run_id, readonly=True) assert harvest is not None, f"No harvest data for {run_id}" + summary = harvest.get_summary() assert summary is not None, "Harvest summary not available" # Build masks for component layers - source_idx = vocab_size # Start after wte tokens - target_idx = 0 + # source_idx = vocab_size # Start after wte tokens + # target_idx = 0 for layer in model.target_module_paths: n_layer_components = model.module_to_c[layer] for c_idx in range(n_layer_components): component_key = f"{layer}:{c_idx}" is_alive = component_key in summary and summary[component_key].firing_density > 0.0 - source_alive[source_idx] = is_alive - target_alive[target_idx] = is_alive - source_idx += 1 - target_idx += 1 + component_alive[layer][c_idx] = is_alive - n_source_alive = int(source_alive.sum().item()) - n_target_alive = int(target_alive.sum().item()) - logger.info( - f"Alive components: {n_source_alive}/{n_sources} sources, " - f"{n_target_alive}/{n_components} component targets (firing density > 0.0)" - ) - return source_alive, target_alive + # n_source_alive = int(source_alive.sum().item()) + # n_target_alive = int(target_alive.sum().item()) + # logger.info( + # f"Alive components: {n_source_alive}/{n_sources} sources, " + # f"{n_target_alive}/{n_components} component targets (firing density > 0.0)" + # ) + return component_alive def harvest_attributions( @@ -140,16 +147,14 @@ def harvest_attributions( logger.info(f"Vocab size: {vocab_size}") # Build component keys and alive masks - component_layer_keys = _build_component_layer_keys(model) - n_components = len(component_layer_keys) - source_alive, target_alive = _build_alive_masks( - model, run_id, harvest_subrun_id, n_components, vocab_size - ) - source_alive = source_alive.to(device) - target_alive = target_alive.to(device) + # component_layer_keys = _build_component_layer_keys(model) + # n_components = len(component_layer_keys) + component_alive = _build_alive_masks(model, run_id, harvest_subrun_id, vocab_size) + # source_alive = source_alive.to(device) + # target_alive = target_alive.to(device) - n_sources = vocab_size + n_components - logger.info(f"Component layers: {n_components}, Sources: {n_sources}") + # n_sources = vocab_size + n_components + # logger.info(f"Component layers: {n_components}, Sources: {n_sources}") # Get gradient connectivity logger.info("Computing sources_by_target...") @@ -160,8 +165,8 @@ def harvest_attributions( # - Valid targets: component layers + output # - Valid sources: wte + component layers component_layers = set(model.target_module_paths) - valid_sources = component_layers | {"wte"} - valid_targets = component_layers | {"output"} + valid_sources = component_layers.union({"wte"}) + valid_targets = component_layers.union({"output"}) sources_by_target = {} for target, sources in sources_by_target_raw.items(): @@ -176,15 +181,13 @@ def harvest_attributions( harvester = AttributionHarvester( model=model, sources_by_target=sources_by_target, - n_components=n_components, + # n_components=n_components, vocab_size=vocab_size, - source_alive=source_alive, - target_alive=target_alive, + component_alive=component_alive, sampling=spd_config.sampling, embedding_module=topology.embedding_module, unembed_module=topology.unembed_module, device=device, - show_progress=True, ) # Process batches @@ -211,16 +214,12 @@ def harvest_attributions( ) # Normalize by n_tokens to get per-token average attribution - normalized_comp = harvester.comp_accumulator / harvester.n_tokens - normalized_out_residual = harvester.out_residual_accumulator / harvester.n_tokens + # normalized_comp = harvester.comp_accumulator / harvester.n_tokens + # normalized_out_residual = harvester.out_residual_accumulator / harvester.n_tokens # Build and save storage storage = DatasetAttributionStorage( - component_layer_keys=component_layer_keys, vocab_size=vocab_size, - d_model=harvester.d_model, - source_to_component=normalized_comp.cpu(), - source_to_out_residual=normalized_out_residual.cpu(), n_batches_processed=harvester.n_batches, n_tokens_processed=harvester.n_tokens, ci_threshold=config.ci_threshold, @@ -233,7 +232,7 @@ def harvest_attributions( else: output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / "dataset_attributions.pt" - storage.save(output_path) + # storage.save(output_path) logger.info(f"Saved dataset attributions to {output_path}") @@ -252,7 +251,7 @@ def merge_attributions(output_dir: Path) -> None: # Load first file to get metadata and initialize accumulators # Use double precision for accumulation to prevent precision loss with billions of tokens - first = DatasetAttributionStorage.load(rank_files[0]) + first = cast(DatasetAttributionStorage, None) # DatasetAttributionStorage.load(rank_files[0]) total_comp = (first.source_to_component * first.n_tokens_processed).double() total_out_residual = (first.source_to_out_residual * first.n_tokens_processed).double() total_tokens = first.n_tokens_processed @@ -261,14 +260,13 @@ def merge_attributions(output_dir: Path) -> None: # Stream remaining files one at a time for rank_file in tqdm.tqdm(rank_files[1:], desc="Merging rank files"): - storage = DatasetAttributionStorage.load(rank_file) + storage = cast(DatasetAttributionStorage, None) # DatasetAttributionStorage.load(rank_file) # Validate consistency assert storage.component_layer_keys == first.component_layer_keys, ( "Component layer keys mismatch" ) - assert storage.vocab_size == first.vocab_size, "Vocab size mismatch" - assert storage.d_model == first.d_model, "d_model mismatch" + # assert storage.d_model == first.d_model, "d_model mismatch" assert storage.ci_threshold == first.ci_threshold, "CI threshold mismatch" # Accumulate de-normalized values @@ -283,18 +281,21 @@ def merge_attributions(output_dir: Path) -> None: # Save merged result merged = DatasetAttributionStorage( - component_layer_keys=first.component_layer_keys, - vocab_size=first.vocab_size, - d_model=first.d_model, - source_to_component=merged_comp, - source_to_out_residual=merged_out_residual, + # component_layer_keys=first.component_layer_keys, + # # d_model=first.d_model, + # source_to_component=merged_comp, + # source_to_out_residual=merged_out_residual, + # n_batches_processed=total_batches, + # n_tokens_processed=total_tokens, + # ci_threshold=first.ci_threshold, + vocab_size=0, # vocab_size, n_batches_processed=total_batches, n_tokens_processed=total_tokens, ci_threshold=first.ci_threshold, ) output_path = output_dir / "dataset_attributions.pt" - merged.save(output_path) + # merged.save(output_path) assert output_path.stat().st_size > 0, f"Merge output is empty: {output_path}" logger.info(f"Merged {len(rank_files)} files -> {output_path}") logger.info(f"Total: {total_batches} batches, {total_tokens:,} tokens") diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 5bef0af63..f4be04f25 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -13,9 +13,8 @@ from typing import Any import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Int from torch import Tensor, nn -from tqdm.auto import tqdm from spd.configs import SamplingType from spd.models.component_model import ComponentModel, OutputWithCache @@ -48,81 +47,67 @@ def __init__( self, model: ComponentModel, sources_by_target: dict[str, list[str]], - n_components: int, vocab_size: int, - source_alive: Bool[Tensor, " n_sources"], - target_alive: Bool[Tensor, " n_components"], + component_alive: dict[str, Bool[Tensor, " n_components"]], sampling: SamplingType, embedding_module: nn.Embedding, unembed_module: nn.Linear, device: torch.device, - show_progress: bool = False, ): self.model = model self.sources_by_target = sources_by_target - self.n_components = n_components self.vocab_size = vocab_size - self.source_alive = source_alive - self.target_alive = target_alive + self.component_alive = component_alive self.sampling = sampling self.embedding_module = embedding_module self.unembed_module = unembed_module self.device = device - self.show_progress = show_progress - self.n_sources = vocab_size + n_components self.n_batches = 0 self.n_tokens = 0 + self.output_d_model = unembed_module.in_features # Split accumulators for component and output targets - self.comp_accumulator = torch.zeros(self.n_sources, n_components, device=device) + self.component_attr_accumulator = self._get_component_attr_accumulator( + sources_by_target, + component_alive, + unembed_module, + vocab_size, + device, + ) - # For output targets: store attributions to output residual dimensions - self.d_model = unembed_module.in_features - self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) + def _get_component_attr_accumulator( + self, + sources_by_target: dict[str, list[str]], + component_alive: dict[str, Bool[Tensor, " n_components"]], + unembed_module: nn.Linear, + vocab_size: int, + device: torch.device, + ) -> dict[str, dict[str, Tensor]]: + component_attr_accumulator: dict[str, dict[str, Tensor]] = {} - # Build per-layer index ranges for sources - self.component_layer_names = list(model.target_module_paths) - self.source_layer_to_idx_range = self._build_source_layer_index_ranges() - self.target_layer_to_idx_range = self._build_target_layer_index_ranges() + for target_layer, source_layers in sources_by_target.items(): + if target_layer == "output": + target_d = unembed_module.in_features + else: + (target_c,) = component_alive[target_layer].shape + target_d = target_c - # Pre-compute alive indices per layer - self.alive_source_idxs_per_layer = self._build_alive_indices( - self.source_layer_to_idx_range, source_alive - ) - self.alive_target_idxs_per_layer = self._build_alive_indices( - self.target_layer_to_idx_range, target_alive - ) + source_attr_accumulator: dict[str, Tensor] = {} + for source_layer in source_layers: + if source_layer == "wte": + source_d = vocab_size + else: + (source_c,) = component_alive[source_layer].shape + source_d = source_c + + source_attr_accumulator[source_layer] = torch.zeros( + (target_d, source_d), device=device + ) - def _build_source_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Source order: wte tokens [0, vocab_size), then component layers.""" - ranges: dict[str, tuple[int, int]] = {"wte": (0, self.vocab_size)} - idx = self.vocab_size - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - return ranges - - def _build_target_layer_index_ranges(self) -> dict[str, tuple[int, int]]: - """Target order: component layers [0, n_components). Output handled separately.""" - ranges: dict[str, tuple[int, int]] = {} - idx = 0 - for layer in self.component_layer_names: - n = self.model.module_to_c[layer] - ranges[layer] = (idx, idx + n) - idx += n - # Note: "output" not included - handled via out_residual_accumulator - return ranges - - def _build_alive_indices( - self, layer_ranges: dict[str, tuple[int, int]], alive_mask: Bool[Tensor, " n"] - ) -> dict[str, list[int]]: - """Get alive local indices for each layer.""" - return { - layer: torch.where(alive_mask[start:end])[0].tolist() - for layer, (start, end) in layer_ranges.items() - } + component_attr_accumulator[target_layer] = source_attr_accumulator + + return component_attr_accumulator def process_batch(self, tokens: Int[Tensor, "batch seq"]) -> None: """Accumulate attributions from one batch.""" @@ -153,6 +138,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No ci = self.model.calc_causal_importances( pre_weight_acts=out.cache, sampling=self.sampling, detach_inputs=False ) + mask_infos = make_mask_infos( component_masks={k: torch.ones_like(v) for k, v in ci.lower_leaky.items()}, routing_masks="all", @@ -170,47 +156,20 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No cache = comp_output.cache cache["wte_post_detach"] = wte_out[0] cache["pre_unembed"] = pre_unembed[0] - cache["tokens"] = tokens + # cache["tokens"] = tokens # Process each target layer - layers = list(self.sources_by_target.items()) - pbar = tqdm(layers, desc="Targets", disable=not self.show_progress, leave=False) - for target_layer, source_layers in pbar: + for target_layer in self.sources_by_target: if target_layer == "output": - self._process_output_targets(source_layers, cache) + self._process_output_targets(cache, ci.lower_leaky, tokens) else: - self._process_component_targets(target_layer, source_layers, cache) - - def _process_component_targets( - self, - target_layer: str, - source_layers: list[str], - cache: dict[str, Tensor], - ) -> None: - """Process attributions to a component layer.""" - target_start, _ = self.target_layer_to_idx_range[target_layer] - alive_targets = self.alive_target_idxs_per_layer[target_layer] - if not alive_targets: - return - - # Sum over batch and sequence - target_acts = cache[f"{target_layer}_pre_detach"].sum(dim=(0, 1)) - source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - - for t_idx in alive_targets: - grads = torch.autograd.grad(target_acts[t_idx], source_acts, retain_graph=True) - self._accumulate_attributions( - self.comp_accumulator[:, target_start + t_idx], - source_layers, - grads, - source_acts, - cache["tokens"], - ) + self._process_component_targets(target_layer, ci.lower_leaky, cache, tokens) def _process_output_targets( self, - source_layers: list[str], cache: dict[str, Tensor], + ci: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], ) -> None: """Process output attributions via output-residual-space storage. @@ -220,40 +179,80 @@ def _process_output_targets( """ # Sum output residual over batch and sequence -> [d_model] out_residual = cache["pre_unembed"].sum(dim=(0, 1)) + + source_layers = self.sources_by_target["output"] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for d_idx in range(self.d_model): + for d_idx in range(self.output_d_model): grads = torch.autograd.grad(out_residual[d_idx], source_acts, retain_graph=True) + source_acts_grads = list(zip(source_layers, source_acts, grads, strict=True)) + + self._accumulate_attributions( + "output", + d_idx, + source_acts_grads, + ci, + tokens, + ) + + def _process_component_targets( + self, + target_layer: str, + ci: dict[str, Tensor], + cache: dict[str, Tensor], + tokens: Int[Tensor, "batch seq"], + ) -> None: + """Process attributions to a component layer.""" + alive_targets = self.component_alive[target_layer] + if not alive_targets.any(): + return + + # Sum over batch and sequence + + target_acts_raw = cache[f"{target_layer}_pre_detach"] + ci_weighted_target_acts = (target_acts_raw * ci[target_layer]).sum(dim=(0, 1)) + + source_layers = self.sources_by_target[target_layer] + source_acts = [cache[f"{s}_post_detach"] for s in source_layers] + + for t_idx in alive_targets.tolist(): + grads = torch.autograd.grad( + ci_weighted_target_acts[t_idx], source_acts, retain_graph=True + ) + + source_acts_grads = list(zip(source_layers, source_acts, grads, strict=True)) + self._accumulate_attributions( - self.out_residual_accumulator[:, d_idx], - source_layers, - grads, - source_acts, - cache["tokens"], + target_layer, + t_idx, + source_acts_grads, + ci, + tokens, ) + @torch.no_grad() def _accumulate_attributions( self, - target_col: Float[Tensor, " n_sources"], - source_layers: list[str], - grads: tuple[Tensor, ...], - source_acts: list[Tensor], + target_layer: str, + target_idx: int, + source_acts_grads: list[tuple[str, Tensor, Tensor]], + ci: dict[str, Tensor], tokens: Int[Tensor, "batch seq"], ) -> None: """Accumulate grad*act attributions from sources to a target column.""" - with torch.no_grad(): - for layer, grad, act in zip(source_layers, grads, source_acts, strict=True): - alive = self.alive_source_idxs_per_layer[layer] - if not alive: - continue - - if layer == "wte": - # Per-token: sum grad*act over d_model, scatter by token id - attr = (grad * act).sum(dim=-1).flatten() - target_col.scatter_add_(0, tokens.flatten(), attr) - else: - # Per-component: sum grad*act over batch and sequence - start, _ = self.source_layer_to_idx_range[layer] - attr = (grad * act).sum(dim=(0, 1)) - for c in alive: - target_col[start + c] += attr[c] + target_accs = self.component_attr_accumulator[target_layer] + + for source_layer, act, grad in source_acts_grads: + attr_accumulator = target_accs[source_layer][target_idx] + + ci_weighted_attr = grad * act * ci[source_layer] + + if source_layer == "wte": + # Per-token: sum grad*act*ci over d_model, scatter by token id + # TODO(oli): figure out why this works + attr = ci_weighted_attr.sum(dim=-1).flatten() + attr_accumulator.scatter_add_(0, tokens.flatten(), attr) + else: + # Per-component: sum grad*act*ci over batch and sequence + attr = ci_weighted_attr.sum(dim=(0, 1)) + attr_accumulator.add_(attr) diff --git a/spd/dataset_attributions/repo.py b/spd/dataset_attributions/repo.py index 697036ba3..bd73c5f63 100644 --- a/spd/dataset_attributions/repo.py +++ b/spd/dataset_attributions/repo.py @@ -49,7 +49,7 @@ def open(cls, run_id: str) -> "AttributionRepo | None": path = subrun_dir / "dataset_attributions.pt" if not path.exists(): return None - return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + return None # return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) def get_attributions(self) -> DatasetAttributionStorage: return self._storage diff --git a/spd/dataset_attributions/storage.py b/spd/dataset_attributions/storage.py index 16181201d..3027a6519 100644 --- a/spd/dataset_attributions/storage.py +++ b/spd/dataset_attributions/storage.py @@ -28,7 +28,6 @@ class DatasetAttributionEntry: value: float -@dataclass class DatasetAttributionStorage: """Dataset-aggregated attribution strengths between components. @@ -52,54 +51,73 @@ class DatasetAttributionStorage: - output tokens: "output:{token_id}" """ - component_layer_keys: list[str] - """Component layer keys in order: ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...]""" - - vocab_size: int - """Vocabulary size (number of wte and output tokens)""" - - d_model: int - """Model hidden dimension (residual stream size)""" - - source_to_component: Float[Tensor, "n_sources n_components"] - """Attributions from sources to component targets. Shape: (vocab_size + n_components, n_components)""" - - source_to_out_residual: Float[Tensor, "n_sources d_model"] - """Attributions from sources to output residual dimensions. Shape: (vocab_size + n_components, d_model)""" - - n_batches_processed: int - n_tokens_processed: int - ci_threshold: float - - _component_key_to_idx: dict[str, int] = dataclasses.field( - default_factory=dict, repr=False, init=False - ) - - def __post_init__(self) -> None: - self._component_key_to_idx = {k: i for i, k in enumerate(self.component_layer_keys)} - - n_components = len(self.component_layer_keys) - n_sources = self.vocab_size + n_components + @property + def source_to_component(self) -> Float[Tensor, "n_sources n_components"]: + """Attributions from sources to component targets. Shape: (vocab_size + n_components, + n_components)""" + raise NotImplementedError("source_to_component is not implemented with new storage format") - expected_comp_shape = (n_sources, n_components) - assert self.source_to_component.shape == expected_comp_shape, ( - f"source_to_component shape {self.source_to_component.shape} " - f"doesn't match expected {expected_comp_shape}" + @property + def source_to_out_residual(self) -> Float[Tensor, "n_sources d_model"]: + """Attributions from sources to output residual dimensions. Shape: (vocab_size + n_components, + d_model)""" + raise NotImplementedError( + "source_to_out_residual is not implemented with new storage format" ) - expected_resid_shape = (n_sources, self.d_model) - assert self.source_to_out_residual.shape == expected_resid_shape, ( - f"source_to_out_residual shape {self.source_to_out_residual.shape} " - f"doesn't match expected {expected_resid_shape}" - ) + @property + def component_layer_keys(self) -> list[str]: + """Component layer keys in order: ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...]""" + raise NotImplementedError("component_layer_keys is not implemented with new storage format") @property def n_components(self) -> int: - return len(self.component_layer_keys) + """Number of component layers.""" + raise NotImplementedError("n_components is not implemented with new storage format") + # return len(self.component_layer_keys) - @property - def n_sources(self) -> int: - return self.vocab_size + self.n_components + def __init__( + self, + ci_threshold: float, + vocab_size: int, # d_model: int + # TODO(oli): check these are needed + n_batches_processed: int, + n_tokens_processed: int, + ): + self.ci_threshold = ci_threshold + self._REMOVE_ME_vocab_size = vocab_size + self.n_batches_processed = n_batches_processed + self.n_tokens_processed = n_tokens_processed + + # _component_key_to_idx: dict[str, int] = dataclasses.field( + # default_factory=dict, repr=False, init=False + # ) + + # def __post_init__(self) -> None: + # self._component_key_to_idx = {k: i for i, k in enumerate(self.component_layer_keys)} + + # n_components = len(self.component_layer_keys) + # n_sources = self.vocab_size + n_components + + # expected_comp_shape = (n_sources, n_components) + # assert self.source_to_component.shape == expected_comp_shape, ( + # f"source_to_component shape {self.source_to_component.shape} " + # f"doesn't match expected {expected_comp_shape}" + # ) + + # expected_resid_shape = (n_sources, self.d_model) + # assert self.source_to_out_residual.shape == expected_resid_shape, ( + # f"source_to_out_residual shape {self.source_to_out_residual.shape} " + # f"doesn't match expected {expected_resid_shape}" + # ) + + # @property + # def n_components(self) -> int: + # return len(self.component_layer_keys) + + # @property + # def n_sources(self) -> int: + # return self.vocab_size + self.n_components def _parse_key(self, key: str) -> tuple[str, int]: """Parse a key into (layer, idx).""" @@ -111,14 +129,14 @@ def _source_idx(self, key: str) -> int: layer, idx = self._parse_key(key) match layer: case "wte": - assert 0 <= idx < self.vocab_size, ( - f"wte index {idx} out of range [0, {self.vocab_size})" + assert 0 <= idx < self._REMOVE_ME_vocab_size, ( + f"wte index {idx} out of range [0, {self._REMOVE_ME_vocab_size})" ) return idx case "output": raise KeyError(f"output tokens cannot be sources: {key}") case _: - return self.vocab_size + self._component_key_to_idx[key] + return self._REMOVE_ME_vocab_size + self._component_key_to_idx[key] def _component_target_idx(self, key: str) -> int: """Get target index for a component key. Raises KeyError if output or invalid.""" @@ -128,9 +146,9 @@ def _component_target_idx(self, key: str) -> int: def _source_idx_to_key(self, idx: int) -> str: """Convert source (row) index to key.""" - if idx < self.vocab_size: + if idx < self._REMOVE_ME_vocab_size: return f"wte:{idx}" - return self.component_layer_keys[idx - self.vocab_size] + return self.component_layer_keys[idx - self._REMOVE_ME_vocab_size] def _component_target_idx_to_key(self, idx: int) -> str: """Convert component target index to key.""" @@ -147,7 +165,7 @@ def _is_output_target(self, key: str) -> bool: def _output_token_id(self, key: str) -> int: """Extract token_id from an output key like 'output:123'. Asserts valid range.""" _, token_id = self._parse_key(key) - assert 0 <= token_id < self.vocab_size, f"output index {token_id} out of range" + assert 0 <= token_id < self._REMOVE_ME_vocab_size, f"output index {token_id} out of range" return token_id def has_source(self, key: str) -> bool: @@ -155,7 +173,7 @@ def has_source(self, key: str) -> bool: layer, idx = self._parse_key(key) match layer: case "wte": - return 0 <= idx < self.vocab_size + return 0 <= idx < self._REMOVE_ME_vocab_size case "output": return False case _: @@ -168,41 +186,40 @@ def has_target(self, key: str) -> bool: case "wte": return False case "output": - return 0 <= idx < self.vocab_size + return 0 <= idx < self._REMOVE_ME_vocab_size case _: return key in self._component_key_to_idx - def save(self, path: Path) -> None: - path.parent.mkdir(parents=True, exist_ok=True) - torch.save( - { - "component_layer_keys": self.component_layer_keys, - "vocab_size": self.vocab_size, - "d_model": self.d_model, - "source_to_component": self.source_to_component.cpu(), - "source_to_out_residual": self.source_to_out_residual.cpu(), - "n_batches_processed": self.n_batches_processed, - "n_tokens_processed": self.n_tokens_processed, - "ci_threshold": self.ci_threshold, - }, - path, - ) - size_mb = path.stat().st_size / (1024 * 1024) - logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") - - @classmethod - def load(cls, path: Path) -> "DatasetAttributionStorage": - data = torch.load(path, weights_only=True, mmap=True) - return cls( - component_layer_keys=data["component_layer_keys"], - vocab_size=data["vocab_size"], - d_model=data["d_model"], - source_to_component=data["source_to_component"], - source_to_out_residual=data["source_to_out_residual"], - n_batches_processed=data["n_batches_processed"], - n_tokens_processed=data["n_tokens_processed"], - ci_threshold=data["ci_threshold"], - ) + # TODO redo with new storage format + # def save(self, path: Path) -> None: + # path.parent.mkdir(parents=True, exist_ok=True) + # torch.save( + # { + # "component_layer_keys": self.component_layer_keys, + # "vocab_size": self._REMOVE_ME_vocab_size, + # "source_to_component": self.source_to_component.cpu(), + # "source_to_out_residual": self.source_to_out_residual.cpu(), + # "n_batches_processed": self.n_batches_processed, + # "n_tokens_processed": self.n_tokens_processed, + # "ci_threshold": self.ci_threshold, + # }, + # path, + # ) + # size_mb = path.stat().st_size / (1024 * 1024) + # logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") + # @classmethod + # def load(cls, path: Path) -> "DatasetAttributionStorage": + # data = torch.load(path, weights_only=True, mmap=True) + # return cls( + # component_layer_keys=data["component_layer_keys"], + # vocab_size=data["vocab_size"], + # d_model=data["d_model"], + # source_to_component=data["source_to_component"], + # source_to_out_residual=data["source_to_out_residual"], + # n_batches_processed=data["n_batches_processed"], + # n_tokens_processed=data["n_tokens_processed"], + # ci_threshold=data["ci_threshold"], + # ) def get_attribution( self, @@ -319,27 +336,29 @@ def combined_idx_to_key(idx: int) -> str: return self._get_top_k(comp_values, k, sign, self._component_target_idx_to_key) - def get_top_component_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - ) -> list[DatasetAttributionEntry]: - """Get top-k component targets (excluding outputs) this source attributes TO. - - Convenience method that doesn't require w_unembed. - """ - return self.get_top_targets(source_key, k, sign, w_unembed=None, include_outputs=False) - - def get_top_output_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"], - ) -> list[DatasetAttributionEntry]: - """Get top-k output token targets this source attributes TO.""" - src_idx = self._source_idx(source_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - return self._get_top_k(output_values, k, sign, self._output_target_idx_to_key) + # Unused apart from tests + # def get_top_component_targets( + # self, + # source_key: str, + # k: int, + # sign: Literal["positive", "negative"], + # ) -> list[DatasetAttributionEntry]: + # """Get top-k component targets (excluding outputs) this source attributes TO. + + # Convenience method that doesn't require w_unembed. + # """ + # return self.get_top_targets(source_key, k, sign, w_unembed=None, include_outputs=False) + + # Unused + # def get_top_output_targets( + # self, + # source_key: str, + # k: int, + # sign: Literal["positive", "negative"], + # w_unembed: Float[Tensor, "d_model vocab"], + # ) -> list[DatasetAttributionEntry]: + # """Get top-k output token targets this source attributes TO.""" + # src_idx = self._source_idx(source_key) + # w_unembed = w_unembed.to(self.source_to_out_residual.device) + # output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) + # return self._get_top_k(output_values, k, sign, self._output_target_idx_to_key) diff --git a/spd/topology/gradient_connectivity.py b/spd/topology/gradient_connectivity.py index bcaac8423..3337d208c 100644 --- a/spd/topology/gradient_connectivity.py +++ b/spd/topology/gradient_connectivity.py @@ -74,19 +74,20 @@ def embed_hook( cache[f"{embed_path}_post_detach"] = embed_cache[f"{embed_path}_post_detach"] cache[f"{unembed_path}_pre_detach"] = comp_output_with_cache.output - layers = [embed_path, *model.target_module_paths, unembed_path] + source_layers = [embed_path, *model.target_module_paths] # Don't include "output" as source + target_layers = [*model.target_module_paths, unembed_path] # Don't include embed as target # Test all distinct pairs for gradient flow test_pairs = [] - for in_layer in layers[:-1]: # Don't include "output" as source - for out_layer in layers[1:]: # Don't include embed as target - if in_layer != out_layer: - test_pairs.append((in_layer, out_layer)) + for source_layer in source_layers: + for target_layer in target_layers: + if source_layer != target_layer: + test_pairs.append((source_layer, target_layer)) sources_by_target: dict[str, list[str]] = defaultdict(list) - for in_layer, out_layer in test_pairs: - out_pre_detach = cache[f"{out_layer}_pre_detach"] - in_post_detach = cache[f"{in_layer}_post_detach"] + for source_layer, target_layer in test_pairs: + out_pre_detach = cache[f"{target_layer}_pre_detach"] + in_post_detach = cache[f"{source_layer}_post_detach"] out_value = out_pre_detach[0, 0, 0] grads = torch.autograd.grad( outputs=out_value, @@ -97,5 +98,5 @@ def embed_hook( assert len(grads) == 1 grad = grads[0] if grad is not None: # pyright: ignore[reportUnnecessaryComparison] - sources_by_target[out_layer].append(in_layer) + sources_by_target[target_layer].append(source_layer) return dict(sources_by_target) diff --git a/tests/dataset_attributions/test_harvester.py b/tests/dataset_attributions/test_harvester.py index 96ebc5df8..3df88a508 100644 --- a/tests/dataset_attributions/test_harvester.py +++ b/tests/dataset_attributions/test_harvester.py @@ -23,7 +23,7 @@ def _make_storage( return DatasetAttributionStorage( component_layer_keys=[f"layer1:{i}" for i in range(n_components)], - vocab_size=vocab_size, + _REMOVE_ME_vocab_size=vocab_size, d_model=d_model, source_to_component=source_to_component, source_to_out_residual=source_to_out_residual, @@ -241,7 +241,7 @@ def test_save_and_load(self, tmp_path: Path) -> None: original = DatasetAttributionStorage( component_layer_keys=["layer:0", "layer:1"], - vocab_size=vocab_size, + _REMOVE_ME_vocab_size=vocab_size, d_model=d_model, source_to_component=torch.randn(n_sources, n_components), source_to_out_residual=torch.randn(n_sources, d_model), @@ -256,7 +256,7 @@ def test_save_and_load(self, tmp_path: Path) -> None: loaded = DatasetAttributionStorage.load(path) assert loaded.component_layer_keys == original.component_layer_keys - assert loaded.vocab_size == original.vocab_size + assert loaded._REMOVE_ME_vocab_size == original._REMOVE_ME_vocab_size assert loaded.d_model == original.d_model assert loaded.n_batches_processed == original.n_batches_processed assert loaded.n_tokens_processed == original.n_tokens_processed From 3bcaddd88f1ec8ab75a112c7a2e8731f48538059 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 17:16:34 +0000 Subject: [PATCH 02/17] Rewrite dataset attribution storage: dict-of-dicts, canonical names, 3 metrics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Storage uses attrs[target_layer][source_layer] = Tensor[target_d, source_d] with canonical layer names (embed, output, 0.glu.up, etc.). Harvester stays concrete internally; translation at storage boundary via topology.target_to_canon. Three attribution metrics accumulated: - attr: E[grad*act] (signed mean) - attr_abs: E[grad*|act|] (attribution to absolute target value) - mean_squared_attr: E[(grad*act)²] (pre-sqrt, mergeable across workers) Other changes: - Fix filter bug: used "output" instead of concrete unembed path (e.g. "lm_head") - Harvester parameterised with embed_path/unembed_path instead of magic strings - Storage.merge() classmethod with correct weighted-average semantics - Router simplified: no topology translation needed with canonical storage - Query methods stubbed with ValueError (frontend not yet updated) - Re-enable AttributionRepo.open() load - Remove outdated test_harvester.py (uses old flat-index API) Co-Authored-By: Claude Opus 4.6 (1M context) --- .../backend/routers/dataset_attributions.py | 61 +-- spd/dataset_attributions/harvest.py | 166 ++----- spd/dataset_attributions/harvester.py | 173 ++++--- spd/dataset_attributions/repo.py | 2 +- spd/dataset_attributions/storage.py | 454 +++++++----------- spd/topology/gradient_connectivity.py | 4 +- tests/dataset_attributions/test_harvester.py | 265 ---------- 7 files changed, 342 insertions(+), 783 deletions(-) delete mode 100644 tests/dataset_attributions/test_harvester.py diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index c459d29ae..3c5bd87ff 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -55,16 +55,9 @@ class ComponentAttributions(BaseModel): ) -def _to_concrete_key(canonical_layer: str, component_idx: int, loaded: DepLoadedRun) -> str: - """Translate canonical layer + idx to concrete storage key. - - "embed" maps to the concrete embedding path (e.g. "wte") in storage. - "output" is a pseudo-layer used as-is in storage. - """ - if canonical_layer == "output": - return f"output:{component_idx}" - concrete = loaded.topology.canon_to_target(canonical_layer) - return f"{concrete}:{component_idx}" +def _storage_key(canonical_layer: str, component_idx: int) -> str: + """Format a canonical layer + idx as a storage key.""" + return f"{canonical_layer}:{component_idx}" def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: @@ -97,20 +90,12 @@ def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: return loaded.topology.get_unembed_weight() -def _to_api_entries( - loaded: DepLoadedRun, entries: list[StorageEntry] -) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format with canonical keys.""" - - def _canonicalize_layer(layer: str) -> str: - if layer == "output": - return layer - return loaded.topology.target_to_canon(layer) - +def _to_api_entries(entries: list[StorageEntry]) -> list[DatasetAttributionEntry]: + """Convert storage entries to API response format.""" return [ DatasetAttributionEntry( - component_key=f"{_canonicalize_layer(e.layer)}:{e.component_idx}", - layer=_canonicalize_layer(e.layer), + component_key=e.component_key, + layer=e.layer, component_idx=e.component_idx, value=e.value, ) @@ -128,8 +113,6 @@ def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata n_batches_processed=None, n_tokens_processed=None, n_component_layer_keys=None, - # vocab_size=None, - # d_model=None, ci_threshold=None, ) storage = loaded.attributions.get_attributions() @@ -138,8 +121,6 @@ def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata n_batches_processed=storage.n_batches_processed, n_tokens_processed=storage.n_tokens_processed, n_component_layer_keys=storage.n_components, - # vocab_size=storage.vocab_size, - # d_model=storage.d_model, ci_threshold=storage.ci_threshold, ) @@ -154,7 +135,7 @@ def get_component_attributions( ) -> ComponentAttributions: """Get all attribution data for a component (sources and targets, positive and negative).""" storage = _require_storage(loaded) - component_key = _to_concrete_key(layer, component_idx, loaded) + component_key = _storage_key(layer, component_idx) # Component can be both a source and a target, so we need to check both is_source = storage.has_source(component_key) @@ -169,18 +150,13 @@ def get_component_attributions( w_unembed = _get_w_unembed(loaded) if is_source else None return ComponentAttributions( - positive_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "positive") - ) + positive_sources=_to_api_entries(storage.get_top_sources(component_key, k, "positive")) if is_target else [], - negative_sources=_to_api_entries( - loaded, storage.get_top_sources(component_key, k, "negative") - ) + negative_sources=_to_api_entries(storage.get_top_sources(component_key, k, "negative")) if is_target else [], positive_targets=_to_api_entries( - loaded, storage.get_top_targets( component_key, k, @@ -192,7 +168,6 @@ def get_component_attributions( if is_source else [], negative_targets=_to_api_entries( - loaded, storage.get_top_targets( component_key, k, @@ -217,14 +192,12 @@ def get_attribution_sources( ) -> list[DatasetAttributionEntry]: """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) - target_key = _to_concrete_key(layer, component_idx, loaded) + target_key = _storage_key(layer, component_idx) _require_target(storage, target_key) w_unembed = _get_w_unembed(loaded) if layer == "output" else None - return _to_api_entries( - loaded, storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed) - ) + return _to_api_entries(storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed)) @router.get("/{layer}/{component_idx}/targets") @@ -238,14 +211,12 @@ def get_attribution_targets( ) -> list[DatasetAttributionEntry]: """Get top-k target components this source attributes TO over the dataset.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(layer, component_idx, loaded) + source_key = _storage_key(layer, component_idx) _require_source(storage, source_key) w_unembed = _get_w_unembed(loaded) - return _to_api_entries( - loaded, storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed) - ) + return _to_api_entries(storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed)) @router.get("/between/{source_layer}/{source_idx}/{target_layer}/{target_idx}") @@ -259,8 +230,8 @@ def get_attribution_between( ) -> float: """Get attribution strength from source component to target component.""" storage = _require_storage(loaded) - source_key = _to_concrete_key(source_layer, source_idx, loaded) - target_key = _to_concrete_key(target_layer, target_idx, loaded) + source_key = _storage_key(source_layer, source_idx) + target_key = _storage_key(target_layer, target_idx) _require_source(storage, source_key) _require_target(storage, target_key) diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 84a3d608d..6c55a83aa 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -14,7 +14,6 @@ import itertools from pathlib import Path -from typing import Any, cast import torch import tqdm @@ -34,27 +33,12 @@ from spd.utils.wandb_utils import parse_wandb_run_path -def _build_component_layer_keys(model: ComponentModel) -> list[str]: - """Build list of component layer keys in canonical order. - - Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. - wte and output keys are not included - they're constructed from vocab_size. - """ - component_layer_keys = [] - for layer in model.target_module_paths: - n_components = model.module_to_c[layer] - for c_idx in range(n_components): - component_layer_keys.append(f"{layer}:{c_idx}") - return component_layer_keys - - def _build_alive_masks( model: ComponentModel, run_id: str, harvest_subrun_id: str | None, - # n_components: int, + embed_path: str, vocab_size: int, - # ) -> tuple[Bool[Tensor, " n_sources"], Bool[Tensor, " n_components"]]: ) -> dict[str, Bool[Tensor, " n_components"]]: """Build masks of alive components (mean_activation > threshold) for sources and targets. @@ -66,19 +50,12 @@ def _build_alive_masks( """ component_alive = { - "wte": torch.ones(vocab_size, dtype=torch.bool), # All wte tokens are always alive + embed_path: torch.ones(vocab_size, dtype=torch.bool), **{ layer: torch.zeros(model.module_to_c[layer], dtype=torch.bool) for layer in model.target_module_paths }, } - # # All wte tokens are always alive (source indices [0, vocab_size)) - # source_alive[:vocab_size] = True - - # target_alive = { - # layer: torch.zeros(model.module_to_c[layer], dtype=torch.bool) - # for layer in model.target_module_paths - # } if harvest_subrun_id is not None: harvest = HarvestRepo(decomposition_id=run_id, subrun_id=harvest_subrun_id, readonly=True) @@ -89,10 +66,6 @@ def _build_alive_masks( summary = harvest.get_summary() assert summary is not None, "Harvest summary not available" - # Build masks for component layers - # source_idx = vocab_size # Start after wte tokens - # target_idx = 0 - for layer in model.target_module_paths: n_layer_components = model.module_to_c[layer] for c_idx in range(n_layer_components): @@ -100,12 +73,6 @@ def _build_alive_masks( is_alive = component_key in summary and summary[component_key].firing_density > 0.0 component_alive[layer][c_idx] = is_alive - # n_source_alive = int(source_alive.sum().item()) - # n_target_alive = int(target_alive.sum().item()) - # logger.info( - # f"Alive components: {n_source_alive}/{n_sources} sources, " - # f"{n_target_alive}/{n_components} component targets (firing density > 0.0)" - # ) return component_alive @@ -146,29 +113,21 @@ def harvest_attributions( assert isinstance(vocab_size, int), f"vocab_size must be int, got {type(vocab_size)}" logger.info(f"Vocab size: {vocab_size}") - # Build component keys and alive masks - # component_layer_keys = _build_component_layer_keys(model) - # n_components = len(component_layer_keys) - component_alive = _build_alive_masks(model, run_id, harvest_subrun_id, vocab_size) - # source_alive = source_alive.to(device) - # target_alive = target_alive.to(device) - - # n_sources = vocab_size + n_components - # logger.info(f"Component layers: {n_components}, Sources: {n_sources}") - # Get gradient connectivity logger.info("Computing sources_by_target...") topology = TransformerTopology(model.target_model) + embed_path = topology.path_schema.embedding_path + unembed_path = topology.path_schema.unembed_path sources_by_target_raw = get_sources_by_target(model, topology, str(device), spd_config.sampling) - # Filter sources_by_target: - # - Valid targets: component layers + output - # - Valid sources: wte + component layers + # Filter to valid source/target pairs: + # - Valid sources: embedding + component layers + # - Valid targets: component layers + unembed component_layers = set(model.target_module_paths) - valid_sources = component_layers.union({"wte"}) - valid_targets = component_layers.union({"output"}) + valid_sources = component_layers | {embed_path} + valid_targets = component_layers | {unembed_path} - sources_by_target = {} + sources_by_target: dict[str, list[str]] = {} for target, sources in sources_by_target_raw.items(): if target not in valid_targets: continue @@ -177,15 +136,19 @@ def harvest_attributions( sources_by_target[target] = filtered_sources logger.info(f"Found {len(sources_by_target)} target layers with gradient connections") - # Create harvester + # Build alive masks + component_alive = _build_alive_masks(model, run_id, harvest_subrun_id, embed_path, vocab_size) + + # Create harvester (all concrete paths internally) harvester = AttributionHarvester( model=model, sources_by_target=sources_by_target, - # n_components=n_components, vocab_size=vocab_size, component_alive=component_alive, sampling=spd_config.sampling, + embed_path=embed_path, embedding_module=topology.embedding_module, + unembed_path=unembed_path, unembed_module=topology.unembed_module, device=device, ) @@ -197,15 +160,18 @@ def harvest_attributions( batch_range = range(n_batches) case "whole_dataset": batch_range = itertools.count() + for batch_idx in tqdm.tqdm(batch_range, desc="Attribution batches"): try: batch_data = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break + # Skip batches not assigned to this rank if world_size is not None and batch_idx % world_size != rank: continue + batch = extract_batch_data(batch_data).to(device) harvester.process_batch(batch) @@ -213,16 +179,24 @@ def harvest_attributions( f"Processing complete. Tokens: {harvester.n_tokens:,}, Batches: {harvester.n_batches}" ) - # Normalize by n_tokens to get per-token average attribution - # normalized_comp = harvester.comp_accumulator / harvester.n_tokens - # normalized_out_residual = harvester.out_residual_accumulator / harvester.n_tokens + # Translate concrete paths to canonical for storage + to_canon = topology.target_to_canon + normalized = harvester.normalized_attrs() + + def canonicalize(d: dict[str, dict[str, Tensor]]) -> dict[str, dict[str, Tensor]]: + return { + to_canon(target): {to_canon(src): tensor for src, tensor in src_attrs.items()} + for target, src_attrs in d.items() + } - # Build and save storage storage = DatasetAttributionStorage( + attr=canonicalize(normalized.attr), + attr_abs=canonicalize(normalized.attr_abs), + mean_squared_attr=canonicalize(normalized.mean_squared_attr), vocab_size=vocab_size, + ci_threshold=config.ci_threshold, n_batches_processed=harvester.n_batches, n_tokens_processed=harvester.n_tokens, - ci_threshold=config.ci_threshold, ) if rank is not None: @@ -232,75 +206,27 @@ def harvest_attributions( else: output_dir.mkdir(parents=True, exist_ok=True) output_path = output_dir / "dataset_attributions.pt" - # storage.save(output_path) - logger.info(f"Saved dataset attributions to {output_path}") + storage.save(output_path) def merge_attributions(output_dir: Path) -> None: - """Merge partial attribution files from parallel workers. - - Looks for worker_states/dataset_attributions_rank_*.pt files and merges them - into dataset_attributions.pt in the output_dir. - - Uses streaming merge to avoid OOM - loads one file at a time instead of all at once. - """ + """Merge partial attribution files from parallel workers.""" worker_dir = output_dir / "worker_states" rank_files = sorted(worker_dir.glob("dataset_attributions_rank_*.pt")) assert rank_files, f"No rank files found in {worker_dir}" logger.info(f"Found {len(rank_files)} rank files to merge") - # Load first file to get metadata and initialize accumulators - # Use double precision for accumulation to prevent precision loss with billions of tokens - first = cast(DatasetAttributionStorage, None) # DatasetAttributionStorage.load(rank_files[0]) - total_comp = (first.source_to_component * first.n_tokens_processed).double() - total_out_residual = (first.source_to_out_residual * first.n_tokens_processed).double() - total_tokens = first.n_tokens_processed - total_batches = first.n_batches_processed - logger.info(f"Loaded rank 0: {first.n_tokens_processed:,} tokens") - - # Stream remaining files one at a time - for rank_file in tqdm.tqdm(rank_files[1:], desc="Merging rank files"): - storage = cast(DatasetAttributionStorage, None) # DatasetAttributionStorage.load(rank_file) - - # Validate consistency - assert storage.component_layer_keys == first.component_layer_keys, ( - "Component layer keys mismatch" - ) - # assert storage.d_model == first.d_model, "d_model mismatch" - assert storage.ci_threshold == first.ci_threshold, "CI threshold mismatch" - - # Accumulate de-normalized values - total_comp += storage.source_to_component * storage.n_tokens_processed - total_out_residual += storage.source_to_out_residual * storage.n_tokens_processed - total_tokens += storage.n_tokens_processed - total_batches += storage.n_batches_processed - - # Normalize by total tokens and convert back to float32 for storage - merged_comp = (total_comp / total_tokens).float() - merged_out_residual = (total_out_residual / total_tokens).float() - - # Save merged result - merged = DatasetAttributionStorage( - # component_layer_keys=first.component_layer_keys, - # # d_model=first.d_model, - # source_to_component=merged_comp, - # source_to_out_residual=merged_out_residual, - # n_batches_processed=total_batches, - # n_tokens_processed=total_tokens, - # ci_threshold=first.ci_threshold, - vocab_size=0, # vocab_size, - n_batches_processed=total_batches, - n_tokens_processed=total_tokens, - ci_threshold=first.ci_threshold, - ) + merged = DatasetAttributionStorage.merge(rank_files) output_path = output_dir / "dataset_attributions.pt" - # merged.save(output_path) - assert output_path.stat().st_size > 0, f"Merge output is empty: {output_path}" - logger.info(f"Merged {len(rank_files)} files -> {output_path}") - logger.info(f"Total: {total_batches} batches, {total_tokens:,} tokens") - - for rank_file in rank_files: - rank_file.unlink() - worker_dir.rmdir() - logger.info(f"Deleted {len(rank_files)} per-rank files and worker_states/") + merged.save(output_path) + logger.info( + f"Total: {merged.n_batches_processed} batches, {merged.n_tokens_processed:,} tokens" + ) + + # TODO(oli): reenable this + # disabled deletion for testing, posterity and retries + # for rank_file in rank_files: + # rank_file.unlink() + # worker_dir.rmdir() + # logger.info(f"Deleted {len(rank_files)} per-rank files and worker_states/") diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index f4be04f25..947937f0c 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -4,12 +4,26 @@ training dataset using gradient x activation formula, summed over all positions and batches. +Three metrics are accumulated: +- attr: E[∂y/∂x · x] (signed mean attribution) +- attr_abs: E[∂|y|/∂x · x] (attribution to absolute value of target) +- squared_attr: E[(∂y/∂x · x)²] (mean squared attribution, for RMS) + +Naming convention: modifier before "attr" applies to the target (e.g. attr_abs = +attribution to |target|), modifier after applies to the attribution itself +(e.g. squared_attr = squared attribution). + Uses residual-based storage for scalability: -- Component targets: accumulated directly to comp_accumulator -- Output targets: accumulated as attributions to output residual stream (source_to_out_residual) - Output attributions computed on-the-fly at query time via w_unembed +- Component targets: accumulated directly +- Output targets: accumulated as attributions to output residual stream, + computed on-the-fly at query time via w_unembed + +All layer keys are concrete module paths (e.g. "wte", "h.0.attn.q_proj", "lm_head"). +Translation to canonical names happens at the storage boundary in harvest.py. """ +from collections import defaultdict +from dataclasses import dataclass from typing import Any import torch @@ -23,7 +37,7 @@ class AttributionHarvester: - """Accumulates attribution strengths across batches. + """Accumulates attribution strengths across batches using concrete module paths. The attribution formula is: attribution[src, tgt] = Σ_batch Σ_pos (∂out[pos, tgt] / ∂in[pos, src]) × in_act[pos, src] @@ -34,11 +48,6 @@ class AttributionHarvester: 2. For output targets, store attributions to the pre-unembed residual (d_model dimensions) instead of vocab tokens. This eliminates the expensive O((V+C) × d_model × V) matmul during harvesting and reduces storage. - - Index structure: - - Sources: wte tokens [0, vocab_size) + component layers [vocab_size, ...) - - Component targets: [0, n_components) in comp_accumulator - - Output targets: via out_residual_accumulator (computed on-the-fly at query time) """ sampling: SamplingType @@ -50,7 +59,9 @@ def __init__( vocab_size: int, component_alive: dict[str, Bool[Tensor, " n_components"]], sampling: SamplingType, + embed_path: str, embedding_module: nn.Embedding, + unembed_path: str, unembed_module: nn.Linear, device: torch.device, ): @@ -59,7 +70,9 @@ def __init__( self.vocab_size = vocab_size self.component_alive = component_alive self.sampling = sampling + self.embed_path = embed_path self.embedding_module = embedding_module + self.unembed_path = unembed_path self.unembed_module = unembed_module self.device = device @@ -67,16 +80,17 @@ def __init__( self.n_tokens = 0 self.output_d_model = unembed_module.in_features - # Split accumulators for component and output targets - self.component_attr_accumulator = self._get_component_attr_accumulator( - sources_by_target, - component_alive, - unembed_module, - vocab_size, - device, + self.attr_accumulator = self._build_accumulator( + sources_by_target, component_alive, unembed_module, vocab_size, device + ) + self.attr_abs_accumulator = self._build_accumulator( + sources_by_target, component_alive, unembed_module, vocab_size, device + ) + self.square_attr_accumulator = self._build_accumulator( + sources_by_target, component_alive, unembed_module, vocab_size, device ) - def _get_component_attr_accumulator( + def _build_accumulator( self, sources_by_target: dict[str, list[str]], component_alive: dict[str, Bool[Tensor, " n_components"]], @@ -84,44 +98,71 @@ def _get_component_attr_accumulator( vocab_size: int, device: torch.device, ) -> dict[str, dict[str, Tensor]]: - component_attr_accumulator: dict[str, dict[str, Tensor]] = {} + accumulator: dict[str, dict[str, Tensor]] = {} for target_layer, source_layers in sources_by_target.items(): - if target_layer == "output": + if target_layer == self.unembed_path: target_d = unembed_module.in_features else: (target_c,) = component_alive[target_layer].shape target_d = target_c - source_attr_accumulator: dict[str, Tensor] = {} + source_acc: dict[str, Tensor] = {} for source_layer in source_layers: - if source_layer == "wte": + if source_layer == self.embed_path: source_d = vocab_size else: (source_c,) = component_alive[source_layer].shape source_d = source_c - source_attr_accumulator[source_layer] = torch.zeros( - (target_d, source_d), device=device - ) + source_acc[source_layer] = torch.zeros((target_d, source_d), device=device) + + accumulator[target_layer] = source_acc + + return accumulator + + @dataclass + class NormalizedAttrs: + attr: dict[str, dict[str, Tensor]] + attr_abs: dict[str, dict[str, Tensor]] + mean_squared_attr: dict[str, dict[str, Tensor]] - component_attr_accumulator[target_layer] = source_attr_accumulator + def normalized_attrs(self) -> NormalizedAttrs: + """Return the accumulated attributions normalized by n_tokens. - return component_attr_accumulator + mean_squared_attr is pre-sqrt so it can be merged across workers. + """ + attr = defaultdict[str, dict[str, Tensor]](dict) + attr_abs = defaultdict[str, dict[str, Tensor]](dict) + mean_squared_attr = defaultdict[str, dict[str, Tensor]](dict) + + for target in self.attr_accumulator: + for source in self.sources_by_target[target]: + attr[target][source] = self.attr_accumulator[target][source] / self.n_tokens + attr_abs[target][source] = self.attr_abs_accumulator[target][source] / self.n_tokens + mean_squared_attr[target][source] = ( + self.square_attr_accumulator[target][source] / self.n_tokens + ) + + return self.NormalizedAttrs( + attr=attr, + attr_abs=attr_abs, + mean_squared_attr=mean_squared_attr, + ) def process_batch(self, tokens: Int[Tensor, "batch seq"]) -> None: """Accumulate attributions from one batch.""" self.n_batches += 1 self.n_tokens += tokens.numel() - # Setup hooks to capture wte output and pre-unembed residual - wte_out: list[Tensor] = [] + # Setup hooks to capture embedding output and pre-unembed residual + embed_out: list[Tensor] = [] pre_unembed: list[Tensor] = [] - def wte_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: + def embed_hook(_mod: nn.Module, _args: Any, _kwargs: Any, out: Tensor) -> Tensor: out.requires_grad_(True) - wte_out.clear() - wte_out.append(out) + embed_out.clear() + embed_out.append(out) return out def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> None: @@ -129,7 +170,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - h1 = self.embedding_module.register_forward_hook(wte_hook, with_kwargs=True) + h1 = self.embedding_module.register_forward_hook(embed_hook, with_kwargs=True) h2 = self.unembed_module.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) # Get masks with all components active @@ -154,13 +195,11 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No h2.remove() cache = comp_output.cache - cache["wte_post_detach"] = wte_out[0] - cache["pre_unembed"] = pre_unembed[0] - # cache["tokens"] = tokens + cache[f"{self.embed_path}_post_detach"] = embed_out[0] + cache[f"{self.unembed_path}_pre_detach"] = pre_unembed[0] - # Process each target layer for target_layer in self.sources_by_target: - if target_layer == "output": + if target_layer == self.unembed_path: self._process_output_targets(cache, ci.lower_leaky, tokens) else: self._process_component_targets(target_layer, ci.lower_leaky, cache, tokens) @@ -171,26 +210,21 @@ def _process_output_targets( ci: dict[str, Tensor], tokens: Int[Tensor, "batch seq"], ) -> None: - """Process output attributions via output-residual-space storage. - - Instead of computing and storing attributions to vocab tokens directly, - we store attributions to output residual dimensions. Output attributions are - computed on-the-fly at query time via: attr[src, token] = out_residual[src] @ w_unembed[:, token] - """ - # Sum output residual over batch and sequence -> [d_model] - out_residual = cache["pre_unembed"].sum(dim=(0, 1)) + """Process output attributions via output-residual-space storage.""" + out_residual = cache[f"{self.unembed_path}_pre_detach"].sum(dim=(0, 1)) - source_layers = self.sources_by_target["output"] + source_layers = self.sources_by_target[self.unembed_path] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] for d_idx in range(self.output_d_model): grads = torch.autograd.grad(out_residual[d_idx], source_acts, retain_graph=True) - source_acts_grads = list(zip(source_layers, source_acts, grads, strict=True)) self._accumulate_attributions( - "output", + self.unembed_path, d_idx, - source_acts_grads, + source_layers, + source_acts, + list(grads), ci, tokens, ) @@ -207,8 +241,6 @@ def _process_component_targets( if not alive_targets.any(): return - # Sum over batch and sequence - target_acts_raw = cache[f"{target_layer}_pre_detach"] ci_weighted_target_acts = (target_acts_raw * ci[target_layer]).sum(dim=(0, 1)) @@ -220,39 +252,54 @@ def _process_component_targets( ci_weighted_target_acts[t_idx], source_acts, retain_graph=True ) - source_acts_grads = list(zip(source_layers, source_acts, grads, strict=True)) - self._accumulate_attributions( target_layer, t_idx, - source_acts_grads, + source_layers, + source_acts, + list(grads), ci, tokens, ) - @torch.no_grad() + @torch.no_grad() # pyright: ignore[reportUntypedFunctionDecorator] def _accumulate_attributions( self, target_layer: str, target_idx: int, - source_acts_grads: list[tuple[str, Tensor, Tensor]], + source_layers: list[str], + source_acts: list[Tensor], + source_grads: list[Tensor], ci: dict[str, Tensor], tokens: Int[Tensor, "batch seq"], ) -> None: """Accumulate grad*act attributions from sources to a target column.""" - target_accs = self.component_attr_accumulator[target_layer] - for source_layer, act, grad in source_acts_grads: - attr_accumulator = target_accs[source_layer][target_idx] + attr_accumulator = self.attr_accumulator[target_layer] + attr_abr_accumulator = self.attr_abs_accumulator[target_layer] + square_attr_accumulator = self.square_attr_accumulator[target_layer] + + for source_layer, act, grad in zip(source_layers, source_acts, source_grads, strict=True): + attr_acc = attr_accumulator[source_layer][target_idx] + attr_abs_acc = attr_abr_accumulator[source_layer][target_idx] + square_attr_acc = square_attr_accumulator[source_layer][target_idx] ci_weighted_attr = grad * act * ci[source_layer] + ci_weighted_attr_abs = torch.where(act > 0, ci_weighted_attr, -ci_weighted_attr) + ci_weighted_squared_attr = ci_weighted_attr.square() - if source_layer == "wte": + if source_layer == self.embed_path: # Per-token: sum grad*act*ci over d_model, scatter by token id # TODO(oli): figure out why this works attr = ci_weighted_attr.sum(dim=-1).flatten() - attr_accumulator.scatter_add_(0, tokens.flatten(), attr) + attr_abs = ci_weighted_attr_abs.sum(dim=-1).flatten() + attr_squared = ci_weighted_squared_attr.sum(dim=-1).flatten() + + attr_acc.scatter_add_(0, tokens.flatten(), attr) + attr_abs_acc.scatter_add_(0, tokens.flatten(), attr_abs) + square_attr_acc.scatter_add_(0, tokens.flatten(), attr_squared) else: # Per-component: sum grad*act*ci over batch and sequence - attr = ci_weighted_attr.sum(dim=(0, 1)) - attr_accumulator.add_(attr) + attr_acc.add_(ci_weighted_attr.sum(dim=(0, 1))) + attr_abs_acc.add_(ci_weighted_attr_abs.sum(dim=(0, 1))) + square_attr_acc.add_(ci_weighted_squared_attr.sum(dim=(0, 1))) diff --git a/spd/dataset_attributions/repo.py b/spd/dataset_attributions/repo.py index bd73c5f63..697036ba3 100644 --- a/spd/dataset_attributions/repo.py +++ b/spd/dataset_attributions/repo.py @@ -49,7 +49,7 @@ def open(cls, run_id: str) -> "AttributionRepo | None": path = subrun_dir / "dataset_attributions.pt" if not path.exists(): return None - return None # return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) def get_attributions(self) -> DatasetAttributionStorage: return self._storage diff --git a/spd/dataset_attributions/storage.py b/spd/dataset_attributions/storage.py index 3027a6519..ca1c8e59d 100644 --- a/spd/dataset_attributions/storage.py +++ b/spd/dataset_attributions/storage.py @@ -1,22 +1,28 @@ """Storage classes for dataset attributions. -Uses a residual-based storage approach for scalability: -- Component targets: stored directly in source_to_component matrix -- Output targets: stored as attributions to residual stream, computed on-the-fly via w_unembed +Stored as nested dicts: attrs[target_layer][source_layer] = Tensor[target_d, source_d] + +Three attribution metrics are stored: +- attr: mean attribution of source to target (signed) +- attr_abs: mean attribution of source to |target| (always positive for positive activations) +- mean_squared_attr: mean of squared attributions (pre-sqrt, for mergeable RMS) + +For output targets, target_d = d_model (residual stream dimension). +Output token attributions are computed on-the-fly via w_unembed. """ -import dataclasses -from collections.abc import Callable +import bisect from dataclasses import dataclass from pathlib import Path from typing import Literal import torch -from jaxtyping import Float from torch import Tensor from spd.log import logger +AttrDict = dict[str, dict[str, Tensor]] + @dataclass class DatasetAttributionEntry: @@ -31,334 +37,208 @@ class DatasetAttributionEntry: class DatasetAttributionStorage: """Dataset-aggregated attribution strengths between components. - Uses residual-based storage for scalability with large vocabularies: - - source_to_component: direct attributions to component targets - - source_to_out_residual: attributions to output residual stream (for computing output attributions) - - Output attributions are computed on-the-fly: attr[src, output_token] = out_residual[src] @ w_unembed[:, token] - - Source indexing (rows): - - [0, vocab_size): wte tokens - - [vocab_size, vocab_size + n_components): component layers - - Target indexing: - - Component targets: [0, n_components) in source_to_component - - Output targets: computed via source_to_out_residual @ w_unembed + All layer names use canonical addressing (e.g., "embed", "0.glu.up", "output"). Key formats: - - wte tokens: "wte:{token_id}" - - component layers: "layer:c_idx" (e.g., "h.0.attn.q_proj:5") + - embed tokens: "embed:{token_id}" + - component layers: "canonical_layer:c_idx" (e.g., "0.glu.up:5") - output tokens: "output:{token_id}" """ - @property - def source_to_component(self) -> Float[Tensor, "n_sources n_components"]: - """Attributions from sources to component targets. Shape: (vocab_size + n_components, - n_components)""" - raise NotImplementedError("source_to_component is not implemented with new storage format") - - @property - def source_to_out_residual(self) -> Float[Tensor, "n_sources d_model"]: - """Attributions from sources to output residual dimensions. Shape: (vocab_size + n_components, - d_model)""" - raise NotImplementedError( - "source_to_out_residual is not implemented with new storage format" - ) - - @property - def component_layer_keys(self) -> list[str]: - """Component layer keys in order: ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...]""" - raise NotImplementedError("component_layer_keys is not implemented with new storage format") - - @property - def n_components(self) -> int: - """Number of component layers.""" - raise NotImplementedError("n_components is not implemented with new storage format") - # return len(self.component_layer_keys) - def __init__( self, + attr: AttrDict, + attr_abs: AttrDict, + mean_squared_attr: AttrDict, + vocab_size: int, ci_threshold: float, - vocab_size: int, # d_model: int - # TODO(oli): check these are needed n_batches_processed: int, n_tokens_processed: int, ): + self.attr = attr + self.attr_abs = attr_abs + self.mean_squared_attr = mean_squared_attr + self.vocab_size = vocab_size self.ci_threshold = ci_threshold - self._REMOVE_ME_vocab_size = vocab_size self.n_batches_processed = n_batches_processed self.n_tokens_processed = n_tokens_processed - # _component_key_to_idx: dict[str, int] = dataclasses.field( - # default_factory=dict, repr=False, init=False - # ) - - # def __post_init__(self) -> None: - # self._component_key_to_idx = {k: i for i, k in enumerate(self.component_layer_keys)} - - # n_components = len(self.component_layer_keys) - # n_sources = self.vocab_size + n_components - - # expected_comp_shape = (n_sources, n_components) - # assert self.source_to_component.shape == expected_comp_shape, ( - # f"source_to_component shape {self.source_to_component.shape} " - # f"doesn't match expected {expected_comp_shape}" - # ) - - # expected_resid_shape = (n_sources, self.d_model) - # assert self.source_to_out_residual.shape == expected_resid_shape, ( - # f"source_to_out_residual shape {self.source_to_out_residual.shape} " - # f"doesn't match expected {expected_resid_shape}" - # ) - - # @property - # def n_components(self) -> int: - # return len(self.component_layer_keys) - - # @property - # def n_sources(self) -> int: - # return self.vocab_size + self.n_components - - def _parse_key(self, key: str) -> tuple[str, int]: - """Parse a key into (layer, idx).""" + @property + def n_components(self) -> int: + total = 0 + for target_layer in self.attr: + if target_layer == "output": + continue + first_source = next(iter(self.attr[target_layer].values())) + total += first_source.shape[0] + return total + + @staticmethod + def _parse_key(key: str) -> tuple[str, int]: layer, idx_str = key.rsplit(":", 1) return layer, int(idx_str) - def _source_idx(self, key: str) -> int: - """Get source (row) index for a key. Raises KeyError if not a valid source.""" - layer, idx = self._parse_key(key) - match layer: - case "wte": - assert 0 <= idx < self._REMOVE_ME_vocab_size, ( - f"wte index {idx} out of range [0, {self._REMOVE_ME_vocab_size})" - ) - return idx - case "output": - raise KeyError(f"output tokens cannot be sources: {key}") - case _: - return self._REMOVE_ME_vocab_size + self._component_key_to_idx[key] - - def _component_target_idx(self, key: str) -> int: - """Get target index for a component key. Raises KeyError if output or invalid.""" - if key.startswith(("wte:", "output:")): - raise KeyError(f"Not a component target: {key}") - return self._component_key_to_idx[key] - - def _source_idx_to_key(self, idx: int) -> str: - """Convert source (row) index to key.""" - if idx < self._REMOVE_ME_vocab_size: - return f"wte:{idx}" - return self.component_layer_keys[idx - self._REMOVE_ME_vocab_size] - - def _component_target_idx_to_key(self, idx: int) -> str: - """Convert component target index to key.""" - return self.component_layer_keys[idx] - - def _output_target_idx_to_key(self, idx: int) -> str: - """Convert output token index to key.""" - return f"output:{idx}" - - def _is_output_target(self, key: str) -> bool: - """Check if key is an output target.""" - return key.startswith("output:") - - def _output_token_id(self, key: str) -> int: - """Extract token_id from an output key like 'output:123'. Asserts valid range.""" - _, token_id = self._parse_key(key) - assert 0 <= token_id < self._REMOVE_ME_vocab_size, f"output index {token_id} out of range" - return token_id - def has_source(self, key: str) -> bool: - """Check if a key can be a source (wte token or component layer).""" layer, idx = self._parse_key(key) - match layer: - case "wte": - return 0 <= idx < self._REMOVE_ME_vocab_size - case "output": - return False - case _: - return key in self._component_key_to_idx + if layer == "output": + return False + for target_sources in self.attr.values(): + if layer in target_sources: + return 0 <= idx < target_sources[layer].shape[1] + return False def has_target(self, key: str) -> bool: - """Check if a key can be a target (component layer or output token).""" layer, idx = self._parse_key(key) match layer: - case "wte": + case "embed": return False case "output": - return 0 <= idx < self._REMOVE_ME_vocab_size + return 0 <= idx < self.vocab_size case _: - return key in self._component_key_to_idx - - # TODO redo with new storage format - # def save(self, path: Path) -> None: - # path.parent.mkdir(parents=True, exist_ok=True) - # torch.save( - # { - # "component_layer_keys": self.component_layer_keys, - # "vocab_size": self._REMOVE_ME_vocab_size, - # "source_to_component": self.source_to_component.cpu(), - # "source_to_out_residual": self.source_to_out_residual.cpu(), - # "n_batches_processed": self.n_batches_processed, - # "n_tokens_processed": self.n_tokens_processed, - # "ci_threshold": self.ci_threshold, - # }, - # path, - # ) - # size_mb = path.stat().st_size / (1024 * 1024) - # logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") - # @classmethod - # def load(cls, path: Path) -> "DatasetAttributionStorage": - # data = torch.load(path, weights_only=True, mmap=True) - # return cls( - # component_layer_keys=data["component_layer_keys"], - # vocab_size=data["vocab_size"], - # d_model=data["d_model"], - # source_to_component=data["source_to_component"], - # source_to_out_residual=data["source_to_out_residual"], - # n_batches_processed=data["n_batches_processed"], - # n_tokens_processed=data["n_tokens_processed"], - # ci_threshold=data["ci_threshold"], - # ) - - def get_attribution( - self, - source_key: str, - target_key: str, - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> float: - """Get attribution strength from source to target. - - Args: - source_key: Source component key (wte or component layer) - target_key: Target component key (component layer or output token) - w_unembed: Unembedding matrix, required if target is an output token - """ - src_idx = self._source_idx(source_key) + if layer not in self.attr: + return False + first_source = next(iter(self.attr[layer].values())) + return 0 <= idx < first_source.shape[0] - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - return (self.source_to_out_residual[src_idx] @ w_unembed[:, token_id]).item() + # TODO: these methods need a metric parameter to select which of the 3 attr dicts to query + def get_attribution(self, *_args: object, **_kwargs: object) -> float: + raise ValueError("TODO: get_attribution needs metric selection") - tgt_idx = self._component_target_idx(target_key) - return self.source_to_component[src_idx, tgt_idx].item() + def get_top_sources(self, *_args: object, **_kwargs: object) -> list[DatasetAttributionEntry]: + raise ValueError("TODO: get_top_sources needs metric selection") - def _get_top_k( + def get_top_targets(self, *_args: object, **_kwargs: object) -> list[DatasetAttributionEntry]: + raise ValueError("TODO: get_top_targets needs metric selection") + + def _top_k_from_segments( self, - values: Tensor, + value_segments: list[Tensor], + layer_names: list[str], k: int, sign: Literal["positive", "negative"], - idx_to_key: Callable[[int], str], ) -> list[DatasetAttributionEntry]: - """Get top-k entries from a 1D tensor of attribution values.""" + if not value_segments: + return [] + + all_values = torch.cat(value_segments) + offsets = [0] + for seg in value_segments: + offsets.append(offsets[-1] + len(seg)) + is_positive = sign == "positive" - top_vals, top_idxs = torch.topk(values, min(k, len(values)), largest=is_positive) + top_vals, top_idxs = torch.topk(all_values, min(k, len(all_values)), largest=is_positive) - # Filter to only values matching the requested sign mask = top_vals > 0 if is_positive else top_vals < 0 top_vals, top_idxs = top_vals[mask], top_idxs[mask] results = [] - for idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): - key = idx_to_key(idx) - layer, c_idx = self._parse_key(key) + for flat_idx, val in zip(top_idxs.tolist(), top_vals.tolist(), strict=True): + seg_idx = bisect.bisect_right(offsets, flat_idx) - 1 + local_idx = flat_idx - offsets[seg_idx] + layer = layer_names[seg_idx] results.append( DatasetAttributionEntry( - component_key=key, + component_key=f"{layer}:{local_idx}", layer=layer, - component_idx=c_idx, + component_idx=local_idx, value=val, ) ) return results - def get_top_sources( - self, - target_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target. + def save(self, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + + def to_cpu(d: AttrDict) -> AttrDict: + return { + target: {source: tensor.cpu() for source, tensor in sources.items()} + for target, sources in d.items() + } + + torch.save( + { + "attr": to_cpu(self.attr), + "attr_abs": to_cpu(self.attr_abs), + "mean_squared_attr": to_cpu(self.mean_squared_attr), + "vocab_size": self.vocab_size, + "ci_threshold": self.ci_threshold, + "n_batches_processed": self.n_batches_processed, + "n_tokens_processed": self.n_tokens_processed, + }, + path, + ) + size_mb = path.stat().st_size / (1024 * 1024) + logger.info(f"Saved dataset attributions to {path} ({size_mb:.1f} MB)") + + @classmethod + def load(cls, path: Path) -> "DatasetAttributionStorage": + data = torch.load(path, weights_only=True) + return cls( + attr=data["attr"], + attr_abs=data["attr_abs"], + mean_squared_attr=data["mean_squared_attr"], + vocab_size=data["vocab_size"], + ci_threshold=data["ci_threshold"], + n_batches_processed=data["n_batches_processed"], + n_tokens_processed=data["n_tokens_processed"], + ) - Args: - target_key: Target component key (component layer or output token) - k: Number of top sources to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if target is an output token - """ - if self._is_output_target(target_key): - assert w_unembed is not None, "w_unembed required for output target queries" - token_id = self._output_token_id(target_key) - w_unembed = w_unembed.to(self.source_to_out_residual.device) - values = self.source_to_out_residual @ w_unembed[:, token_id] # (n_sources,) - else: - tgt_idx = self._component_target_idx(target_key) - values = self.source_to_component[:, tgt_idx] - - return self._get_top_k(values, k, sign, self._source_idx_to_key) - - def get_top_targets( - self, - source_key: str, - k: int, - sign: Literal["positive", "negative"], - w_unembed: Float[Tensor, "d_model vocab"] | None = None, - include_outputs: bool = True, - ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO. - - Args: - source_key: Source component key (wte or component layer) - k: Number of top targets to return - sign: "positive" for strongest positive, "negative" for strongest negative - w_unembed: Unembedding matrix, required if include_outputs=True - include_outputs: Whether to include output tokens in results + @classmethod + def merge(cls, paths: list[Path]) -> "DatasetAttributionStorage": + """Merge partial attribution files from parallel workers. + + All three metrics are means, so merge is weighted average by n_tokens. + (mean_squared_attr is E[x²], not sqrt(E[x²]), so this works.) """ - src_idx = self._source_idx(source_key) - comp_values = self.source_to_component[src_idx, :] # (n_components,) - - if include_outputs: - assert w_unembed is not None, "w_unembed required when include_outputs=True" - # Compute attributions to all output tokens - w_unembed = w_unembed.to(self.source_to_out_residual.device) - output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - all_values = torch.cat([comp_values, output_values]) - - def combined_idx_to_key(idx: int) -> str: - if idx < self.n_components: - return self._component_target_idx_to_key(idx) - return self._output_target_idx_to_key(idx - self.n_components) - - return self._get_top_k(all_values, k, sign, combined_idx_to_key) - - return self._get_top_k(comp_values, k, sign, self._component_target_idx_to_key) - - # Unused apart from tests - # def get_top_component_targets( - # self, - # source_key: str, - # k: int, - # sign: Literal["positive", "negative"], - # ) -> list[DatasetAttributionEntry]: - # """Get top-k component targets (excluding outputs) this source attributes TO. - - # Convenience method that doesn't require w_unembed. - # """ - # return self.get_top_targets(source_key, k, sign, w_unembed=None, include_outputs=False) - - # Unused - # def get_top_output_targets( - # self, - # source_key: str, - # k: int, - # sign: Literal["positive", "negative"], - # w_unembed: Float[Tensor, "d_model vocab"], - # ) -> list[DatasetAttributionEntry]: - # """Get top-k output token targets this source attributes TO.""" - # src_idx = self._source_idx(source_key) - # w_unembed = w_unembed.to(self.source_to_out_residual.device) - # output_values = self.source_to_out_residual[src_idx, :] @ w_unembed # (vocab,) - # return self._get_top_k(output_values, k, sign, self._output_target_idx_to_key) + assert paths, "No files to merge" + + first = cls.load(paths[0]) + n = first.n_tokens_processed + + def denormalize(d: AttrDict, n_tokens: int) -> AttrDict: + return { + target: {source: (tensor * n_tokens).double() for source, tensor in sources.items()} + for target, sources in d.items() + } + + total_attr = denormalize(first.attr, n) + total_attr_abs = denormalize(first.attr_abs, n) + total_mean_squared_attr = denormalize(first.mean_squared_attr, n) + total_tokens = n + total_batches = first.n_batches_processed + + for path in paths[1:]: + storage = cls.load(path) + assert storage.ci_threshold == first.ci_threshold, "CI threshold mismatch" + assert storage.attr.keys() == first.attr.keys(), "Target layer mismatch" + n = storage.n_tokens_processed + + for target, sources in storage.attr.items(): + for source, tensor in sources.items(): + total_attr[target][source] += (tensor * n).double() + total_attr_abs[target][source] += ( + storage.attr_abs[target][source] * n + ).double() + total_mean_squared_attr[target][source] += ( + storage.mean_squared_attr[target][source] * n + ).double() + total_tokens += n + total_batches += storage.n_batches_processed + + def normalize(d: AttrDict) -> AttrDict: + return { + target: { + source: (tensor / total_tokens).float() for source, tensor in sources.items() + } + for target, sources in d.items() + } + + return cls( + attr=normalize(total_attr), + attr_abs=normalize(total_attr_abs), + mean_squared_attr=normalize(total_mean_squared_attr), + vocab_size=first.vocab_size, + ci_threshold=first.ci_threshold, + n_batches_processed=total_batches, + n_tokens_processed=total_tokens, + ) diff --git a/spd/topology/gradient_connectivity.py b/spd/topology/gradient_connectivity.py index 3337d208c..31ba61b5a 100644 --- a/spd/topology/gradient_connectivity.py +++ b/spd/topology/gradient_connectivity.py @@ -74,8 +74,8 @@ def embed_hook( cache[f"{embed_path}_post_detach"] = embed_cache[f"{embed_path}_post_detach"] cache[f"{unembed_path}_pre_detach"] = comp_output_with_cache.output - source_layers = [embed_path, *model.target_module_paths] # Don't include "output" as source - target_layers = [*model.target_module_paths, unembed_path] # Don't include embed as target + source_layers = [embed_path, *model.target_module_paths] # Don't include "output" as source + target_layers = [*model.target_module_paths, unembed_path] # Don't include embed as target # Test all distinct pairs for gradient flow test_pairs = [] diff --git a/tests/dataset_attributions/test_harvester.py b/tests/dataset_attributions/test_harvester.py deleted file mode 100644 index 3df88a508..000000000 --- a/tests/dataset_attributions/test_harvester.py +++ /dev/null @@ -1,265 +0,0 @@ -"""Tests for dataset attribution harvester logic.""" - -from pathlib import Path - -import torch - -from spd.dataset_attributions.storage import DatasetAttributionStorage - - -def _make_storage( - n_components: int = 2, - vocab_size: int = 3, - d_model: int = 4, - source_to_component: torch.Tensor | None = None, - source_to_out_residual: torch.Tensor | None = None, -) -> DatasetAttributionStorage: - """Helper to create storage with default values.""" - n_sources = vocab_size + n_components - if source_to_component is None: - source_to_component = torch.zeros(n_sources, n_components) - if source_to_out_residual is None: - source_to_out_residual = torch.zeros(n_sources, d_model) - - return DatasetAttributionStorage( - component_layer_keys=[f"layer1:{i}" for i in range(n_components)], - _REMOVE_ME_vocab_size=vocab_size, - d_model=d_model, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - n_batches_processed=10, - n_tokens_processed=1000, - ci_threshold=0.0, - ) - - -class TestDatasetAttributionStorage: - """Tests for DatasetAttributionStorage. - - Storage structure: - - source_to_component: (n_sources, n_components) for component target attributions - - source_to_out_residual: (n_sources, d_model) for output target attributions (via w_unembed) - """ - - def test_has_source_and_target(self) -> None: - """Test has_source and has_target methods.""" - storage = _make_storage(n_components=2, vocab_size=3) - - # wte tokens can only be sources - assert storage.has_source("wte:0") - assert storage.has_source("wte:2") - assert not storage.has_source("wte:3") # Out of vocab - assert not storage.has_target("wte:0") # wte can't be target - - # Component layers can be both sources and targets - assert storage.has_source("layer1:0") - assert storage.has_source("layer1:1") - assert storage.has_target("layer1:0") - assert storage.has_target("layer1:1") - assert not storage.has_source("layer1:2") - assert not storage.has_target("layer1:2") - - # output tokens can only be targets - assert storage.has_target("output:0") - assert storage.has_target("output:2") - assert not storage.has_target("output:3") # Out of vocab - assert not storage.has_source("output:0") # output can't be source - - def test_get_attribution_component_target(self) -> None: - """Test get_attribution for component targets (no w_unembed needed).""" - # 2 component layers: layer1:0, layer1:1 - # vocab_size=2, d_model=4 - # n_sources = 2 + 2 = 4 - # source_to_component shape: (4, 2) - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 -> components - [3.0, 4.0], # wte:1 -> components - [5.0, 6.0], # layer1:0 -> components - [7.0, 8.0], # layer1:1 -> components - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # wte:0 -> layer1:0 - assert storage.get_attribution("wte:0", "layer1:0") == 1.0 - # wte:1 -> layer1:1 - assert storage.get_attribution("wte:1", "layer1:1") == 4.0 - # layer1:0 -> layer1:1 - assert storage.get_attribution("layer1:0", "layer1:1") == 6.0 - - def test_get_attribution_output_target(self) -> None: - """Test get_attribution for output targets (requires w_unembed).""" - # source_to_out_residual shape: (4, 4) for n_sources=4, d_model=4 - source_to_out_residual = torch.tensor( - [ - [1.0, 0.0, 0.0, 0.0], # wte:0 -> out_residual - [0.0, 1.0, 0.0, 0.0], # wte:1 -> out_residual - [0.0, 0.0, 1.0, 0.0], # layer1:0 -> out_residual - [0.0, 0.0, 0.0, 1.0], # layer1:1 -> out_residual - ] - ) - # w_unembed shape: (d_model=4, vocab=2) - w_unembed = torch.tensor( - [ - [1.0, 2.0], # d0 -> outputs - [3.0, 4.0], # d1 -> outputs - [5.0, 6.0], # d2 -> outputs - [7.0, 8.0], # d3 -> outputs - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, d_model=4, source_to_out_residual=source_to_out_residual - ) - - # wte:0 -> output:0 = out_residual[0] @ w_unembed[:, 0] = [1,0,0,0] @ [1,3,5,7] = 1.0 - assert storage.get_attribution("wte:0", "output:0", w_unembed=w_unembed) == 1.0 - # wte:1 -> output:1 = [0,1,0,0] @ [2,4,6,8] = 4.0 - assert storage.get_attribution("wte:1", "output:1", w_unembed=w_unembed) == 4.0 - # layer1:0 -> output:0 = [0,0,1,0] @ [1,3,5,7] = 5.0 - assert storage.get_attribution("layer1:0", "output:0", w_unembed=w_unembed) == 5.0 - - def test_get_top_sources_component_target(self) -> None: - """Test get_top_sources for component targets.""" - source_to_component = torch.tensor( - [ - [1.0, 2.0], # wte:0 - [5.0, 3.0], # wte:1 - [2.0, 4.0], # layer1:0 - [3.0, 1.0], # layer1:1 - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - # Top sources TO layer1:0 (column 0): wte:0=1.0, wte:1=5.0, layer1:0=2.0, layer1:1=3.0 - sources = storage.get_top_sources("layer1:0", k=2, sign="positive") - assert len(sources) == 2 - assert sources[0].component_key == "wte:1" - assert sources[0].value == 5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == 3.0 - - def test_get_top_sources_negative(self) -> None: - """Test get_top_sources with negative sign.""" - source_to_component = torch.tensor( - [ - [-1.0, 2.0], - [-5.0, 3.0], - [-2.0, 4.0], - [-3.0, 1.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - sources = storage.get_top_sources("layer1:0", k=2, sign="negative") - assert len(sources) == 2 - # wte:1 has most negative (-5.0), then layer1:1 (-3.0) - assert sources[0].component_key == "wte:1" - assert sources[0].value == -5.0 - assert sources[1].component_key == "layer1:1" - assert sources[1].value == -3.0 - - def test_get_top_component_targets(self) -> None: - """Test get_top_component_targets (no w_unembed needed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - storage = _make_storage( - n_components=2, vocab_size=2, source_to_component=source_to_component - ) - - targets = storage.get_top_component_targets("layer1:0", k=2, sign="positive") - assert len(targets) == 2 - assert targets[0].component_key == "layer1:1" - assert targets[0].value == 4.0 - assert targets[1].component_key == "layer1:0" - assert targets[1].value == 2.0 - - def test_get_top_targets_with_outputs(self) -> None: - """Test get_top_targets including outputs (requires w_unembed).""" - source_to_component = torch.tensor( - [ - [0.0, 0.0], - [0.0, 0.0], - [2.0, 4.0], # layer1:0 -> components - [0.0, 0.0], - ] - ) - # Make out_residual attribution that produces high output values - source_to_out_residual = torch.tensor( - [ - [0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0], # layer1:0 -> out_residual (sum=4 per output) - [0.0, 0.0, 0.0, 0.0], - ] - ) - # w_unembed that gives output:0=10, output:1=5 - w_unembed = torch.tensor( - [ - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - [2.5, 1.25], - ] - ) - storage = _make_storage( - n_components=2, - vocab_size=2, - d_model=4, - source_to_component=source_to_component, - source_to_out_residual=source_to_out_residual, - ) - - targets = storage.get_top_targets("layer1:0", k=3, sign="positive", w_unembed=w_unembed) - assert len(targets) == 3 - # output:0 = 10.0, output:1 = 5.0, layer1:1 = 4.0 - assert targets[0].component_key == "output:0" - assert targets[0].value == 10.0 - assert targets[1].component_key == "output:1" - assert targets[1].value == 5.0 - assert targets[2].component_key == "layer1:1" - assert targets[2].value == 4.0 - - def test_save_and_load(self, tmp_path: Path) -> None: - """Test save and load roundtrip.""" - n_components = 2 - vocab_size = 3 - d_model = 4 - n_sources = vocab_size + n_components - - original = DatasetAttributionStorage( - component_layer_keys=["layer:0", "layer:1"], - _REMOVE_ME_vocab_size=vocab_size, - d_model=d_model, - source_to_component=torch.randn(n_sources, n_components), - source_to_out_residual=torch.randn(n_sources, d_model), - n_batches_processed=100, - n_tokens_processed=10000, - ci_threshold=0.01, - ) - - path = tmp_path / "test_attributions.pt" - original.save(path) - - loaded = DatasetAttributionStorage.load(path) - - assert loaded.component_layer_keys == original.component_layer_keys - assert loaded._REMOVE_ME_vocab_size == original._REMOVE_ME_vocab_size - assert loaded.d_model == original.d_model - assert loaded.n_batches_processed == original.n_batches_processed - assert loaded.n_tokens_processed == original.n_tokens_processed - assert loaded.ci_threshold == original.ci_threshold - assert torch.allclose(loaded.source_to_component, original.source_to_component) - assert torch.allclose(loaded.source_to_out_residual, original.source_to_out_residual) From 9d9a4a31ddadc662ce9d16b556c63c5dd0840f76 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 17:32:53 +0000 Subject: [PATCH 03/17] Fix alive_targets iteration: use torch.where for indices, not bool tolist alive_targets is a bool tensor; .tolist() gives [True, False, ...] not indices. torch.autograd.grad needs a scalar output, so index with actual int indices. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/dataset_attributions/harvester.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 947937f0c..c0997f095 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -247,7 +247,7 @@ def _process_component_targets( source_layers = self.sources_by_target[target_layer] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] - for t_idx in alive_targets.tolist(): + for t_idx in torch.where(alive_targets)[0].tolist(): grads = torch.autograd.grad( ci_weighted_target_acts[t_idx], source_acts, retain_graph=True ) From 3c0ba4b56c38ef68ddd6530bcdd9b454af324916 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 17:51:14 +0000 Subject: [PATCH 04/17] Fix KeyError for embed source: CI dict doesn't include embedding layer Embed tokens have no CI (always active), so skip CI weighting for embed sources. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/dataset_attributions/harvester.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index c0997f095..27c6d5d43 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -284,7 +284,9 @@ def _accumulate_attributions( attr_abs_acc = attr_abr_accumulator[source_layer][target_idx] square_attr_acc = square_attr_accumulator[source_layer][target_idx] - ci_weighted_attr = grad * act * ci[source_layer] + # Embed has no CI (all tokens always active) + source_ci = ci[source_layer] if source_layer != self.embed_path else 1.0 + ci_weighted_attr = grad * act * source_ci ci_weighted_attr_abs = torch.where(act > 0, ci_weighted_attr, -ci_weighted_attr) ci_weighted_squared_attr = ci_weighted_attr.square() From d736f425d4e950ef0f47faf54a8a147223c25b40 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 18:05:54 +0000 Subject: [PATCH 05/17] Fix scatter_add OOB: use embedding num_embeddings instead of tokenizer vocab_size tokenizer.vocab_size (50254) < len(tokenizer) (50277) due to added tokens. Token IDs >= vocab_size cause scatter_add_ index out of bounds in the embed accumulator. Use embedding_module.num_embeddings which matches the actual token ID space. Also add Path type annotation to test tmp_path params. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/dataset_attributions/harvest.py | 7 +- tests/dataset_attributions/test_storage.py | 201 +++++++++++++++++++++ 2 files changed, 204 insertions(+), 4 deletions(-) create mode 100644 tests/dataset_attributions/test_storage.py diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 6c55a83aa..d55fc52c9 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -108,16 +108,15 @@ def harvest_attributions( model.eval() spd_config = run_info.config - train_loader, tokenizer = train_loader_and_tokenizer(spd_config, config.batch_size) - vocab_size = tokenizer.vocab_size - assert isinstance(vocab_size, int), f"vocab_size must be int, got {type(vocab_size)}" - logger.info(f"Vocab size: {vocab_size}") + train_loader, _ = train_loader_and_tokenizer(spd_config, config.batch_size) # Get gradient connectivity logger.info("Computing sources_by_target...") topology = TransformerTopology(model.target_model) embed_path = topology.path_schema.embedding_path unembed_path = topology.path_schema.unembed_path + vocab_size = topology.embedding_module.num_embeddings + logger.info(f"Vocab size: {vocab_size}") sources_by_target_raw = get_sources_by_target(model, topology, str(device), spd_config.sampling) # Filter to valid source/target pairs: diff --git a/tests/dataset_attributions/test_storage.py b/tests/dataset_attributions/test_storage.py new file mode 100644 index 000000000..4f394f509 --- /dev/null +++ b/tests/dataset_attributions/test_storage.py @@ -0,0 +1,201 @@ +"""Tests for DatasetAttributionStorage.""" + +from pathlib import Path + +import torch +from torch import Tensor + +from spd.dataset_attributions.storage import DatasetAttributionStorage + +VOCAB_SIZE = 4 +D_MODEL = 4 +LAYER_0 = "0.glu.up" +LAYER_1 = "1.glu.up" +C0 = 3 # components in layer 0 +C1 = 2 # components in layer 1 + + +def _make_attr_dict(seed: int = 0) -> dict[str, dict[str, Tensor]]: + """Build attr dict for the test topology. + + Sources by target: + "0.glu.up": ["embed"] -> shape (C0, VOCAB_SIZE) + "1.glu.up": ["embed", "0.glu.up"] -> shape (C1, VOCAB_SIZE), (C1, C0) + "output": ["0.glu.up", "1.glu.up"] -> shape (D_MODEL, C0), (D_MODEL, C1) + """ + g = torch.Generator().manual_seed(seed) + + def rand(*shape: int) -> Tensor: + return torch.randn(*shape, generator=g) + + return { + LAYER_0: {"embed": rand(C0, VOCAB_SIZE)}, + LAYER_1: {"embed": rand(C1, VOCAB_SIZE), LAYER_0: rand(C1, C0)}, + "output": {LAYER_0: rand(D_MODEL, C0), LAYER_1: rand(D_MODEL, C1)}, + } + + +def _make_storage( + seed: int = 0, n_batches: int = 10, n_tokens: int = 640 +) -> DatasetAttributionStorage: + return DatasetAttributionStorage( + attr=_make_attr_dict(seed), + attr_abs=_make_attr_dict(seed + 100), + mean_squared_attr=_make_attr_dict(seed + 200), + vocab_size=VOCAB_SIZE, + ci_threshold=1e-6, + n_batches_processed=n_batches, + n_tokens_processed=n_tokens, + ) + + +class TestNComponents: + def test_counts_non_output_targets(self): + storage = _make_storage() + assert storage.n_components == C0 + C1 + + +class TestHasSource: + def test_embed_token(self): + storage = _make_storage() + assert storage.has_source("embed:0") + assert storage.has_source(f"embed:{VOCAB_SIZE - 1}") + + def test_embed_oob(self): + storage = _make_storage() + assert not storage.has_source(f"embed:{VOCAB_SIZE}") + assert not storage.has_source("embed:-1") + + def test_component_source(self): + storage = _make_storage() + assert storage.has_source(f"{LAYER_0}:0") + assert storage.has_source(f"{LAYER_0}:{C0 - 1}") + + def test_component_source_oob(self): + storage = _make_storage() + assert not storage.has_source(f"{LAYER_0}:{C0}") + + def test_output_never_source(self): + storage = _make_storage() + assert not storage.has_source("output:0") + + def test_layer_not_present(self): + storage = _make_storage() + assert not storage.has_source("nonexistent:0") + + +class TestHasTarget: + def test_component_target(self): + storage = _make_storage() + assert storage.has_target(f"{LAYER_0}:0") + assert storage.has_target(f"{LAYER_1}:{C1 - 1}") + + def test_component_target_oob(self): + storage = _make_storage() + assert not storage.has_target(f"{LAYER_0}:{C0}") + assert not storage.has_target(f"{LAYER_1}:{C1}") + + def test_output_target(self): + storage = _make_storage() + assert storage.has_target("output:0") + assert storage.has_target(f"output:{VOCAB_SIZE - 1}") + + def test_output_target_oob(self): + storage = _make_storage() + assert not storage.has_target(f"output:{VOCAB_SIZE}") + + def test_embed_never_target(self): + storage = _make_storage() + assert not storage.has_target("embed:0") + + def test_layer_not_present(self): + storage = _make_storage() + assert not storage.has_target("nonexistent:0") + + +class TestSaveLoad: + def test_roundtrip(self, tmp_path: Path): + original = _make_storage() + path = tmp_path / "attrs.pt" + original.save(path) + + loaded = DatasetAttributionStorage.load(path) + + assert loaded.vocab_size == original.vocab_size + assert loaded.ci_threshold == original.ci_threshold + assert loaded.n_batches_processed == original.n_batches_processed + assert loaded.n_tokens_processed == original.n_tokens_processed + assert loaded.n_components == original.n_components + + for attr_name in ("attr", "attr_abs", "mean_squared_attr"): + orig_dict = getattr(original, attr_name) + load_dict = getattr(loaded, attr_name) + assert orig_dict.keys() == load_dict.keys() + for target in orig_dict: + assert orig_dict[target].keys() == load_dict[target].keys() + for source in orig_dict[target]: + torch.testing.assert_close(load_dict[target][source], orig_dict[target][source]) + + +class TestMerge: + def test_two_workers_weighted_average(self, tmp_path: Path): + s1 = _make_storage(seed=0, n_batches=5, n_tokens=320) + s2 = _make_storage(seed=42, n_batches=5, n_tokens=320) + + p1 = tmp_path / "rank_0.pt" + p2 = tmp_path / "rank_1.pt" + s1.save(p1) + s2.save(p2) + + merged = DatasetAttributionStorage.merge([p1, p2]) + + assert merged.n_batches_processed == 10 + assert merged.n_tokens_processed == 640 + assert merged.vocab_size == VOCAB_SIZE + assert merged.ci_threshold == s1.ci_threshold + + n1, n2 = s1.n_tokens_processed, s2.n_tokens_processed + total = n1 + n2 + for target in s1.attr: + for source in s1.attr[target]: + expected = (s1.attr[target][source] * n1 + s2.attr[target][source] * n2) / total + torch.testing.assert_close( + merged.attr[target][source], expected, atol=1e-5, rtol=1e-5 + ) + + def test_unequal_token_counts(self, tmp_path: Path): + s1 = _make_storage(seed=0, n_batches=3, n_tokens=192) + s2 = _make_storage(seed=42, n_batches=7, n_tokens=448) + + p1 = tmp_path / "rank_0.pt" + p2 = tmp_path / "rank_1.pt" + s1.save(p1) + s2.save(p2) + + merged = DatasetAttributionStorage.merge([p1, p2]) + + assert merged.n_tokens_processed == 640 + assert merged.n_batches_processed == 10 + + n1, n2 = s1.n_tokens_processed, s2.n_tokens_processed + total = n1 + n2 + for target in s1.attr: + for source in s1.attr[target]: + expected = (s1.attr[target][source] * n1 + s2.attr[target][source] * n2) / total + torch.testing.assert_close( + merged.attr[target][source], expected, atol=1e-5, rtol=1e-5 + ) + + def test_single_file(self, tmp_path: Path): + original = _make_storage(seed=7, n_batches=10, n_tokens=640) + path = tmp_path / "rank_0.pt" + original.save(path) + + merged = DatasetAttributionStorage.merge([path]) + + assert merged.n_tokens_processed == original.n_tokens_processed + for target in original.attr: + for source in original.attr[target]: + torch.testing.assert_close( + merged.attr[target][source], original.attr[target][source] + ) From 82719058bfec40862a2f5c3270ec72bb9ee91eba Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 18:19:17 +0000 Subject: [PATCH 06/17] Split run.py into run_worker.py and run_merge.py Merge doesn't need config_json, worker does. Separate entrypoints avoid the issue where Fire requires config_json for both paths. Cherry-picked from feature/faster-dataset-attributions. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/dataset_attributions/scripts/run_merge.py | 37 ++++++++++++++ spd/dataset_attributions/scripts/run_slurm.py | 6 +-- .../scripts/{run.py => run_worker.py} | 49 ++++++------------- 3 files changed, 55 insertions(+), 37 deletions(-) create mode 100644 spd/dataset_attributions/scripts/run_merge.py rename spd/dataset_attributions/scripts/{run.py => run_worker.py} (60%) diff --git a/spd/dataset_attributions/scripts/run_merge.py b/spd/dataset_attributions/scripts/run_merge.py new file mode 100644 index 000000000..913ea5374 --- /dev/null +++ b/spd/dataset_attributions/scripts/run_merge.py @@ -0,0 +1,37 @@ +"""Merge script for dataset attribution rank files. + +Combines per-rank attribution files into a single merged result. + +Usage: + python -m spd.dataset_attributions.scripts.run_merge --wandb_path --subrun_id da-xxx +""" + +from spd.dataset_attributions.harvest import merge_attributions +from spd.dataset_attributions.repo import get_attributions_subrun_dir +from spd.log import logger +from spd.utils.wandb_utils import parse_wandb_run_path + + +def main( + *, + wandb_path: str, + subrun_id: str, +) -> None: + _, _, run_id = parse_wandb_run_path(wandb_path) + output_dir = get_attributions_subrun_dir(run_id, subrun_id) + logger.info(f"Merging attribution results for {wandb_path} (subrun {subrun_id})") + merge_attributions(output_dir) + + +def get_command(wandb_path: str, subrun_id: str) -> str: + return ( + f"python -m spd.dataset_attributions.scripts.run_merge " + f'--wandb_path "{wandb_path}" ' + f"--subrun_id {subrun_id}" + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(main) diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index 3fdba505e..8420db56e 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -14,7 +14,7 @@ from datetime import datetime from spd.dataset_attributions.config import AttributionsSlurmConfig -from spd.dataset_attributions.scripts import run as attribution_run +from spd.dataset_attributions.scripts import run_merge, run_worker from spd.log import logger from spd.utils.git_utils import create_git_snapshot from spd.utils.slurm import ( @@ -85,7 +85,7 @@ def submit_attributions( # SLURM arrays are 1-indexed, so task ID 1 -> rank 0, etc. worker_commands = [] for rank in range(n_gpus): - cmd = attribution_run.get_worker_command( + cmd = run_worker.get_command( wandb_path, config_json, rank=rank, @@ -115,7 +115,7 @@ def submit_attributions( ) # Submit merge job with dependency on array completion - merge_cmd = attribution_run.get_merge_command(wandb_path, subrun_id) + merge_cmd = run_merge.get_command(wandb_path, subrun_id) merge_config = SlurmConfig( job_name="spd-attr-merge", partition=partition, diff --git a/spd/dataset_attributions/scripts/run.py b/spd/dataset_attributions/scripts/run_worker.py similarity index 60% rename from spd/dataset_attributions/scripts/run.py rename to spd/dataset_attributions/scripts/run_worker.py index 5d060767e..995a9f22a 100644 --- a/spd/dataset_attributions/scripts/run.py +++ b/spd/dataset_attributions/scripts/run_worker.py @@ -4,19 +4,20 @@ Usage: # Single GPU - python -m spd.dataset_attributions.scripts.run --config_json '...' + python -m spd.dataset_attributions.scripts.run_worker + + # Single GPU with config + python -m spd.dataset_attributions.scripts.run_worker --config_json '{"n_batches": 500}' # Multi-GPU (run in parallel) - python -m spd.dataset_attributions.scripts.run --config_json '...' --rank 0 --world_size 4 --subrun_id da-20260211_120000 - ... - python -m spd.dataset_attributions.scripts.run --merge --subrun_id da-20260211_120000 + python -m spd.dataset_attributions.scripts.run_worker --rank 0 --world_size 4 --subrun_id da-xxx """ from datetime import datetime from typing import Any from spd.dataset_attributions.config import DatasetAttributionConfig -from spd.dataset_attributions.harvest import harvest_attributions, merge_attributions +from spd.dataset_attributions.harvest import harvest_attributions from spd.dataset_attributions.repo import get_attributions_subrun_dir from spd.log import logger from spd.utils.wandb_utils import parse_wandb_run_path @@ -24,31 +25,24 @@ def main( wandb_path: str, - config_json: dict[str, Any], + config_json: dict[str, Any] | None = None, rank: int | None = None, world_size: int | None = None, - merge: bool = False, subrun_id: str | None = None, harvest_subrun_id: str | None = None, ) -> None: - assert isinstance(config_json, dict), f"Expected dict from fire, got {type(config_json)}" _, _, run_id = parse_wandb_run_path(wandb_path) if subrun_id is None: subrun_id = "da-" + datetime.now().strftime("%Y%m%d_%H%M%S") + config = ( + DatasetAttributionConfig.model_validate(config_json) + if config_json + else DatasetAttributionConfig() + ) output_dir = get_attributions_subrun_dir(run_id, subrun_id) - if merge: - assert rank is None and world_size is None, "Cannot specify rank/world_size with --merge" - logger.info(f"Merging attribution results for {wandb_path} (subrun {subrun_id})") - merge_attributions(output_dir) - return - - assert (rank is None) == (world_size is None), "rank and world_size must both be set or unset" - - config = DatasetAttributionConfig.model_validate(config_json) - if world_size is not None: logger.info( f"Distributed harvest: {wandb_path} (rank {rank}/{world_size}, subrun {subrun_id})" @@ -66,7 +60,7 @@ def main( ) -def get_worker_command( +def get_command( wandb_path: str, config_json: str, rank: int, @@ -75,7 +69,7 @@ def get_worker_command( harvest_subrun_id: str | None = None, ) -> str: cmd = ( - f"python -m spd.dataset_attributions.scripts.run " + f"python -m spd.dataset_attributions.scripts.run_worker " f'"{wandb_path}" ' f"--config_json '{config_json}' " f"--rank {rank} " @@ -87,20 +81,7 @@ def get_worker_command( return cmd -def get_merge_command(wandb_path: str, subrun_id: str) -> str: - return ( - f"python -m spd.dataset_attributions.scripts.run " - f'"{wandb_path}" ' - "--merge " - f"--subrun_id {subrun_id}" - ) - - -def cli() -> None: +if __name__ == "__main__": import fire fire.Fire(main) - - -if __name__ == "__main__": - cli() From 7650495a19f194edefb794582347a50b1fe91b62 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 18:53:03 +0000 Subject: [PATCH 07/17] Correct attr_abs via backprop through |target|, reorganise method signatures attr_abs now computed by backpropping through target_acts.abs() instead of flipping by source activation sign. Requires 2 backward passes per target component but is mathematically correct for cross-position (attention) interactions. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/dataset_attributions/harvester.py | 53 ++++++++++++++++++++------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 27c6d5d43..82141c879 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -200,24 +200,39 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No for target_layer in self.sources_by_target: if target_layer == self.unembed_path: - self._process_output_targets(cache, ci.lower_leaky, tokens) + self._process_output_targets( + cache, + tokens, + ci.lower_leaky, + ) else: - self._process_component_targets(target_layer, ci.lower_leaky, cache, tokens) + self._process_component_targets( + cache, + tokens, + ci.lower_leaky, + target_layer, + ) def _process_output_targets( self, cache: dict[str, Tensor], - ci: dict[str, Tensor], tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], ) -> None: """Process output attributions via output-residual-space storage.""" - out_residual = cache[f"{self.unembed_path}_pre_detach"].sum(dim=(0, 1)) + out_residual = cache[f"{self.unembed_path}_pre_detach"] + + out_residual_sum = out_residual.sum(dim=(0, 1)) + out_residual_sum_abs = out_residual.abs().sum(dim=(0, 1)) source_layers = self.sources_by_target[self.unembed_path] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] for d_idx in range(self.output_d_model): - grads = torch.autograd.grad(out_residual[d_idx], source_acts, retain_graph=True) + grads = torch.autograd.grad(out_residual_sum[d_idx], source_acts, retain_graph=True) + abs_grads = torch.autograd.grad( + out_residual_sum_abs[d_idx], source_acts, retain_graph=True + ) self._accumulate_attributions( self.unembed_path, @@ -225,16 +240,17 @@ def _process_output_targets( source_layers, source_acts, list(grads), + list(abs_grads), ci, tokens, ) def _process_component_targets( self, - target_layer: str, - ci: dict[str, Tensor], cache: dict[str, Tensor], tokens: Int[Tensor, "batch seq"], + ci: dict[str, Tensor], + target_layer: str, ) -> None: """Process attributions to a component layer.""" alive_targets = self.component_alive[target_layer] @@ -242,7 +258,10 @@ def _process_component_targets( return target_acts_raw = cache[f"{target_layer}_pre_detach"] - ci_weighted_target_acts = (target_acts_raw * ci[target_layer]).sum(dim=(0, 1)) + + target_ci_detached = ci[target_layer].detach() + ci_weighted_target_acts = (target_acts_raw * target_ci_detached).sum(dim=(0, 1)) + ci_weighted_target_acts_abs = (target_acts_raw.abs() * target_ci_detached).sum(dim=(0, 1)) source_layers = self.sources_by_target[target_layer] source_acts = [cache[f"{s}_post_detach"] for s in source_layers] @@ -252,12 +271,17 @@ def _process_component_targets( ci_weighted_target_acts[t_idx], source_acts, retain_graph=True ) + abs_grads = torch.autograd.grad( + ci_weighted_target_acts_abs[t_idx], source_acts, retain_graph=True + ) + self._accumulate_attributions( target_layer, t_idx, source_layers, source_acts, list(grads), + list(abs_grads), ci, tokens, ) @@ -270,29 +294,32 @@ def _accumulate_attributions( source_layers: list[str], source_acts: list[Tensor], source_grads: list[Tensor], + source_abs_grads: list[Tensor], ci: dict[str, Tensor], tokens: Int[Tensor, "batch seq"], ) -> None: """Accumulate grad*act attributions from sources to a target column.""" attr_accumulator = self.attr_accumulator[target_layer] - attr_abr_accumulator = self.attr_abs_accumulator[target_layer] + attr_abs_accumulator = self.attr_abs_accumulator[target_layer] square_attr_accumulator = self.square_attr_accumulator[target_layer] - for source_layer, act, grad in zip(source_layers, source_acts, source_grads, strict=True): + for source_layer, act, grad, abs_grad in zip( + source_layers, source_acts, source_grads, source_abs_grads, strict=True + ): attr_acc = attr_accumulator[source_layer][target_idx] - attr_abs_acc = attr_abr_accumulator[source_layer][target_idx] + attr_abs_acc = attr_abs_accumulator[source_layer][target_idx] square_attr_acc = square_attr_accumulator[source_layer][target_idx] # Embed has no CI (all tokens always active) source_ci = ci[source_layer] if source_layer != self.embed_path else 1.0 + ci_weighted_attr = grad * act * source_ci - ci_weighted_attr_abs = torch.where(act > 0, ci_weighted_attr, -ci_weighted_attr) + ci_weighted_attr_abs = abs_grad * act * source_ci ci_weighted_squared_attr = ci_weighted_attr.square() if source_layer == self.embed_path: # Per-token: sum grad*act*ci over d_model, scatter by token id - # TODO(oli): figure out why this works attr = ci_weighted_attr.sum(dim=-1).flatten() attr_abs = ci_weighted_attr_abs.sum(dim=-1).flatten() attr_squared = ci_weighted_squared_attr.sum(dim=-1).flatten() From ccf713ff72479d7fabbd3635230d36a1e50c983d Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 19:03:03 +0000 Subject: [PATCH 08/17] Add merge_mem config (default 200G) to prevent merge OOM MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 3 metrics × dict-of-dicts makes rank files ~15GB each. Merge loads all in double precision, needs much more than the default 10GB. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/dataset_attributions/config.py | 1 + spd/dataset_attributions/scripts/run_slurm.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/spd/dataset_attributions/config.py b/spd/dataset_attributions/config.py index a1de165fb..3d84fcbd8 100644 --- a/spd/dataset_attributions/config.py +++ b/spd/dataset_attributions/config.py @@ -26,3 +26,4 @@ class AttributionsSlurmConfig(BaseConfig): partition: str = DEFAULT_PARTITION_NAME time: str = "48:00:00" merge_time: str = "01:00:00" + merge_mem: str = "200G" diff --git a/spd/dataset_attributions/scripts/run_slurm.py b/spd/dataset_attributions/scripts/run_slurm.py index 8420db56e..6adc4bd52 100644 --- a/spd/dataset_attributions/scripts/run_slurm.py +++ b/spd/dataset_attributions/scripts/run_slurm.py @@ -119,8 +119,9 @@ def submit_attributions( merge_config = SlurmConfig( job_name="spd-attr-merge", partition=partition, - n_gpus=0, # No GPU needed for merge + n_gpus=0, time=config.merge_time, + mem=config.merge_mem, snapshot_branch=snapshot_branch, dependency_job_id=array_result.job_id, comment=wandb_url, From 8139fb1ff05fdebcc3c34d47a412bce39fb35e79 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 19:18:02 +0000 Subject: [PATCH 09/17] Add 3-metric selection to dataset attributions in app Backend: implement storage query methods with AttrMetric parameter, bulk endpoint returns all 3 metrics (attr, attr_abs, mean_squared_attr), other endpoints accept optional ?metric= query param. Frontend: 3-way radio toggle (Signed / Abs Target / RMS) in DatasetAttributionsSection. All metrics fetched at once, selection is local state that switches which ComponentAttributions to display. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../backend/routers/dataset_attributions.py | 130 ++++++++++-------- .../ui/DatasetAttributionsSection.svelte | 105 +++++++++++--- .../src/lib/api/datasetAttributions.ts | 16 ++- .../src/lib/useComponentData.svelte.ts | 6 +- .../useComponentDataExpectCached.svelte.ts | 6 +- spd/dataset_attributions/storage.py | 100 +++++++++++++- 6 files changed, 275 insertions(+), 88 deletions(-) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 3c5bd87ff..19df369b8 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -13,13 +13,13 @@ from spd.app.backend.dependencies import DepLoadedRun from spd.app.backend.utils import log_errors +from spd.dataset_attributions.storage import AttrMetric, DatasetAttributionStorage from spd.dataset_attributions.storage import DatasetAttributionEntry as StorageEntry -from spd.dataset_attributions.storage import DatasetAttributionStorage +ATTR_METRICS: list[AttrMetric] = ["attr", "attr_abs", "mean_squared_attr"] -class DatasetAttributionEntry(BaseModel): - """A single entry in attribution results.""" +class DatasetAttributionEntry(BaseModel): component_key: str layer: str component_idx: int @@ -27,27 +27,26 @@ class DatasetAttributionEntry(BaseModel): class DatasetAttributionMetadata(BaseModel): - """Metadata about dataset attributions availability.""" - available: bool n_batches_processed: int | None n_tokens_processed: int | None n_component_layer_keys: int | None - # TODO(oli): remove these from frontend - # vocab_size: int | None - # d_model: int | None ci_threshold: float | None class ComponentAttributions(BaseModel): - """All attribution data for a single component (sources and targets, positive and negative).""" - positive_sources: list[DatasetAttributionEntry] negative_sources: list[DatasetAttributionEntry] positive_targets: list[DatasetAttributionEntry] negative_targets: list[DatasetAttributionEntry] +class AllMetricAttributions(BaseModel): + attr: ComponentAttributions + attr_abs: ComponentAttributions + mean_squared_attr: ComponentAttributions + + router = APIRouter(prefix="/api/dataset_attributions", tags=["dataset_attributions"]) NOT_AVAILABLE_MSG = ( @@ -56,19 +55,16 @@ class ComponentAttributions(BaseModel): def _storage_key(canonical_layer: str, component_idx: int) -> str: - """Format a canonical layer + idx as a storage key.""" return f"{canonical_layer}:{component_idx}" def _require_storage(loaded: DepLoadedRun) -> DatasetAttributionStorage: - """Get storage or raise 404.""" if loaded.attributions is None: raise HTTPException(status_code=404, detail=NOT_AVAILABLE_MSG) return loaded.attributions.get_attributions() def _require_source(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a source or raise 404.""" if not storage.has_source(component_key): raise HTTPException( status_code=404, @@ -77,7 +73,6 @@ def _require_source(storage: DatasetAttributionStorage, component_key: str) -> N def _require_target(storage: DatasetAttributionStorage, component_key: str) -> None: - """Validate component exists as a target or raise 404.""" if not storage.has_target(component_key): raise HTTPException( status_code=404, @@ -86,12 +81,10 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: - """Get the unembedding matrix from the loaded model.""" return loaded.topology.get_unembed_weight() def _to_api_entries(entries: list[StorageEntry]) -> list[DatasetAttributionEntry]: - """Convert storage entries to API response format.""" return [ DatasetAttributionEntry( component_key=e.component_key, @@ -103,10 +96,56 @@ def _to_api_entries(entries: list[StorageEntry]) -> list[DatasetAttributionEntry ] +def _get_component_attributions_for_metric( + storage: DatasetAttributionStorage, + component_key: str, + k: int, + metric: AttrMetric, + is_source: bool, + is_target: bool, + w_unembed: Float[Tensor, "d_model vocab"] | None, +) -> ComponentAttributions: + return ComponentAttributions( + positive_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "positive", metric) + ) + if is_target + else [], + negative_sources=_to_api_entries( + storage.get_top_sources(component_key, k, "negative", metric) + ) + if is_target + else [], + positive_targets=_to_api_entries( + storage.get_top_targets( + component_key, + k, + "positive", + metric, + w_unembed=w_unembed, + include_outputs=w_unembed is not None, + ), + ) + if is_source + else [], + negative_targets=_to_api_entries( + storage.get_top_targets( + component_key, + k, + "negative", + metric, + w_unembed=w_unembed, + include_outputs=w_unembed is not None, + ), + ) + if is_source + else [], + ) + + @router.get("/metadata") @log_errors def get_attribution_metadata(loaded: DepLoadedRun) -> DatasetAttributionMetadata: - """Get metadata about dataset attributions availability.""" if loaded.attributions is None: return DatasetAttributionMetadata( available=False, @@ -132,12 +171,11 @@ def get_component_attributions( component_idx: int, loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, -) -> ComponentAttributions: - """Get all attribution data for a component (sources and targets, positive and negative).""" +) -> AllMetricAttributions: + """Get all attribution data for a component across all 3 metrics.""" storage = _require_storage(loaded) component_key = _storage_key(layer, component_idx) - # Component can be both a source and a target, so we need to check both is_source = storage.has_source(component_key) is_target = storage.has_target(component_key) @@ -149,35 +187,13 @@ def get_component_attributions( w_unembed = _get_w_unembed(loaded) if is_source else None - return ComponentAttributions( - positive_sources=_to_api_entries(storage.get_top_sources(component_key, k, "positive")) - if is_target - else [], - negative_sources=_to_api_entries(storage.get_top_sources(component_key, k, "negative")) - if is_target - else [], - positive_targets=_to_api_entries( - storage.get_top_targets( - component_key, - k, - "positive", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], - negative_targets=_to_api_entries( - storage.get_top_targets( - component_key, - k, - "negative", - w_unembed=w_unembed, - include_outputs=w_unembed is not None, - ), - ) - if is_source - else [], + return AllMetricAttributions( + **{ + metric: _get_component_attributions_for_metric( + storage, component_key, k, metric, is_source, is_target, w_unembed + ) + for metric in ATTR_METRICS + } ) @@ -189,15 +205,17 @@ def get_attribution_sources( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k source components that attribute TO this target over the dataset.""" storage = _require_storage(loaded) target_key = _storage_key(layer, component_idx) _require_target(storage, target_key) w_unembed = _get_w_unembed(loaded) if layer == "output" else None - return _to_api_entries(storage.get_top_sources(target_key, k, sign, w_unembed=w_unembed)) + return _to_api_entries( + storage.get_top_sources(target_key, k, sign, metric, w_unembed=w_unembed) + ) @router.get("/{layer}/{component_idx}/targets") @@ -208,15 +226,17 @@ def get_attribution_targets( loaded: DepLoadedRun, k: Annotated[int, Query(ge=1)] = 10, sign: Literal["positive", "negative"] = "positive", + metric: AttrMetric = "attr", ) -> list[DatasetAttributionEntry]: - """Get top-k target components this source attributes TO over the dataset.""" storage = _require_storage(loaded) source_key = _storage_key(layer, component_idx) _require_source(storage, source_key) w_unembed = _get_w_unembed(loaded) - return _to_api_entries(storage.get_top_targets(source_key, k, sign, w_unembed=w_unembed)) + return _to_api_entries( + storage.get_top_targets(source_key, k, sign, metric, w_unembed=w_unembed) + ) @router.get("/between/{source_layer}/{source_idx}/{target_layer}/{target_idx}") @@ -227,8 +247,8 @@ def get_attribution_between( target_layer: str, target_idx: int, loaded: DepLoadedRun, + metric: AttrMetric = "attr", ) -> float: - """Get attribution strength from source component to target component.""" storage = _require_storage(loaded) source_key = _storage_key(source_layer, source_idx) target_key = _storage_key(target_layer, target_idx) @@ -237,4 +257,4 @@ def get_attribution_between( w_unembed = _get_w_unembed(loaded) if target_layer == "output" else None - return storage.get_attribution(source_key, target_key, w_unembed=w_unembed) + return storage.get_attribution(source_key, target_key, metric, w_unembed=w_unembed) diff --git a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte index cd86c9af1..3ec394eba 100644 --- a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte +++ b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte @@ -9,14 +9,22 @@ import { COMPONENT_CARD_CONSTANTS } from "../../lib/componentCardConstants"; import type { EdgeAttribution } from "../../lib/promptAttributionsTypes"; import type { DatasetAttributions } from "../../lib/useComponentData.svelte"; + import type { AttrMetric } from "../../lib/api/datasetAttributions"; import EdgeAttributionGrid from "./EdgeAttributionGrid.svelte"; + const METRIC_LABELS: Record = { + attr: "Signed", + attr_abs: "Abs Target", + mean_squared_attr: "RMS", + }; + type Props = { attributions: DatasetAttributions; onComponentClick?: (componentKey: string) => void; }; let { attributions, onComponentClick }: Props = $props(); + let selectedMetric = $state("attr"); function handleClick(key: string) { if (onComponentClick) { @@ -24,6 +32,8 @@ } } + const active = $derived(attributions[selectedMetric]); + function toEdgeAttribution( entries: { component_key: string; value: number }[], maxAbsValue: number, @@ -36,26 +46,87 @@ } const maxSourceVal = $derived( - Math.max(attributions.positive_sources[0]?.value ?? 0, Math.abs(attributions.negative_sources[0]?.value ?? 0)), + Math.max( + active.positive_sources[0]?.value ?? 0, + Math.abs(active.negative_sources[0]?.value ?? 0), + ), ); const maxTargetVal = $derived( - Math.max(attributions.positive_targets[0]?.value ?? 0, Math.abs(attributions.negative_targets[0]?.value ?? 0)), + Math.max( + active.positive_targets[0]?.value ?? 0, + Math.abs(active.negative_targets[0]?.value ?? 0), + ), ); - const positiveSources = $derived(toEdgeAttribution(attributions.positive_sources, maxSourceVal)); - const negativeSources = $derived(toEdgeAttribution(attributions.negative_sources, maxSourceVal)); - const positiveTargets = $derived(toEdgeAttribution(attributions.positive_targets, maxTargetVal)); - const negativeTargets = $derived(toEdgeAttribution(attributions.negative_targets, maxTargetVal)); + const positiveSources = $derived(toEdgeAttribution(active.positive_sources, maxSourceVal)); + const negativeSources = $derived(toEdgeAttribution(active.negative_sources, maxSourceVal)); + const positiveTargets = $derived(toEdgeAttribution(active.positive_targets, maxTargetVal)); + const negativeTargets = $derived(toEdgeAttribution(active.negative_targets, maxTargetVal)); - +
+
+ {#each Object.entries(METRIC_LABELS) as [metric, label] (metric)} + + {/each} +
+ + +
+ + diff --git a/spd/app/frontend/src/lib/api/datasetAttributions.ts b/spd/app/frontend/src/lib/api/datasetAttributions.ts index f995a33f6..9916439fd 100644 --- a/spd/app/frontend/src/lib/api/datasetAttributions.ts +++ b/spd/app/frontend/src/lib/api/datasetAttributions.ts @@ -18,20 +18,30 @@ export type ComponentAttributions = { negative_targets: DatasetAttributionEntry[]; }; +export type AttrMetric = "attr" | "attr_abs" | "mean_squared_attr"; + +export type AllMetricAttributions = { + attr: ComponentAttributions; + attr_abs: ComponentAttributions; + mean_squared_attr: ComponentAttributions; +}; + export type DatasetAttributionsMetadata = { available: boolean; }; export async function getDatasetAttributionsMetadata(): Promise { - return fetchJson(apiUrl("/api/dataset_attributions/metadata").toString()); + return fetchJson( + apiUrl("/api/dataset_attributions/metadata").toString(), + ); } export async function getComponentAttributions( layer: string, componentIdx: number, k: number = 10, -): Promise { +): Promise { const url = apiUrl(`/api/dataset_attributions/${encodeURIComponent(layer)}/${componentIdx}`); url.searchParams.set("k", String(k)); - return fetchJson(url.toString()); + return fetchJson(url.toString()); } diff --git a/spd/app/frontend/src/lib/useComponentData.svelte.ts b/spd/app/frontend/src/lib/useComponentData.svelte.ts index b9041f83e..d2af449ef 100644 --- a/spd/app/frontend/src/lib/useComponentData.svelte.ts +++ b/spd/app/frontend/src/lib/useComponentData.svelte.ts @@ -8,7 +8,7 @@ import { getInterpretationDetail, requestComponentInterpretation, } from "./api"; -import type { ComponentAttributions, InterpretationDetail } from "./api"; +import type { AllMetricAttributions, InterpretationDetail } from "./api"; import type { SubcomponentCorrelationsResponse, SubcomponentActivationContexts, @@ -23,7 +23,7 @@ const TOKEN_STATS_TOP_K = 200; /** Dataset attributions top-k */ const DATASET_ATTRIBUTIONS_TOP_K = 20; -export type { ComponentAttributions as DatasetAttributions }; +export type { AllMetricAttributions as DatasetAttributions }; export type ComponentCoords = { layer: string; cIdx: number }; @@ -43,7 +43,7 @@ export function useComponentData() { // null inside Loadable means "no data for this component" (404) let correlations = $state>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); diff --git a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts index f32dab70a..29110c327 100644 --- a/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts +++ b/spd/app/frontend/src/lib/useComponentDataExpectCached.svelte.ts @@ -17,7 +17,7 @@ import { getInterpretationDetail, requestComponentInterpretation, } from "./api"; -import type { ComponentAttributions, InterpretationDetail } from "./api"; +import type { AllMetricAttributions, InterpretationDetail } from "./api"; import type { SubcomponentCorrelationsResponse, SubcomponentActivationContexts, @@ -29,7 +29,7 @@ const DATASET_ATTRIBUTIONS_TOP_K = 20; /** Fetch more activation examples in background after initial cached load */ const ACTIVATION_EXAMPLES_FULL_LIMIT = 200; -export type { ComponentAttributions as DatasetAttributions }; +export type { AllMetricAttributions as DatasetAttributions }; export type ComponentCoords = { layer: string; cIdx: number }; @@ -39,7 +39,7 @@ export function useComponentDataExpectCached() { let componentDetail = $state>({ status: "uninitialized" }); let correlations = $state>({ status: "uninitialized" }); let tokenStats = $state>({ status: "uninitialized" }); - let datasetAttributions = $state>({ status: "uninitialized" }); + let datasetAttributions = $state>({ status: "uninitialized" }); let interpretationDetail = $state>({ status: "uninitialized" }); let currentCoords = $state(null); diff --git a/spd/dataset_attributions/storage.py b/spd/dataset_attributions/storage.py index ca1c8e59d..22dae2e6e 100644 --- a/spd/dataset_attributions/storage.py +++ b/spd/dataset_attributions/storage.py @@ -22,6 +22,7 @@ from spd.log import logger AttrDict = dict[str, dict[str, Tensor]] +AttrMetric = Literal["attr", "attr_abs", "mean_squared_attr"] @dataclass @@ -100,15 +101,100 @@ def has_target(self, key: str) -> bool: first_source = next(iter(self.attr[layer].values())) return 0 <= idx < first_source.shape[0] - # TODO: these methods need a metric parameter to select which of the 3 attr dicts to query - def get_attribution(self, *_args: object, **_kwargs: object) -> float: - raise ValueError("TODO: get_attribution needs metric selection") + def _get_attr_dict(self, metric: AttrMetric) -> AttrDict: + match metric: + case "attr": + return self.attr + case "attr_abs": + return self.attr_abs + case "mean_squared_attr": + return self.mean_squared_attr - def get_top_sources(self, *_args: object, **_kwargs: object) -> list[DatasetAttributionEntry]: - raise ValueError("TODO: get_top_sources needs metric selection") + def get_attribution( + self, + source_key: str, + target_key: str, + metric: AttrMetric, + w_unembed: Tensor | None = None, + ) -> float: + source_layer, source_idx = self._parse_key(source_key) + target_layer, target_idx = self._parse_key(target_key) + assert source_layer != "output", f"output tokens cannot be sources: {source_key}" + + attrs = self._get_attr_dict(metric) + attr_matrix = attrs[target_layer][source_layer] + + if target_layer == "output": + assert w_unembed is not None, "w_unembed required for output target queries" + w_unembed = w_unembed.to(attr_matrix.device) + return (attr_matrix[:, source_idx] @ w_unembed[:, target_idx]).item() + + return attr_matrix[target_idx, source_idx].item() + + def get_top_sources( + self, + target_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + w_unembed: Tensor | None = None, + ) -> list[DatasetAttributionEntry]: + target_layer, target_idx = self._parse_key(target_key) + attrs = self._get_attr_dict(metric) + + if target_layer == "output": + assert w_unembed is not None, "w_unembed required for output target queries" + + value_segments: list[Tensor] = [] + layer_names: list[str] = [] + + for source_layer, attr_matrix in attrs[target_layer].items(): + if target_layer == "output": + assert w_unembed is not None + w = w_unembed.to(attr_matrix.device) + values = w[:, target_idx] @ attr_matrix + else: + values = attr_matrix[target_idx, :] + + value_segments.append(values) + layer_names.append(source_layer) + + return self._top_k_from_segments(value_segments, layer_names, k, sign) - def get_top_targets(self, *_args: object, **_kwargs: object) -> list[DatasetAttributionEntry]: - raise ValueError("TODO: get_top_targets needs metric selection") + def get_top_targets( + self, + source_key: str, + k: int, + sign: Literal["positive", "negative"], + metric: AttrMetric, + w_unembed: Tensor | None = None, + include_outputs: bool = True, + ) -> list[DatasetAttributionEntry]: + source_layer, source_idx = self._parse_key(source_key) + attrs = self._get_attr_dict(metric) + + value_segments: list[Tensor] = [] + layer_names: list[str] = [] + + for target_layer, sources in attrs.items(): + if source_layer not in sources: + continue + + attr_matrix = sources[source_layer] + + if target_layer == "output": + if not include_outputs: + continue + assert w_unembed is not None, "w_unembed required when include_outputs=True" + w = w_unembed.to(attr_matrix.device) + values = attr_matrix[:, source_idx] @ w + else: + values = attr_matrix[:, source_idx] + + value_segments.append(values) + layer_names.append(target_layer) + + return self._top_k_from_segments(value_segments, layer_names, k, sign) def _top_k_from_segments( self, From 1bf98772272b6ca784972854871c87f121e8bb02 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 19:22:52 +0000 Subject: [PATCH 10/17] Allow bare s-prefixed run IDs everywhere (e.g. "s-17805b61") parse_wandb_run_path now accepts "s-xxxxxxxx" and expands to goodfire/spd. Handled in backend so it works for CLI, app, and any other consumer. Frontend placeholder updated. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/app/frontend/src/components/RunSelector.svelte | 2 +- spd/utils/wandb_utils.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/spd/app/frontend/src/components/RunSelector.svelte b/spd/app/frontend/src/components/RunSelector.svelte index aa4728bd8..f174ee635 100644 --- a/spd/app/frontend/src/components/RunSelector.svelte +++ b/spd/app/frontend/src/components/RunSelector.svelte @@ -87,7 +87,7 @@
diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 6b1d85813..7dd49a730 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -31,7 +31,11 @@ # Regex patterns for parsing W&B run references # Run IDs can be 8 chars (e.g., "d2ec3bfe") or prefixed with char-dash (e.g., "s-d2ec3bfe") +DEFAULT_WANDB_ENTITY = "goodfire" +DEFAULT_WANDB_PROJECT = "spd" + _RUN_ID_PATTERN = r"(?:[a-z0-9]-)?[a-z0-9]{8}" +_BARE_RUN_ID_RE = re.compile(r"^(s-[a-z0-9]{8})$") _WANDB_PATH_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/({_RUN_ID_PATTERN})$") _WANDB_PATH_WITH_RUNS_RE = re.compile(rf"^([^/\s]+)/([^/\s]+)/runs/({_RUN_ID_PATTERN})$") _WANDB_URL_RE = re.compile( @@ -169,6 +173,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: """Parse various W&B run reference formats into (entity, project, run_id). Accepts: + - "s-xxxxxxxx" (bare SPD run ID, assumes goodfire/spd) - "entity/project/runId" (compact form) - "entity/project/runs/runId" (with /runs/) - "wandb:entity/project/runId" (with wandb: prefix) @@ -187,6 +192,10 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: if s.startswith("wandb:"): s = s[6:] + # Bare run ID (e.g. "s-17805b61") → default entity/project + if m := _BARE_RUN_ID_RE.match(s): + return DEFAULT_WANDB_ENTITY, DEFAULT_WANDB_PROJECT, m.group(1) + # Try compact form: entity/project/runid if m := _WANDB_PATH_RE.match(s): return m.group(1), m.group(2), m.group(3) @@ -201,6 +210,7 @@ def parse_wandb_run_path(input_path: str) -> tuple[str, str, str]: raise ValueError( f"Invalid W&B run reference. Expected one of:\n" + f' - "s-xxxxxxxx" (bare run ID)\n' f' - "entity/project/xxxxxxxx"\n' f' - "entity/project/runs/xxxxxxxx"\n' f' - "wandb:entity/project/runs/xxxxxxxx"\n' From 627df2b7bd4a22ce86592e3ee56ea78efa06f3e9 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 19:31:15 +0000 Subject: [PATCH 11/17] Fix AttributionRepo.open skipping valid subruns due to old-format dirs Old subruns (da-timing-*, da-overnight-*) sort after da-YYYYMMDD_* and have no dataset_attributions.pt. The old code only checked the last candidate and returned None. Now iterates in reverse until finding one with the file. Co-Authored-By: Claude Opus 4.6 (1M context) --- spd/dataset_attributions/repo.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/spd/dataset_attributions/repo.py b/spd/dataset_attributions/repo.py index 697036ba3..1175d584e 100644 --- a/spd/dataset_attributions/repo.py +++ b/spd/dataset_attributions/repo.py @@ -42,14 +42,13 @@ def open(cls, run_id: str) -> "AttributionRepo | None": candidates = sorted( [d for d in base_dir.iterdir() if d.is_dir() and d.name.startswith("da-")], key=lambda d: d.name, + reverse=True, ) - if not candidates: - return None - subrun_dir = candidates[-1] - path = subrun_dir / "dataset_attributions.pt" - if not path.exists(): - return None - return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + for subrun_dir in candidates: + path = subrun_dir / "dataset_attributions.pt" + if path.exists(): + return cls(DatasetAttributionStorage.load(path), subrun_id=subrun_dir.name) + return None def get_attributions(self) -> DatasetAttributionStorage: return self._storage From 73a7ba01ef039c3e420771664fe884f2f7bbb47e Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 20:17:01 +0000 Subject: [PATCH 12/17] Fix 3s lag on attribution metric toggle: O(V) linear scan per pill getTokenText did .find() over the full 50K vocab array for every embed/output pill on each render. Build a Map once via $derived, making lookups O(1). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../components/ui/EdgeAttributionList.svelte | 25 +++++++++++++------ 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte index aa98848b5..5f06da323 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionList.svelte @@ -7,6 +7,20 @@ const runState = getContext(RUN_KEY); + function assert(condition: boolean, msg: string): asserts condition { + if (!condition) throw new Error(msg); + } + + const vocabIdToText = $derived.by(() => { + const map = new Map(); + if (runState.allTokens.status === "loaded") { + for (const t of runState.allTokens.data) { + map.set(t.id, t.string); + } + } + return map; + }); + type Props = { items: EdgeAttribution[]; onClick: (key: string) => void; @@ -81,13 +95,10 @@ if (layer === "embed" || layer === "output") { const vocabIdx = parseInt(cIdx); - // Tokens are guaranteed loaded when run is loaded (see useRun.svelte.ts) - if (runState.allTokens.status !== "loaded") { - throw new Error(`allTokens not loaded (status: ${runState.allTokens.status})`); - } - const tokenInfo = runState.allTokens.data.find((t) => t.id === vocabIdx); - if (!tokenInfo) throw new Error(`Token not found for vocab index ${vocabIdx}`); - return tokenInfo.string; + assert(runState.allTokens.status === "loaded", `allTokens not loaded`); + const text = vocabIdToText.get(vocabIdx); + assert(text !== undefined, `Token not found for vocab index ${vocabIdx}`); + return text; } } From 5798178fc6ab5f8fbe7161a3a8f8d997bb88a8ec Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 20:34:45 +0000 Subject: [PATCH 13/17] Ship token strings from backend instead of resolving vocab IDs in frontend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backend resolves embed/output token strings via tokenizer.decode() and includes token_str in DatasetAttributionEntry. Frontend uses it directly instead of scanning a 50K vocab array per pill. Removes tokens/outputProbs passthrough from EdgeAttributionGrid/List — token strings now flow through EdgeAttribution.tokenStr from the source. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../backend/routers/dataset_attributions.py | 21 +++-- .../prompt-attr/ComponentNodeCard.svelte | 13 ++- .../ui/DatasetAttributionsSection.svelte | 3 +- .../components/ui/EdgeAttributionGrid.svelte | 15 +--- .../components/ui/EdgeAttributionList.svelte | 79 ++----------------- .../src/lib/api/datasetAttributions.ts | 1 + .../src/lib/promptAttributionsTypes.ts | 1 + 7 files changed, 38 insertions(+), 95 deletions(-) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 19df369b8..bf8ee501a 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -24,6 +24,7 @@ class DatasetAttributionEntry(BaseModel): layer: str component_idx: int value: float + token_str: str | None = None class DatasetAttributionMetadata(BaseModel): @@ -84,13 +85,18 @@ def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: return loaded.topology.get_unembed_weight() -def _to_api_entries(entries: list[StorageEntry]) -> list[DatasetAttributionEntry]: +def _to_api_entries( + entries: list[StorageEntry], loaded: DepLoadedRun +) -> list[DatasetAttributionEntry]: return [ DatasetAttributionEntry( component_key=e.component_key, layer=e.layer, component_idx=e.component_idx, value=e.value, + token_str=loaded.tokenizer.decode([e.component_idx]) + if e.layer in ("embed", "output") + else None, ) for e in entries ] @@ -98,6 +104,7 @@ def _to_api_entries(entries: list[StorageEntry]) -> list[DatasetAttributionEntry def _get_component_attributions_for_metric( storage: DatasetAttributionStorage, + loaded: DepLoadedRun, component_key: str, k: int, metric: AttrMetric, @@ -107,12 +114,12 @@ def _get_component_attributions_for_metric( ) -> ComponentAttributions: return ComponentAttributions( positive_sources=_to_api_entries( - storage.get_top_sources(component_key, k, "positive", metric) + storage.get_top_sources(component_key, k, "positive", metric), loaded ) if is_target else [], negative_sources=_to_api_entries( - storage.get_top_sources(component_key, k, "negative", metric) + storage.get_top_sources(component_key, k, "negative", metric), loaded ) if is_target else [], @@ -125,6 +132,7 @@ def _get_component_attributions_for_metric( w_unembed=w_unembed, include_outputs=w_unembed is not None, ), + loaded, ) if is_source else [], @@ -137,6 +145,7 @@ def _get_component_attributions_for_metric( w_unembed=w_unembed, include_outputs=w_unembed is not None, ), + loaded, ) if is_source else [], @@ -190,7 +199,7 @@ def get_component_attributions( return AllMetricAttributions( **{ metric: _get_component_attributions_for_metric( - storage, component_key, k, metric, is_source, is_target, w_unembed + storage, loaded, component_key, k, metric, is_source, is_target, w_unembed ) for metric in ATTR_METRICS } @@ -214,7 +223,7 @@ def get_attribution_sources( w_unembed = _get_w_unembed(loaded) if layer == "output" else None return _to_api_entries( - storage.get_top_sources(target_key, k, sign, metric, w_unembed=w_unembed) + storage.get_top_sources(target_key, k, sign, metric, w_unembed=w_unembed), loaded ) @@ -235,7 +244,7 @@ def get_attribution_targets( w_unembed = _get_w_unembed(loaded) return _to_api_entries( - storage.get_top_targets(source_key, k, sign, metric, w_unembed=w_unembed) + storage.get_top_targets(source_key, k, sign, metric, w_unembed=w_unembed), loaded ) diff --git a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte index a0d663208..640135c76 100644 --- a/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte +++ b/spd/app/frontend/src/components/prompt-attr/ComponentNodeCard.svelte @@ -130,6 +130,16 @@ const currentNodeKey = $derived(`${layer}:${seqIdx}:${cIdx}`); const N_EDGES_TO_DISPLAY = 20; + function resolveTokenStr(nodeKey: string): string | null { + const parts = nodeKey.split(":"); + if (parts.length !== 3) return null; + const [layer, seqStr, cIdx] = parts; + const seqIdx = parseInt(seqStr); + if (layer === "embed") return tokens[seqIdx] ?? null; + if (layer === "output") return outputProbs[`${seqIdx}:${cIdx}`]?.token ?? null; + return null; + } + function getTopEdgeAttributions( edges: EdgeData[], isPositive: boolean, @@ -144,6 +154,7 @@ key: getKey(e), value: e.val, normalizedMagnitude: Math.abs(e.val) / maxAbsVal, + tokenStr: resolveTokenStr(getKey(e)), })); } @@ -249,8 +260,6 @@ {outgoingNegative} pageSize={COMPONENT_CARD_CONSTANTS.PROMPT_ATTRIBUTIONS_PAGE_SIZE} onClick={handleEdgeNodeClick} - {tokens} - {outputProbs} /> {/if} diff --git a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte index 3ec394eba..d47668bc3 100644 --- a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte +++ b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte @@ -35,13 +35,14 @@ const active = $derived(attributions[selectedMetric]); function toEdgeAttribution( - entries: { component_key: string; value: number }[], + entries: { component_key: string; value: number; token_str: string | null }[], maxAbsValue: number, ): EdgeAttribution[] { return entries.map((e) => ({ key: e.component_key, value: e.value, normalizedMagnitude: Math.abs(e.value) / (maxAbsValue || 1), + tokenStr: e.token_str, })); } diff --git a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte index 844cb7c04..c90bfc33e 100644 --- a/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte +++ b/spd/app/frontend/src/components/ui/EdgeAttributionGrid.svelte @@ -1,5 +1,5 @@
From 3b8bf4ec2e0490222f45eaa94754d800bf2c5a3b Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Mon, 23 Feb 2026 20:55:28 +0000 Subject: [PATCH 15/17] Narrow frontend types: SignedAttributions vs UnsignedAttributions mean_squared_attr only has positive_sources/positive_targets. Hardcoded three paths in DatasetAttributionsSection matching each type. Slow benchmark result: per-element loops >14x slower than scatter_add_ (>60 min vs 4.3 min per batch on s-17805b61, timed out at 1hr). Co-Authored-By: Claude Opus 4.6 (1M context) --- .../ui/DatasetAttributionsSection.svelte | 117 ++++++++++-------- .../src/lib/api/datasetAttributions.ts | 13 +- 2 files changed, 75 insertions(+), 55 deletions(-) diff --git a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte index 434ecfe62..1d8799d63 100644 --- a/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte +++ b/spd/app/frontend/src/components/ui/DatasetAttributionsSection.svelte @@ -9,15 +9,9 @@ import { COMPONENT_CARD_CONSTANTS } from "../../lib/componentCardConstants"; import type { EdgeAttribution } from "../../lib/promptAttributionsTypes"; import type { DatasetAttributions } from "../../lib/useComponentData.svelte"; - import type { AttrMetric } from "../../lib/api/datasetAttributions"; + import type { AttrMetric, DatasetAttributionEntry } from "../../lib/api/datasetAttributions"; import EdgeAttributionGrid from "./EdgeAttributionGrid.svelte"; - const METRIC_LABELS: Record = { - attr: "Signed", - attr_abs: "Abs Target", - mean_squared_attr: "RMS", - }; - type Props = { attributions: DatasetAttributions; onComponentClick?: (componentKey: string) => void; @@ -32,10 +26,8 @@ } } - const active = $derived(attributions[selectedMetric]); - function toEdgeAttribution( - entries: { component_key: string; value: number; token_str: string | null }[], + entries: DatasetAttributionEntry[], maxAbsValue: number, ): EdgeAttribution[] { return entries.map((e) => ({ @@ -46,53 +38,76 @@ })); } - const maxSourceVal = $derived( - Math.max( - active.positive_sources[0]?.value ?? 0, - Math.abs(active.negative_sources[0]?.value ?? 0), - ), - ); - const maxTargetVal = $derived( - Math.max( - active.positive_targets[0]?.value ?? 0, - Math.abs(active.negative_targets[0]?.value ?? 0), - ), - ); - - const hasSigned = $derived(selectedMetric === "attr"); - - const positiveSources = $derived(toEdgeAttribution(active.positive_sources, maxSourceVal)); - const negativeSources = $derived(hasSigned ? toEdgeAttribution(active.negative_sources, maxSourceVal) : []); - const positiveTargets = $derived(toEdgeAttribution(active.positive_targets, maxTargetVal)); - const negativeTargets = $derived(hasSigned ? toEdgeAttribution(active.negative_targets, maxTargetVal) : []); + function maxAbs(...vals: number[]): number { + return Math.max(...vals.map(Math.abs)); + } + + // attr: signed + const attrMaxSource = $derived(maxAbs(attributions.attr.positive_sources[0]?.value ?? 0, attributions.attr.negative_sources[0]?.value ?? 0)); + const attrMaxTarget = $derived(maxAbs(attributions.attr.positive_targets[0]?.value ?? 0, attributions.attr.negative_targets[0]?.value ?? 0)); + + // attr_abs: signed + const absMaxSource = $derived(maxAbs(attributions.attr_abs.positive_sources[0]?.value ?? 0, attributions.attr_abs.negative_sources[0]?.value ?? 0)); + const absMaxTarget = $derived(maxAbs(attributions.attr_abs.positive_targets[0]?.value ?? 0, attributions.attr_abs.negative_targets[0]?.value ?? 0)); + + // mean_squared_attr: unsigned (positive only) + const rmsMaxSource = $derived(attributions.mean_squared_attr.positive_sources[0]?.value ?? 0); + const rmsMaxTarget = $derived(attributions.mean_squared_attr.positive_targets[0]?.value ?? 0);
- {#each Object.entries(METRIC_LABELS) as [metric, label] (metric)} - - {/each} + + +
- + {#if selectedMetric === "attr"} + + {:else if selectedMetric === "attr_abs"} + + {:else} + + {/if}