From c0c3a173ce9469101a3bfbb3fbb29eade328b42a Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Thu, 29 Jan 2026 06:16:30 +0000 Subject: [PATCH 01/16] wip: Introduce generic types for batch and output in ComponentModel - Add `BatchT` and `OutputT` type parameters to ComponentModel and related functions - Replace `pretrained_model_output_attr` with `run_batch` and `reconstruction_loss` callables - Remove `extract_batch_data` and `AliveComponentsTracker` utilities - --- .claude/.nfs2f6abdf93653d08500002cba | 22 ++ spd/app/backend/compute.py | 20 +- spd/app/backend/optim_cis.py | 12 +- spd/app/backend/routers/prompts.py | 3 +- spd/app/backend/state.py | 2 +- spd/clustering/activations.py | 6 +- spd/data.py | 8 +- spd/dataset_attributions/harvest.py | 8 +- spd/dataset_attributions/harvester.py | 4 +- spd/eval.py | 79 +++--- spd/experiments/ih/ih_decomposition.py | 8 +- spd/experiments/lm/lm_decomposition.py | 3 + .../resid_mlp/resid_mlp_decomposition.py | 3 + spd/experiments/tms/tms_decomposition.py | 3 + spd/harvest/harvest.py | 7 +- spd/losses.py | 26 +- spd/metrics/alive_components.py | 76 ------ spd/metrics/base.py | 8 +- spd/metrics/ce_and_kl_losses.py | 4 +- spd/metrics/ci_histograms.py | 4 +- spd/metrics/ci_l0.py | 4 +- spd/metrics/ci_masked_recon_layerwise_loss.py | 63 ++--- spd/metrics/ci_masked_recon_loss.py | 38 ++- spd/metrics/ci_masked_recon_subset_loss.py | 40 ++- spd/metrics/ci_mean_per_component.py | 4 +- spd/metrics/component_activation_density.py | 6 +- spd/metrics/faithfulness_loss.py | 4 +- spd/metrics/identity_ci_error.py | 4 +- spd/metrics/importance_minimality_loss.py | 4 +- spd/metrics/permuted_ci_plots.py | 4 +- .../pgd_masked_recon_layerwise_loss.py | 33 +-- spd/metrics/pgd_masked_recon_loss.py | 25 +- spd/metrics/pgd_masked_recon_subset_loss.py | 30 +- spd/metrics/pgd_utils.py | 100 +++---- .../stochastic_hidden_acts_recon_loss.py | 18 +- .../stochastic_recon_layerwise_loss.py | 61 ++--- spd/metrics/stochastic_recon_loss.py | 83 +++--- .../stochastic_recon_subset_ce_and_kl.py | 4 +- spd/metrics/stochastic_recon_subset_loss.py | 42 ++- spd/metrics/unmasked_recon_loss.py | 40 ++- spd/metrics/uv_plots.py | 4 +- spd/models/component_model.py | 185 +++++++++---- spd/plotting.py | 5 +- spd/run_spd.py | 70 ++--- spd/scripts/compare_models/compare_models.py | 9 +- spd/utils/general_utils.py | 61 +---- spd/utils/logging_utils.py | 2 +- tests/app/test_server_api.py | 5 +- tests/metrics/fixtures.py | 18 +- tests/metrics/test_alive_components.py | 174 ------------ .../test_alive_components_distributed.py | 257 ------------------ .../test_ci_masked_recon_layerwise_loss.py | 8 +- tests/metrics/test_ci_masked_recon_loss.py | 8 +- .../test_ci_masked_recon_subset_loss.py | 1 - tests/metrics/test_faithfulness_loss.py | 4 +- .../test_stochastic_recon_layerwise_loss.py | 3 - tests/metrics/test_stochastic_recon_loss.py | 1 - .../test_stochastic_recon_subset_loss.py | 1 - tests/test_component_model.py | 26 +- tests/test_gpt2.py | 3 + tests/test_ih_transformer.py | 3 + tests/test_resid_mlp.py | 3 + tests/test_spd_losses.py | 40 +-- tests/test_tms.py | 3 + tests/test_wandb_run_loading.py | 6 +- 65 files changed, 618 insertions(+), 1195 deletions(-) create mode 100644 .claude/.nfs2f6abdf93653d08500002cba delete mode 100644 spd/metrics/alive_components.py delete mode 100644 tests/metrics/test_alive_components.py delete mode 100644 tests/metrics/test_alive_components_distributed.py diff --git a/.claude/.nfs2f6abdf93653d08500002cba b/.claude/.nfs2f6abdf93653d08500002cba new file mode 100644 index 000000000..dbd210f9b --- /dev/null +++ b/.claude/.nfs2f6abdf93653d08500002cba @@ -0,0 +1,22 @@ +{ + "permissions": { + "allow": [ + "Bash(source:*)", + "Bash(npm run check:*)", + "Bash(make check-app:*)", + "Bash(npm run lint:*)", + "Bash(git stash push:*)", + "Bash(grep:*)", + "Bash(npm run format:*)", + "Bash(npm run build:*)", + "Bash(npx eslint:*)", + "Bash(npx prettier:*)", + "Bash(git add:*)", + "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nAdd \"Use as Prompt\" popup for selected text in dataset explorer\n\n- Select text within story content to show floating popup\n- \"Use as Prompt\" button creates a custom prompt from selection\n- Text is cleaned: newlines → spaces, whitespace collapsed, trimmed\n- Shows hint when no run is loaded\n- Only triggers on .story-text elements \\(not headers/tags\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")", + "Bash(git revert:*)", + "Bash(python:*)", + "Bash(make:*)", + "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nOptimize random sampling and hide zero occurrence badges\n\n- Use random indices instead of shuffling entire dataset \\(~100x faster\\)\n- Hide occurrence badge when count is 0 \\(for random samples\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")" + ] + } +} diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 6d05f56d1..bdea2163b 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -126,7 +126,7 @@ def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: def get_sources_by_target( - model: ComponentModel, + model: ComponentModel[Any, Any], device: str, sampling: SamplingType, ) -> dict[str, list[str]]: @@ -141,7 +141,7 @@ def get_sources_by_target( batch: Float[Tensor, "batch seq"] = torch.zeros(2, 3, dtype=torch.long, device=device) with torch.no_grad(): - output_with_cache: OutputWithCache = model(batch, cache_type="input") + output_with_cache: OutputWithCache[Any] = model(batch, cache_type="input") with torch.no_grad(): ci = model.calc_causal_importances( @@ -170,7 +170,7 @@ def wte_hook( wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) with torch.enable_grad(): - comp_output_with_cache: OutputWithCache = model( + comp_output_with_cache: OutputWithCache[Any] = model( batch, mask_infos=mask_infos, cache_type="component_acts", @@ -305,7 +305,7 @@ def _compute_edges_for_target( def compute_edges_from_ci( - model: ComponentModel, + model: ComponentModel[Any, Any], tokens: Float[Tensor, "1 seq"], ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], pre_weight_acts: dict[str, Float[Tensor, "1 seq d_in"]], @@ -354,7 +354,7 @@ def compute_edges_from_ci( weight_deltas_and_masks=weight_deltas_and_masks, ) with torch.enable_grad(): - comp_output_with_cache: OutputWithCache = model( + comp_output_with_cache: OutputWithCache[Any] = model( tokens, mask_infos=unmasked_masks, cache_type="component_acts" ) @@ -490,7 +490,7 @@ def filter_ci_to_included_nodes( def compute_prompt_attributions( - model: ComponentModel, + model: ComponentModel[Any, Any], tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], output_prob_threshold: float, @@ -540,7 +540,7 @@ def compute_prompt_attributions( def compute_prompt_attributions_optimized( - model: ComponentModel, + model: ComponentModel[Any, Any], tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], optim_config: OptimCIConfig, @@ -624,7 +624,7 @@ class CIOnlyResult: def compute_ci_only( - model: ComponentModel, + model: ComponentModel[Any, Any], tokens: Float[Tensor, "1 seq"], sampling: SamplingType, ) -> CIOnlyResult: @@ -642,7 +642,7 @@ def compute_ci_only( CIOnlyResult containing CI values per layer, target model output probabilities, pre-weight activations, and component activations. """ with torch.no_grad(): - output_with_cache: OutputWithCache = model(tokens, cache_type="input") + output_with_cache: OutputWithCache[Any] = model(tokens, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=output_with_cache.cache, sampling=sampling, @@ -788,7 +788,7 @@ class InterventionResult: def compute_intervention_forward( - model: ComponentModel, + model: ComponentModel[Any, Any], tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] top_k: int, diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 8135b3ade..3bacf0aef 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -4,7 +4,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Literal +from typing import Any, Literal import torch import torch.nn.functional as F @@ -72,7 +72,7 @@ class OptimizableCIParams: ci_pre_sigmoid: dict[str, list[Tensor]] # layer_name -> list of [alive_at_pos] values alive_info: AliveComponentInfo - def create_ci_outputs(self, model: ComponentModel, device: str) -> CIOutputs: + def create_ci_outputs(self, model: ComponentModel[Any, Any], device: str) -> CIOutputs: """Expand sparse pre-sigmoid values to full CI tensors and create CIOutputs.""" pre_sigmoid: dict[str, Tensor] = {} @@ -139,7 +139,7 @@ def create_optimizable_ci_params( def compute_label_prob( - model: ComponentModel, + model: ComponentModel[Any, Any], tokens: Tensor, ci_lower_leaky: dict[str, Tensor], label_token: int, @@ -165,7 +165,7 @@ def compute_l0_stats( def compute_final_token_ce_kl( - model: ComponentModel, + model: ComponentModel[Any, Any], batch: Tensor, target_out: Tensor, ci: dict[str, Tensor], @@ -267,7 +267,7 @@ class OptimCIConfig: def optimize_ci_values( - model: ComponentModel, + model: ComponentModel[Any, Any], tokens: Tensor, config: OptimCIConfig, device: str, @@ -292,7 +292,7 @@ def optimize_ci_values( # Get initial CI values from the model with torch.no_grad(): - output_with_cache: OutputWithCache = model(tokens, cache_type="input") + output_with_cache: OutputWithCache[Any] = model(tokens, cache_type="input") initial_ci_outputs = model.calc_causal_importances( pre_weight_acts=output_with_cache.cache, sampling=config.sampling, diff --git a/spd/app/backend/routers/prompts.py b/spd/app/backend/routers/prompts.py index 8002aa11c..d2099c6ca 100644 --- a/spd/app/backend/routers/prompts.py +++ b/spd/app/backend/routers/prompts.py @@ -16,7 +16,6 @@ from spd.data import DatasetConfig, create_data_loader from spd.log import logger from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data # ============================================================================= # Schemas @@ -120,7 +119,7 @@ def generate() -> Generator[str]: if added_count >= n_prompts: break - tokens = extract_batch_data(batch).to(DEVICE) + tokens = batch["input_ids"].to(DEVICE) batch_size, n_seq = tokens.shape # Compute CI for the whole batch diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index 47dacfe51..1a86a06e0 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -110,7 +110,7 @@ class RunState: """Runtime state for a loaded run (model, tokenizer, etc.)""" run: Run - model: ComponentModel + model: ComponentModel[Any, Any] tokenizer: PreTrainedTokenizerBase sources_by_target: dict[str, list[str]] config: Config diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index cd6a2b742..2999efe21 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from functools import cached_property -from typing import Literal, NamedTuple +from typing import Any, Literal, NamedTuple import torch from jaxtyping import Bool, Float, Float16, Int @@ -17,14 +17,14 @@ def component_activations( - model: ComponentModel, + model: ComponentModel[Any, Any], device: torch.device | str, batch: Int[Tensor, "batch_size n_ctx"], ) -> dict[str, ActivationsTensor]: """Get the component activations over a **single** batch.""" causal_importances: dict[str, ActivationsTensor] with torch.no_grad(): - model_output: OutputWithCache = model( + model_output: OutputWithCache[Any] = model( batch.to(device), cache_type="input", ) diff --git a/spd/data.py b/spd/data.py index 840cb0543..2519f9fda 100644 --- a/spd/data.py +++ b/spd/data.py @@ -1,8 +1,10 @@ +from collections.abc import Generator from typing import Any import numpy as np import torch from datasets import Dataset, IterableDataset, load_dataset +from jaxtyping import Int from numpy.typing import NDArray from torch import Tensor from torch.utils.data import DataLoader, DistributedSampler @@ -151,7 +153,7 @@ def create_data_loader( dist_state: DistributedState | None = None, global_seed: int = 0, to_lower: bool = True, -) -> tuple[DataLoader[Any], PreTrainedTokenizer]: +) -> tuple[DataLoader[Int[Tensor, "batch seq"]], PreTrainedTokenizer]: """Create a DataLoader for the given dataset. Uses PyTorch's DistributedSampler to ensure each rank gets the correct @@ -252,7 +254,7 @@ def create_data_loader( generator = torch.Generator(device="cpu") generator.manual_seed(seed) - loader = DataLoader[Dataset | IterableDataset]( + loader = DataLoader[Int[Tensor, "batch seq"]]( torch_dataset, # pyright: ignore[reportArgumentType] batch_size=batch_size, sampler=sampler, @@ -265,7 +267,7 @@ def create_data_loader( return loader, tokenizer -def loop_dataloader[T](dl: DataLoader[T]): +def loop_dataloader[T](dl: DataLoader[T]) -> Generator[T]: """Loop over a dataloader, resetting the iterator when it is exhausted. Ensures that each epoch gets different data, even when using a distributed sampler. diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 0651a3a7b..3318b71d7 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -15,6 +15,7 @@ import itertools from dataclasses import dataclass from pathlib import Path +from typing import Any import torch import tqdm @@ -30,7 +31,6 @@ from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data from spd.utils.wandb_utils import parse_wandb_run_path @@ -42,7 +42,7 @@ class DatasetAttributionConfig: ci_threshold: float -def _build_component_layer_keys(model: ComponentModel) -> list[str]: +def _build_component_layer_keys(model: ComponentModel[Any, Any]) -> list[str]: """Build list of component layer keys in canonical order. Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. @@ -57,7 +57,7 @@ def _build_component_layer_keys(model: ComponentModel) -> list[str]: def _build_alive_masks( - model: ComponentModel, + model: ComponentModel[Any, Any], run_id: str, ci_threshold: float, n_components: int, @@ -206,7 +206,7 @@ def harvest_attributions( # Skip batches not assigned to this rank if world_size is not None and batch_idx % world_size != rank: continue - batch = extract_batch_data(batch_data).to(device) + batch = batch_data["input_ids"].to(device) harvester.process_batch(batch) logger.info( diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 2f6ba6973..3c1d18a64 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -45,7 +45,7 @@ class AttributionHarvester: def __init__( self, - model: ComponentModel, + model: ComponentModel[Any, Any], sources_by_target: dict[str, list[str]], n_components: int, vocab_size: int, @@ -161,7 +161,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No # Forward pass with gradients with torch.enable_grad(): - comp_output: OutputWithCache = self.model( + comp_output: OutputWithCache[Any] = self.model( tokens, mask_infos=mask_infos, cache_type="component_acts" ) diff --git a/spd/eval.py b/spd/eval.py index c6f0b47ff..9184eff1e 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -3,7 +3,6 @@ from collections.abc import Iterator from typing import Any -from jaxtyping import Float, Int from PIL import Image from torch import Tensor from torch.types import Number @@ -39,7 +38,6 @@ ) from spd.metrics import UnmaskedReconLoss from spd.metrics.base import Metric -from spd.metrics.ce_and_kl_losses import CEandKLLosses from spd.metrics.ci_histograms import CIHistograms from spd.metrics.ci_l0 import CI_L0 from spd.metrics.ci_masked_recon_layerwise_loss import CIMaskedReconLayerwiseLoss @@ -58,13 +56,12 @@ from spd.metrics.stochastic_hidden_acts_recon_loss import StochasticHiddenActsReconLoss from spd.metrics.stochastic_recon_layerwise_loss import StochasticReconLayerwiseLoss from spd.metrics.stochastic_recon_loss import StochasticReconLoss -from spd.metrics.stochastic_recon_subset_ce_and_kl import StochasticReconSubsetCEAndKL from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss from spd.metrics.uv_plots import UVPlots from spd.models.component_model import ComponentModel, OutputWithCache from spd.routing import AllLayersRouter, get_subset_router from spd.utils.distributed_utils import avg_metrics_across_ranks, is_distributed -from spd.utils.general_utils import dict_safe_update_, extract_batch_data +from spd.utils.general_utils import dict_safe_update_ MetricOutType = dict[str, str | Number | Image.Image | CustomChart] DistMetricOutType = dict[str, str | float | Image.Image | CustomChart] @@ -116,12 +113,12 @@ def avg_eval_metrics_across_ranks(metrics: MetricOutType, device: str) -> DistMe return {**metrics, **avg_metrics} -def init_metric( +def init_metric[BatchT, OutputT]( cfg: MetricConfigType, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], run_config: Config, device: str, -) -> Metric: +) -> Metric[BatchT, OutputT]: match cfg: case ImportanceMinimalityLossConfig(): metric = ImportanceMinimalityLoss( @@ -139,12 +136,13 @@ def init_metric( device=device, ) case CEandKLLossesConfig(): - metric = CEandKLLosses( - model=model, - device=device, - sampling=run_config.sampling, - rounding_threshold=cfg.rounding_threshold, - ) + raise ValueError("fix this typing!") + # metric = CEandKLLosses( + # model=model, + # device=device, + # sampling=run_config.sampling, + # rounding_threshold=cfg.rounding_threshold, + # ) case CIHistogramsConfig(): metric = CIHistograms(model=model, n_batches_accum=cfg.n_batches_accum) case CI_L0Config(): @@ -158,16 +156,17 @@ def init_metric( metric = CIMaskedReconSubsetLoss( model=model, device=device, - output_loss_type=run_config.output_loss_type, routing=cfg.routing, ) case CIMaskedReconLayerwiseLossConfig(): metric = CIMaskedReconLayerwiseLoss( - model=model, device=device, output_loss_type=run_config.output_loss_type + model=model, + device=device, ) case CIMaskedReconLossConfig(): metric = CIMaskedReconLoss( - model=model, device=device, output_loss_type=run_config.output_loss_type + model=model, + device=device, ) case CIMeanPerComponentConfig(): metric = CIMeanPerComponent(model=model, device=device) @@ -196,7 +195,6 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, ) case StochasticReconLossConfig(): metric = StochasticReconLoss( @@ -205,7 +203,6 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, ) case StochasticReconSubsetLossConfig(): metric = StochasticReconSubsetLoss( @@ -214,7 +211,6 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, routing=cfg.routing, ) case PGDReconLossConfig(): @@ -222,7 +218,6 @@ def init_metric( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, pgd_config=cfg, ) case PGDReconSubsetLossConfig(): @@ -230,7 +225,6 @@ def init_metric( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, pgd_config=cfg, routing=cfg.routing, ) @@ -239,19 +233,19 @@ def init_metric( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, pgd_config=cfg, ) case StochasticReconSubsetCEAndKLConfig(): - metric = StochasticReconSubsetCEAndKL( - model=model, - device=device, - sampling=run_config.sampling, - use_delta_component=run_config.use_delta_component, - n_mask_samples=run_config.n_mask_samples, - include_patterns=cfg.include_patterns, - exclude_patterns=cfg.exclude_patterns, - ) + raise ValueError("fix this typing!") + # metric = StochasticReconSubsetCEAndKL( + # model=model, + # device=device, + # sampling=run_config.sampling, + # use_delta_component=run_config.use_delta_component, + # n_mask_samples=run_config.n_mask_samples, + # include_patterns=cfg.include_patterns, + # exclude_patterns=cfg.exclude_patterns, + # ) case StochasticHiddenActsReconLossConfig(): metric = StochasticHiddenActsReconLoss( model=model, @@ -271,7 +265,6 @@ def init_metric( metric = UnmaskedReconLoss( model=model, device=device, - output_loss_type=run_config.output_loss_type, ) case _: @@ -281,10 +274,10 @@ def init_metric( return metric -def evaluate( +def evaluate[BatchT, OutputT]( eval_metric_configs: list[MetricConfigType], - model: ComponentModel, - eval_iterator: Iterator[Int[Tensor, "..."] | tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + model: ComponentModel[BatchT, OutputT], + eval_iterator: Iterator[BatchT], device: str, run_config: Config, slow_step: bool, @@ -293,7 +286,7 @@ def evaluate( ) -> MetricOutType: """Run evaluation and return a mapping of metric names to values/images.""" - metrics: list[Metric] = [] + metrics: list[Metric[BatchT, OutputT]] = [] for cfg in eval_metric_configs: metric = init_metric(cfg=cfg, model=model, run_config=run_config, device=device) if metric.slow and not slow_step: @@ -304,10 +297,9 @@ def evaluate( weight_deltas = model.calc_weight_deltas() for _ in range(n_eval_steps): - batch_raw = next(eval_iterator) - batch = extract_batch_data(batch_raw).to(device) + batch = next(eval_iterator) - target_output: OutputWithCache = model(batch, cache_type="input") + target_output: OutputWithCache[OutputT] = model.__call__(batch, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=target_output.cache, detach_inputs=False, @@ -337,14 +329,13 @@ def evaluate( return outputs -def evaluate_multibatch_pgd( +def evaluate_multibatch_pgd[BatchT, OutputT]( multibatch_pgd_eval_configs: list[ PGDMultiBatchReconLossConfig | PGDMultiBatchReconSubsetLossConfig ], - model: ComponentModel, - create_data_iter: CreateDataIter, + model: ComponentModel[BatchT, OutputT], + create_data_iter: CreateDataIter[BatchT], config: Config, - batch_dims: tuple[int, ...], device: str, ) -> dict[str, float]: """Calculate multibatch PGD metrics.""" @@ -367,11 +358,9 @@ def evaluate_multibatch_pgd( model=model, weight_deltas=weight_deltas, create_data_iter=create_data_iter, - output_loss_type=config.output_loss_type, router=router, sampling=config.sampling, use_delta_component=config.use_delta_component, - batch_dims=batch_dims, device=device, ).item() return metrics diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 1b0b268fc..1bbc45e69 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -7,13 +7,11 @@ from spd.configs import Config, IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer from spd.log import logger +from spd.models.component_model import make_run_batch_lm, recon_loss_kl from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import ( - save_pre_run_info, - set_seed, -) +from spd.utils.general_utils import save_pre_run_info, set_seed from spd.utils.run_utils import ExecutionStamp from spd.utils.wandb_utils import init_wandb @@ -100,6 +98,8 @@ def main( train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch_lm(config.pretrained_model_output_attr), + reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 3131398cf..fcb6c596a 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -10,6 +10,7 @@ from spd.configs import Config, LMTaskConfig from spd.data import DatasetConfig, create_data_loader from spd.log import logger +from spd.models.component_model import make_run_batch_lm, recon_loss_kl from spd.run_spd import optimize from spd.utils.distributed_utils import ( DistributedState, @@ -181,6 +182,8 @@ def main( train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch_lm(config.pretrained_model_output_attr), + reconstruction_loss=recon_loss_kl, out_dir=out_dir, ln_stds=ln_stds, ) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 75e423099..d27b742ae 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -13,6 +13,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.log import logger +from spd.models.component_model import pass_first_tuple_element_to_model, recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.distributed_utils import get_device @@ -109,6 +110,8 @@ def main( train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=pass_first_tuple_element_to_model, + reconstruction_loss=recon_loss_mse, out_dir=out_dir, ) diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 18c437a68..48a51abef 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -13,6 +13,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSTargetRunInfo from spd.log import logger +from spd.models.component_model import pass_first_tuple_element_to_model, recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.distributed_utils import get_device @@ -105,6 +106,8 @@ def main( train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=pass_first_tuple_element_to_model, + reconstruction_loss=recon_loss_mse, out_dir=out_dir, tied_weights=tied_weights, ) diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index 6c9b7858c..a246bc5e2 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -16,6 +16,7 @@ import time from dataclasses import asdict, dataclass from pathlib import Path +from typing import Any import torch import tqdm @@ -34,10 +35,9 @@ from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data -def _compute_u_norms(model: ComponentModel) -> dict[str, Float[Tensor, " C"]]: +def _compute_u_norms(model: ComponentModel[Any, Any]) -> dict[str, Float[Tensor, " C"]]: """Compute ||U[c,:]|| for each component c in each layer. Component activations (v_i^T @ a) have a scale invariance: scaling V by α and U by 1/α @@ -231,7 +231,7 @@ def harvest_activation_contexts( batch_range = range(config.n_batches) if config.n_batches is not None else itertools.count() for batch_idx in tqdm.tqdm(batch_range, desc="Harvesting", disable=rank is not None): try: - batch_data = extract_batch_data(next(train_iter)) + batch = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break @@ -240,7 +240,6 @@ def harvest_activation_contexts( if world_size is not None and batch_idx % world_size != rank: continue - batch = batch_data.to(device) with torch.no_grad(): out = model(batch, cache_type="input") probs = torch.softmax(out.output, dim=-1) diff --git a/spd/losses.py b/spd/losses.py index daef1773c..6e36ea2d3 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -1,7 +1,5 @@ -from typing import Literal - import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from spd.configs import ( @@ -37,27 +35,27 @@ unmasked_recon_loss, ) from spd.models.component_model import CIOutputs, ComponentModel +from spd.utils.general_utils import get_obj_device -def compute_total_loss( +def compute_total_loss[BatchT, OutputT]( loss_metric_configs: list[LossMetricConfigType], - model: ComponentModel, - batch: Int[Tensor, "..."], + model: ComponentModel[BatchT, OutputT], + batch: BatchT, ci: CIOutputs, - target_out: Tensor, + target_out: OutputT, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], pre_weight_acts: dict[str, Float[Tensor, "..."]], current_frac_of_training: float, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], ) -> tuple[Float[Tensor, ""], dict[str, float]]: """Compute weighted total loss and per-term raw values using new loss primitives. Returns (total, terms_dict). terms_dict contains raw per-term values (no coeffs) and a weighted total. """ - total = torch.tensor(0.0, device=batch.device) + total = torch.tensor(0.0, device=get_obj_device(model)) terms: dict[str, float] = {} for cfg in loss_metric_configs: @@ -79,14 +77,12 @@ def compute_total_loss( case UnmaskedReconLossConfig(): loss = unmasked_recon_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ) case CIMaskedReconSubsetLossConfig(): loss = ci_masked_recon_subset_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -95,7 +91,6 @@ def compute_total_loss( case CIMaskedReconLayerwiseLossConfig(): loss = ci_masked_recon_layerwise_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -103,7 +98,6 @@ def compute_total_loss( case CIMaskedReconLossConfig(): loss = ci_masked_recon_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -113,7 +107,6 @@ def compute_total_loss( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -124,7 +117,6 @@ def compute_total_loss( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -135,7 +127,6 @@ def compute_total_loss( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -145,7 +136,6 @@ def compute_total_loss( case PGDReconLossConfig(): loss = pgd_recon_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -155,7 +145,6 @@ def compute_total_loss( case PGDReconSubsetLossConfig(): loss = pgd_recon_subset_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -166,7 +155,6 @@ def compute_total_loss( case PGDReconLayerwiseLossConfig(): loss = pgd_recon_layerwise_loss( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, diff --git a/spd/metrics/alive_components.py b/spd/metrics/alive_components.py deleted file mode 100644 index 4e85fe954..000000000 --- a/spd/metrics/alive_components.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Track which components are alive based on their firing frequency.""" - -import torch -from einops import reduce -from jaxtyping import Bool, Float, Int -from torch import Tensor -from torch.distributed import ReduceOp - -from spd.utils.distributed_utils import all_reduce - - -class AliveComponentsTracker: - """Track which components are considered alive based on their firing frequency. - - A component is considered alive if it has fired (importance > threshold) within - the last n_examples_until_dead examples. - - NOTE: This does not directly inherit from spd.metrics.base.Metric, but its update and compute - methods have a similar signature to the Metric interface. - """ - - def __init__( - self, - module_to_c: dict[str, int], - device: str, - n_examples_until_dead: int, - ci_alive_threshold: float, - global_n_examples_per_batch: int, - ) -> None: - """Initialize the tracker. - - Args: - module_to_c: Dictionary mapping module names to their C values - device: Device to store tensors on - n_examples_until_dead: Number of examples without firing before component is considered dead - ci_alive_threshold: Causal importance threshold above which a component is considered 'firing' - global_n_examples_per_batch: Number of examples per batch across all ranks (including - batch and sequence dimensions) - """ - self.n_examples_until_dead = n_examples_until_dead - self.ci_alive_threshold = ci_alive_threshold - self.n_batches_until_dead = self.n_examples_until_dead // global_n_examples_per_batch - - self.n_batches_since_fired: dict[str, Int[Tensor, " C"]] = { - m: torch.zeros(c, dtype=torch.int64, device=device) for m, c in module_to_c.items() - } - - def update(self, ci: dict[str, Float[Tensor, "... C"]]) -> None: - """Update tracking based on importance values from a batch. - - Args: - ci: Dict mapping module names to causal importance tensors with shape (..., C) - """ - for module_name, importance_vals in ci.items(): - firing: Bool[Tensor, " C"] = reduce( - importance_vals > self.ci_alive_threshold, "... C -> C", torch.any - ) - self.n_batches_since_fired[module_name] = torch.where( - firing, - 0, - self.n_batches_since_fired[module_name] + 1, - ) - - def compute(self) -> dict[str, int]: - """Compute the number of alive components per module. - - Returns: - Dict mapping module names to number of alive components - """ - out: dict[str, int] = {} - for module_name, n_batches_since_fired in self.n_batches_since_fired.items(): - # Use MIN reduction so that a component is alive if it fired on ANY rank - batches_since_fired_reduced = all_reduce(n_batches_since_fired, op=ReduceOp.MIN) - n_alive = int((batches_since_fired_reduced < self.n_batches_until_dead).sum().item()) - out[module_name] = n_alive - return out diff --git a/spd/metrics/base.py b/spd/metrics/base.py index 97665464b..a7903c860 100644 --- a/spd/metrics/base.py +++ b/spd/metrics/base.py @@ -6,13 +6,13 @@ from typing import Any, ClassVar, Protocol -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from spd.models.component_model import CIOutputs -class Metric(Protocol): +class Metric[BatchT, OutputT](Protocol): """Interface for metrics that can be used in training and/or evaluation.""" slow: ClassVar[bool] = False @@ -21,8 +21,8 @@ class Metric(Protocol): def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: CIOutputs, current_frac_of_training: float, diff --git a/spd/metrics/ce_and_kl_losses.py b/spd/metrics/ce_and_kl_losses.py index d93dcbc86..872336954 100644 --- a/spd/metrics/ce_and_kl_losses.py +++ b/spd/metrics/ce_and_kl_losses.py @@ -17,7 +17,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class CEandKLLosses(Metric): +class CEandKLLosses(Metric[Tensor, Tensor]): """CE and KL losses for different masking strategies. NOTE: Assumes all batches and sequences are the same size. @@ -47,7 +47,7 @@ class CEandKLLosses(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[Tensor, Tensor], device: str, sampling: SamplingType, rounding_threshold: float, diff --git a/spd/metrics/ci_histograms.py b/spd/metrics/ci_histograms.py index fcf6fb2ac..22e6f7386 100644 --- a/spd/metrics/ci_histograms.py +++ b/spd/metrics/ci_histograms.py @@ -12,13 +12,13 @@ from spd.utils.distributed_utils import gather_all_tensors -class CIHistograms(Metric): +class CIHistograms(Metric[Any, Any]): slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" def __init__( self, - model: ComponentModel, + model: ComponentModel[Any, Any], n_batches_accum: int | None = None, ): self.n_batches_accum = n_batches_accum diff --git a/spd/metrics/ci_l0.py b/spd/metrics/ci_l0.py index 9a2047ff8..b534f5c9b 100644 --- a/spd/metrics/ci_l0.py +++ b/spd/metrics/ci_l0.py @@ -12,7 +12,7 @@ from spd.utils.distributed_utils import all_reduce -class CI_L0(Metric): +class CI_L0(Metric[Any, Any]): """L0 metric for CI values. NOTE: Assumes all batches and sequences are the same size. @@ -22,7 +22,7 @@ class CI_L0(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[Any, Any], device: str, ci_alive_threshold: float, groups: dict[str, list[str]] | None = None, diff --git a/spd/metrics/ci_masked_recon_layerwise_loss.py b/spd/metrics/ci_masked_recon_layerwise_loss.py index b7ff12be9..5db109845 100644 --- a/spd/metrics/ci_masked_recon_layerwise_loss.py +++ b/spd/metrics/ci_masked_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -9,84 +9,81 @@ from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm +from spd.utils.general_utils import get_obj_device -def _ci_masked_recon_layerwise_loss_update( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def _ci_masked_recon_layerwise_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], ) -> tuple[Float[Tensor, ""], int]: - sum_loss = torch.tensor(0.0, device=batch.device) - n_examples = 0 + sum_loss = torch.tensor(0.0, device=get_obj_device(model)) + sum_n_examples = 0 mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) for module_name, mask_info in mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - n_examples += out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() + loss, n_examples = model.reconstruction_loss(out, target_out) sum_loss += loss - return sum_loss, n_examples + sum_n_examples += n_examples + return sum_loss, sum_n_examples def _ci_masked_recon_layerwise_loss_compute( - sum_loss: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int + sum_loss: Float[Tensor, ""], sum_n_examples: Int[Tensor, ""] | int ) -> Float[Tensor, ""]: - return sum_loss / n_examples + return sum_loss / sum_n_examples -def ci_masked_recon_layerwise_loss( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def ci_masked_recon_layerwise_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], ) -> Float[Tensor, ""]: - sum_loss, n_examples = _ci_masked_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _ci_masked_recon_layerwise_loss_update( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, ) - return _ci_masked_recon_layerwise_loss_compute(sum_loss, n_examples) + return _ci_masked_recon_layerwise_loss_compute(sum_loss, sum_n_examples) -class CIMaskedReconLayerwiseLoss(Metric): +class CIMaskedReconLayerwiseLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when masking with CI values directly one layer at a time.""" metric_section: ClassVar[str] = "loss" def __init__( - self, model: ComponentModel, device: str, output_loss_type: Literal["mse", "kl"] + self, + model: ComponentModel[BatchT, OutputT], + device: str, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.sum_loss = torch.tensor(0.0, device=device) - self.n_examples = torch.tensor(0, device=device) + self.sum_n_examples = torch.tensor(0, device=device) @override def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, **_: Any, ) -> None: - sum_loss, n_examples = _ci_masked_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _ci_masked_recon_layerwise_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, ) self.sum_loss += sum_loss - self.n_examples += n_examples + self.sum_n_examples += sum_n_examples @override def compute(self) -> Float[Tensor, ""]: sum_loss = all_reduce(self.sum_loss, op=ReduceOp.SUM) - n_examples = all_reduce(self.n_examples, op=ReduceOp.SUM) - return _ci_masked_recon_layerwise_loss_compute(sum_loss, n_examples) + sum_n_examples = all_reduce(self.sum_n_examples, op=ReduceOp.SUM) + return _ci_masked_recon_layerwise_loss_compute(sum_loss, sum_n_examples) diff --git a/spd/metrics/ci_masked_recon_loss.py b/spd/metrics/ci_masked_recon_loss.py index a11c11469..c085eb0e9 100644 --- a/spd/metrics/ci_masked_recon_loss.py +++ b/spd/metrics/ci_masked_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -9,21 +9,17 @@ from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm -def _ci_masked_recon_loss_update( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def _ci_masked_recon_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], ) -> tuple[Float[Tensor, ""], int]: mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) out = model(batch, mask_infos=mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - return loss, out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + return model.reconstruction_loss(out, target_out) def _ci_masked_recon_loss_compute( @@ -32,16 +28,14 @@ def _ci_masked_recon_loss_compute( return sum_loss / n_examples -def ci_masked_recon_loss( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def ci_masked_recon_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_loss_update( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, @@ -49,16 +43,17 @@ def ci_masked_recon_loss( return _ci_masked_recon_loss_compute(sum_loss, n_examples) -class CIMaskedReconLoss(Metric): +class CIMaskedReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when masking with CI values directly on all component layers.""" metric_section: ClassVar[str] = "loss" def __init__( - self, model: ComponentModel, device: str, output_loss_type: Literal["mse", "kl"] + self, + model: ComponentModel[BatchT, OutputT], + device: str, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -66,14 +61,13 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, **_: Any, ) -> None: sum_loss, n_examples = _ci_masked_recon_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, diff --git a/spd/metrics/ci_masked_recon_subset_loss.py b/spd/metrics/ci_masked_recon_subset_loss.py index 0a2e83441..bbd40a81b 100644 --- a/spd/metrics/ci_masked_recon_subset_loss.py +++ b/spd/metrics/ci_masked_recon_subset_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -11,14 +11,13 @@ from spd.models.components import make_mask_infos from spd.routing import Router, get_subset_router from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm +from spd.utils.general_utils import get_obj_device -def _ci_masked_recon_subset_loss_update( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def _ci_masked_recon_subset_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], router: Router, ) -> tuple[Float[Tensor, ""], int]: @@ -32,9 +31,7 @@ def _ci_masked_recon_subset_loss_update( weight_deltas_and_masks=None, ) out = model(batch, mask_infos=mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - return loss, out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + return model.reconstruction_loss(out, target_out) def _ci_masked_recon_subset_loss_compute( @@ -43,39 +40,35 @@ def _ci_masked_recon_subset_loss_compute( return sum_loss / n_examples -def ci_masked_recon_subset_loss( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def ci_masked_recon_subset_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], routing: SubsetRoutingType, ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=model, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, - router=get_subset_router(routing, batch.device), + router=get_subset_router(routing, device=get_obj_device(model)), ) return _ci_masked_recon_subset_loss_compute(sum_loss, n_examples) -class CIMaskedReconSubsetLoss(Metric): +class CIMaskedReconSubsetLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when masking with raw CI values and routing to subsets of component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, - output_loss_type: Literal["mse", "kl"], routing: SubsetRoutingType, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.router = get_subset_router(routing, device) self.sum_loss = torch.tensor(0.0, device=device) @@ -85,14 +78,13 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, **_: Any, ) -> None: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, diff --git a/spd/metrics/ci_mean_per_component.py b/spd/metrics/ci_mean_per_component.py index 88800e7a4..fb4373727 100644 --- a/spd/metrics/ci_mean_per_component.py +++ b/spd/metrics/ci_mean_per_component.py @@ -11,11 +11,11 @@ from spd.utils.distributed_utils import all_reduce -class CIMeanPerComponent(Metric): +class CIMeanPerComponent(Metric[Any, Any]): slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" - def __init__(self, model: ComponentModel, device: str) -> None: + def __init__(self, model: ComponentModel[Any, Any], device: str) -> None: self.components = model.components self.component_ci_sums: dict[str, Tensor] = { module_name: torch.zeros(model.module_to_c[module_name], device=device) diff --git a/spd/metrics/component_activation_density.py b/spd/metrics/component_activation_density.py index 5f10a86cc..eb56f83fd 100644 --- a/spd/metrics/component_activation_density.py +++ b/spd/metrics/component_activation_density.py @@ -13,13 +13,15 @@ from spd.utils.distributed_utils import all_reduce -class ComponentActivationDensity(Metric): +class ComponentActivationDensity(Metric[Any, Any]): """Activation density for each component.""" slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" - def __init__(self, model: ComponentModel, device: str, ci_alive_threshold: float) -> None: + def __init__( + self, model: ComponentModel[Any, Any], device: str, ci_alive_threshold: float + ) -> None: self.model = model self.ci_alive_threshold = ci_alive_threshold diff --git a/spd/metrics/faithfulness_loss.py b/spd/metrics/faithfulness_loss.py index d2b02e0d4..3c902b65c 100644 --- a/spd/metrics/faithfulness_loss.py +++ b/spd/metrics/faithfulness_loss.py @@ -35,12 +35,12 @@ def faithfulness_loss(weight_deltas: dict[str, Float[Tensor, "d_out d_in"]]) -> return _faithfulness_loss_compute(sum_loss, total_params) -class FaithfulnessLoss(Metric): +class FaithfulnessLoss(Metric[Any, Any]): """MSE between the target weights and the sum of the components.""" metric_section: ClassVar[str] = "loss" - def __init__(self, model: ComponentModel, device: str) -> None: + def __init__(self, model: ComponentModel[Any, Any], device: str) -> None: self.model = model self.sum_loss = torch.tensor(0.0, device=device) self.total_params = torch.tensor(0, device=device) diff --git a/spd/metrics/identity_ci_error.py b/spd/metrics/identity_ci_error.py index d619c5082..b5771d7e9 100644 --- a/spd/metrics/identity_ci_error.py +++ b/spd/metrics/identity_ci_error.py @@ -9,7 +9,7 @@ from spd.utils.target_ci_solutions import compute_target_metrics, make_target_ci_solution -class IdentityCIError(Metric): +class IdentityCIError(Metric[Any, Any]): """Error between the CI values and an Identity or Dense CI pattern.""" slow: ClassVar[bool] = True @@ -19,7 +19,7 @@ class IdentityCIError(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[Any, Any], sampling: SamplingType, identity_ci: list[dict[str, str | int]] | None = None, dense_ci: list[dict[str, str | int]] | None = None, diff --git a/spd/metrics/importance_minimality_loss.py b/spd/metrics/importance_minimality_loss.py index 5bd6ca31f..d06f9e47e 100644 --- a/spd/metrics/importance_minimality_loss.py +++ b/spd/metrics/importance_minimality_loss.py @@ -144,7 +144,7 @@ def importance_minimality_loss( ) -class ImportanceMinimalityLoss(Metric): +class ImportanceMinimalityLoss(Metric[Any, Any]): """L_p loss on the sum of CI values. NOTE: We don't normalize over the number of layers because a change in the number of layers @@ -165,7 +165,7 @@ class ImportanceMinimalityLoss(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[Any, Any], device: str, pnorm: float, beta: float, diff --git a/spd/metrics/permuted_ci_plots.py b/spd/metrics/permuted_ci_plots.py index d5baa8b28..f0b340b70 100644 --- a/spd/metrics/permuted_ci_plots.py +++ b/spd/metrics/permuted_ci_plots.py @@ -9,7 +9,7 @@ from spd.plotting import plot_causal_importance_vals -class PermutedCIPlots(Metric): +class PermutedCIPlots(Metric[Any, Any]): slow: ClassVar[bool] = True input_magnitude: ClassVar[float] = 0.75 @@ -17,7 +17,7 @@ class PermutedCIPlots(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[Any, Any], sampling: SamplingType, identity_patterns: list[str] | None = None, dense_patterns: list[str] | None = None, diff --git a/spd/metrics/pgd_masked_recon_layerwise_loss.py b/spd/metrics/pgd_masked_recon_layerwise_loss.py index 787cad8a2..1c21fa175 100644 --- a/spd/metrics/pgd_masked_recon_layerwise_loss.py +++ b/spd/metrics/pgd_masked_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -13,12 +13,11 @@ from spd.utils.distributed_utils import all_reduce -def _pgd_recon_layerwise_loss_update( +def _pgd_recon_layerwise_loss_update[BatchT, OutputT]( *, - model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -33,7 +32,6 @@ def _pgd_recon_layerwise_loss_update( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, router=LayerRouter(device=device, layer_name=layer), pgd_config=pgd_config, ) @@ -42,12 +40,11 @@ def _pgd_recon_layerwise_loss_update( return sum_loss, n_examples -def pgd_recon_layerwise_loss( +def pgd_recon_layerwise_loss[BatchT, OutputT]( *, - model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -56,7 +53,6 @@ def pgd_recon_layerwise_loss( model=model, batch=batch, target_out=target_out, - output_loss_type=output_loss_type, ci=ci, weight_deltas=weight_deltas, pgd_config=pgd_config, @@ -64,7 +60,7 @@ def pgd_recon_layerwise_loss( return sum_loss / n_examples -class PGDReconLayerwiseLoss(Metric): +class PGDReconLayerwiseLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when masking with adversarially-optimized values and routing to one layer at a time.""" @@ -72,15 +68,13 @@ class PGDReconLayerwiseLoss(Metric): def __init__( self, - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], + model: ComponentModel[BatchT, OutputT], pgd_config: PGDConfig, device: str, use_delta_component: bool, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.use_delta_component: bool = use_delta_component self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -89,8 +83,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, @@ -99,7 +93,6 @@ def update( model=self.model, batch=batch, target_out=target_out, - output_loss_type=self.output_loss_type, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, pgd_config=self.pgd_config, diff --git a/spd/metrics/pgd_masked_recon_loss.py b/spd/metrics/pgd_masked_recon_loss.py index 7d35e149f..b02413526 100644 --- a/spd/metrics/pgd_masked_recon_loss.py +++ b/spd/metrics/pgd_masked_recon_loss.py @@ -1,7 +1,7 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp @@ -13,12 +13,11 @@ from spd.utils.distributed_utils import all_reduce -def pgd_recon_loss( +def pgd_recon_loss[BatchT, OutputT]( *, - model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -29,14 +28,13 @@ def pgd_recon_loss( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, router=AllLayersRouter(), pgd_config=pgd_config, ) return sum_loss / n_examples -class PGDReconLoss(Metric): +class PGDReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when masking with adversarially-optimized values and routing to all component layers.""" @@ -44,15 +42,13 @@ class PGDReconLoss(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, - output_loss_type: Literal["mse", "kl"], pgd_config: PGDConfig, use_delta_component: bool, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.use_delta_component: bool = use_delta_component self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -61,8 +57,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, @@ -73,7 +69,6 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, target_out=target_out, - output_loss_type=self.output_loss_type, router=AllLayersRouter(), pgd_config=self.pgd_config, ) diff --git a/spd/metrics/pgd_masked_recon_subset_loss.py b/spd/metrics/pgd_masked_recon_subset_loss.py index a9e8a7eac..a34c3677a 100644 --- a/spd/metrics/pgd_masked_recon_subset_loss.py +++ b/spd/metrics/pgd_masked_recon_subset_loss.py @@ -1,7 +1,7 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp @@ -11,14 +11,14 @@ from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import get_subset_router from spd.utils.distributed_utils import all_reduce +from spd.utils.general_utils import get_obj_device -def pgd_recon_subset_loss( +def pgd_recon_subset_loss[BatchT, OutputT]( *, - model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -30,14 +30,13 @@ def pgd_recon_subset_loss( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, - router=get_subset_router(routing, batch.device), + router=get_subset_router(routing, device=get_obj_device(model)), pgd_config=pgd_config, ) return sum_loss / n_examples -class PGDReconSubsetLoss(Metric): +class PGDReconSubsetLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when masking with adversarially-optimized values and routing to subsets of component layers.""" @@ -45,18 +44,16 @@ class PGDReconSubsetLoss(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, - output_loss_type: Literal["mse", "kl"], use_delta_component: bool, pgd_config: PGDConfig, routing: SubsetRoutingType, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.use_delta_component: bool = use_delta_component - self.router = get_subset_router(routing, device) + self.router = get_subset_router(routing, device=get_obj_device(model)) self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -65,8 +62,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -77,7 +74,6 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, target_out=target_out, - output_loss_type=self.output_loss_type, router=self.router, pgd_config=self.pgd_config, ) diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index bb3c5f090..8f97ac816 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -1,9 +1,9 @@ -from collections.abc import Callable, Iterator +from collections.abc import Iterator from functools import partial -from typing import Literal +from typing import Protocol import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp @@ -13,16 +13,15 @@ from spd.models.components import RoutingMasks, make_mask_infos from spd.routing import Router from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, extract_batch_data +from spd.utils.general_utils import get_obj_device -def pgd_masked_recon_loss_update( - model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], +def pgd_masked_recon_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + target_out: OutputT, router: Router, pgd_config: PGDConfig, ) -> tuple[Float[Tensor, ""], int]: @@ -45,7 +44,7 @@ def pgd_masked_recon_loss_update( singleton_batch_dims = [1 for _ in batch_dims] shape = torch.Size([*singleton_batch_dims, mask_c]) adv_sources[module_name] = _get_pgd_init_tensor( - pgd_config.init, shape, batch.device + pgd_config.init, shape, device=get_obj_device(model) ).requires_grad_(True) fwd_pass = partial( @@ -57,7 +56,6 @@ def pgd_masked_recon_loss_update( weight_deltas=weight_deltas, routing_masks=routing_masks, target_out=target_out, - output_loss_type=output_loss_type, batch_dims=batch_dims, ) @@ -79,22 +77,18 @@ def pgd_masked_recon_loss_update( return fwd_pass() -CreateDataIter = Callable[ - [], - Iterator[Int[Tensor, "..."]] | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], -] +class CreateDataIter[BatchT](Protocol): + def __call__(self) -> Iterator[BatchT]: ... -def calc_multibatch_pgd_masked_recon_loss( +def calc_multibatch_pgd_masked_recon_loss[BatchT, OutputT]( pgd_config: PGDMultiBatchConfig, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - create_data_iter: CreateDataIter, - output_loss_type: Literal["mse", "kl"], + create_data_iter: CreateDataIter[BatchT], router: Router, sampling: SamplingType, use_delta_component: bool, - batch_dims: tuple[int, ...], device: str, ) -> Float[Tensor, ""]: """PGD masked reconstruction loss with gradient accumulation over multiple batches. @@ -112,15 +106,21 @@ def calc_multibatch_pgd_masked_recon_loss( router: Router to use for routing masks sampling: Sampling mode for causal importance calculation use_delta_component: Whether to include weight delta component - batch_dims: Dimensions of batch (e.g., (batch_size,) or (batch_size, seq_len)) Returns: Final reconstruction loss after PGD optimization """ - singleton_batch_dims = [1 for _ in batch_dims] + + demo_batch = next(create_data_iter()) + demo_output = model(demo_batch, cache_type="input") + ci_demo = model.calc_causal_importances( + pre_weight_acts=demo_output.cache, sampling=sampling + ).lower_leaky adv_sources: dict[str, Float[Tensor, "*ones mask_c"]] = {} for module_name in model.target_module_paths: - module_c = model.module_to_c[module_name] + demo_ci = ci_demo[module_name] + *batch_dims, module_c = demo_ci.shape + singleton_batch_dims = [1 for _ in batch_dims] mask_c = module_c if not use_delta_component else module_c + 1 shape = torch.Size([*singleton_batch_dims, mask_c]) adv_sources[module_name] = _get_pgd_init_tensor( @@ -134,34 +134,31 @@ def calc_multibatch_pgd_masked_recon_loss( model=model, weight_deltas=weight_deltas, device=device, - output_loss_type=output_loss_type, sampling=sampling, router=router, - batch_dims=batch_dims, ) for _ in range(pgd_config.n_steps): assert all(adv.grad is None for adv in adv_sources.values()) - _, _, adv_sources_grads = fwd_bwd_fn(data_iter=create_data_iter()) + _, _, adv_sources_sum_grads = fwd_bwd_fn(data_iter=create_data_iter()) with torch.no_grad(): for k in adv_sources: - adv_sources[k].add_(pgd_config.step_size * adv_sources_grads[k].sign()) + adv_sources[k].add_(pgd_config.step_size * adv_sources_sum_grads[k].sign()) adv_sources[k].clamp_(0.0, 1.0) - final_loss, final_n_examples, _ = fwd_bwd_fn(data_iter=create_data_iter()) - return final_loss / final_n_examples + final_loss, final_sum_n_examples, _ = fwd_bwd_fn(data_iter=create_data_iter()) + return final_loss / final_sum_n_examples -def _forward_with_adv_sources( - model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], +def _forward_with_adv_sources[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, adv_sources: dict[str, Float[Tensor, "*batch_dim_or_ones mask_c"]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing_masks: RoutingMasks, - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + target_out: OutputT, batch_dims: tuple[int, ...], ): expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} @@ -183,27 +180,20 @@ def _forward_with_adv_sources( ) out = model(batch, mask_infos=mask_infos) - sum_loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - - n_examples = ( - target_out.shape.numel() if output_loss_type == "mse" else target_out.shape[:-1].numel() - ) + sum_loss, n_examples = model.reconstruction_loss(out, target_out) return sum_loss, n_examples -def _multibatch_pgd_fwd_bwd( +def _multibatch_pgd_fwd_bwd[BatchT, OutputT]( adv_sources: dict[str, Float[Tensor, "*ones mask_c"]], pgd_config: PGDMultiBatchConfig, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - data_iter: Iterator[Int[Tensor, "..."]] - | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + data_iter: Iterator[BatchT], device: torch.device | str, - output_loss_type: Literal["mse", "kl"], router: Router, sampling: SamplingType, - batch_dims: tuple[int, ...], ) -> tuple[Float[Tensor, ""], int, dict[str, Float[Tensor, "*ones mask_c"]]]: """Perform a forward and backward pass over multiple batches with gradient accumulation. @@ -213,33 +203,34 @@ def _multibatch_pgd_fwd_bwd( - The gradients of the adv_sources (dict keyed by module name) """ pgd_step_accum_sum_loss = torch.tensor(0.0, device=device) - pgd_step_accum_n_examples = 0 - pgd_step_accum_grads = {k: torch.zeros_like(v) for k, v in adv_sources.items()} + pgd_step_accum_sum_n_examples = 0 + pgd_step_accum_sum_grads = {k: torch.zeros_like(v) for k, v in adv_sources.items()} for microbatch_idx in range(pgd_config.gradient_accumulation_steps): try: - microbatch_item = next(data_iter) + microbatch = next(data_iter) except StopIteration: logger.warning(f"Dataloader exhausted after {microbatch_idx} batches, ending PGD step.") break - microbatch = extract_batch_data(microbatch_item).to(device) # NOTE: technically this is duplicated work across PGD steps, but that's the price we pay to # enable accumulating gradients over more microbatches than we'd be able to fit CI values in # memory for. In other words, you can't fit 100,000 microbatches worth of CI values in memory. - target_model_output: OutputWithCache = model(microbatch, cache_type="input") + target_model_output: OutputWithCache[OutputT] = model(microbatch, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=target_model_output.cache, sampling=sampling, ).lower_leaky + batch_dims = next(iter(ci.values())).shape[:-1] + # It's important that we call this every microbatch to ensure stochastic routing masks are # sampled independently for each example. routing_masks = router.get_masks( module_names=model.target_module_paths, mask_shape=batch_dims ) - batch_sum_loss, batch_n_examples = _forward_with_adv_sources( + batch_sum_loss, batch_sum_n_examples = _forward_with_adv_sources( model=model, batch=microbatch, adv_sources=adv_sources, @@ -247,21 +238,20 @@ def _multibatch_pgd_fwd_bwd( weight_deltas=weight_deltas, routing_masks=routing_masks, target_out=target_model_output.output, - output_loss_type=output_loss_type, batch_dims=batch_dims, ) pgd_step_accum_sum_loss += batch_sum_loss - pgd_step_accum_n_examples += batch_n_examples + pgd_step_accum_sum_n_examples += batch_sum_n_examples # important: take gradient wrt the UNEXPANDED adv_sources, not the expanded ones grads = torch.autograd.grad(batch_sum_loss, list(adv_sources.values())) for k, g in zip(adv_sources.keys(), grads, strict=True): - pgd_step_accum_grads[k] += all_reduce(g, op=ReduceOp.SUM).detach() + pgd_step_accum_sum_grads[k] += all_reduce(g, op=ReduceOp.SUM).detach() del target_model_output, ci - return pgd_step_accum_sum_loss, pgd_step_accum_n_examples, pgd_step_accum_grads + return pgd_step_accum_sum_loss, pgd_step_accum_sum_n_examples, pgd_step_accum_sum_grads def _get_pgd_init_tensor( diff --git a/spd/metrics/stochastic_hidden_acts_recon_loss.py b/spd/metrics/stochastic_hidden_acts_recon_loss.py index 814e6e18c..7b889da93 100644 --- a/spd/metrics/stochastic_hidden_acts_recon_loss.py +++ b/spd/metrics/stochastic_hidden_acts_recon_loss.py @@ -14,11 +14,11 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_hidden_acts_recon_loss_update( - model: ComponentModel, +def _stochastic_hidden_acts_recon_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: BatchT, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -59,11 +59,11 @@ def _stochastic_hidden_acts_recon_loss_compute( return sum_mse / n_examples -def stochastic_hidden_acts_recon_loss( - model: ComponentModel, +def stochastic_hidden_acts_recon_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: BatchT, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -80,14 +80,14 @@ def stochastic_hidden_acts_recon_loss( return _stochastic_hidden_acts_recon_loss_compute(sum_mse, n_examples) -class StochasticHiddenActsReconLoss(Metric): +class StochasticHiddenActsReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Reconstruction loss between target and stochastic hidden activations when sampling with stochastic masks.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, sampling: SamplingType, use_delta_component: bool, @@ -104,7 +104,7 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: BatchT, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/stochastic_recon_layerwise_loss.py b/spd/metrics/stochastic_recon_layerwise_loss.py index b14d57fe3..da8df9f0d 100644 --- a/spd/metrics/stochastic_recon_layerwise_loss.py +++ b/spd/metrics/stochastic_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -11,23 +11,22 @@ from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device -def _stochastic_recon_layerwise_loss_update( - model: ComponentModel, +def _stochastic_recon_layerwise_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) sum_loss = torch.tensor(0.0, device=device) - n_examples = 0 + sum_n_examples = 0 stochastic_mask_infos_list = [ calc_stochastic_component_mask_info( @@ -42,89 +41,83 @@ def _stochastic_recon_layerwise_loss_update( for stochastic_mask_infos in stochastic_mask_infos_list: for module_name, mask_info in stochastic_mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - - n_examples += out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() + loss, batch_n_examples = model.reconstruction_loss(out, target_out) sum_loss += loss - return sum_loss, n_examples + sum_n_examples += batch_n_examples + return sum_loss, sum_n_examples def _stochastic_recon_layerwise_loss_compute( - sum_loss: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int + sum_loss: Float[Tensor, ""], sum_n_examples: Int[Tensor, ""] | int ) -> Float[Tensor, ""]: - return sum_loss / n_examples + return sum_loss / sum_n_examples -def stochastic_recon_layerwise_loss( - model: ComponentModel, +def stochastic_recon_layerwise_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, ) -> Float[Tensor, ""]: - sum_loss, n_examples = _stochastic_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _stochastic_recon_layerwise_loss_update( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, ) - return _stochastic_recon_layerwise_loss_compute(sum_loss, n_examples) + return _stochastic_recon_layerwise_loss_compute(sum_loss, sum_n_examples) -class StochasticReconLayerwiseLoss(Metric): +class StochasticReconLayerwiseLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when sampling with stochastic masks one layer at a time.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.sum_loss = torch.tensor(0.0, device=device) - self.n_examples = torch.tensor(0, device=device) + self.sum_n_examples = torch.tensor(0, device=device) @override def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, ) -> None: - sum_loss, n_examples = _stochastic_recon_layerwise_loss_update( + sum_loss, sum_n_examples = _stochastic_recon_layerwise_loss_update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, ) self.sum_loss += sum_loss - self.n_examples += n_examples + self.sum_n_examples += sum_n_examples @override def compute(self) -> Float[Tensor, ""]: sum_loss = all_reduce(self.sum_loss, op=ReduceOp.SUM) - n_examples = all_reduce(self.n_examples, op=ReduceOp.SUM) - return _stochastic_recon_layerwise_loss_compute(sum_loss, n_examples) + sum_n_examples = all_reduce(self.sum_n_examples, op=ReduceOp.SUM) + return _stochastic_recon_layerwise_loss_compute(sum_loss, sum_n_examples) diff --git a/spd/metrics/stochastic_recon_loss.py b/spd/metrics/stochastic_recon_loss.py index 46cb0ad61..42601b514 100644 --- a/spd/metrics/stochastic_recon_loss.py +++ b/spd/metrics/stochastic_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -11,118 +11,109 @@ from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device -def _stochastic_recon_loss_update( - model: ComponentModel, +def _stochastic_recon_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) sum_loss = torch.tensor(0.0, device=device) - n_examples = 0 + sum_n_examples = 0 - stoch_mask_infos_list = [ - calc_stochastic_component_mask_info( + for _ in range(n_mask_samples): + stoch_mask_infos = calc_stochastic_component_mask_info( causal_importances=ci, component_mask_sampling=sampling, weight_deltas=weight_deltas, router=AllLayersRouter(), ) - for _ in range(n_mask_samples) - ] - for stoch_mask_infos in stoch_mask_infos_list: out = model(batch, mask_infos=stoch_mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - n_examples += out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, n_examples = model.reconstruction_loss(out, target_out) sum_loss += loss - return sum_loss, n_examples + sum_n_examples += n_examples + + return sum_loss, sum_n_examples def _stochastic_recon_loss_compute( - sum_loss: Float[Tensor, ""], n_examples: Int[Tensor, ""] | int + sum_loss: Float[Tensor, ""], sum_n_examples: Int[Tensor, ""] | int ) -> Float[Tensor, ""]: - return sum_loss / n_examples + return sum_loss / sum_n_examples -def stochastic_recon_loss( - model: ComponentModel, +def stochastic_recon_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, ) -> Float[Tensor, ""]: - sum_loss, n_examples = _stochastic_recon_loss_update( - model, - sampling, - n_mask_samples, - output_loss_type, - batch, - target_out, - ci, - weight_deltas, + sum_loss, sum_n_examples = _stochastic_recon_loss_update( + model=model, + sampling=sampling, + n_mask_samples=n_mask_samples, + batch=batch, + target_out=target_out, + ci=ci, + weight_deltas=weight_deltas, ) - return _stochastic_recon_loss_compute(sum_loss, n_examples) + return _stochastic_recon_loss_compute(sum_loss, sum_n_examples) -class StochasticReconLoss(Metric): +class StochasticReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when sampling with stochastic masks on all component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.sum_loss = torch.tensor(0.0, device=device) - self.n_examples = torch.tensor(0, device=device) + self.sum_n_examples = torch.tensor(0, device=device) @override def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, ) -> None: - sum_loss, n_examples = _stochastic_recon_loss_update( + sum_loss, sum_n_examples = _stochastic_recon_loss_update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, ) self.sum_loss += sum_loss - self.n_examples += n_examples + self.sum_n_examples += sum_n_examples @override def compute(self) -> Float[Tensor, ""]: sum_loss = all_reduce(self.sum_loss, op=ReduceOp.SUM) - n_examples = all_reduce(self.n_examples, op=ReduceOp.SUM) - return _stochastic_recon_loss_compute(sum_loss, n_examples) + sum_n_examples = all_reduce(self.sum_n_examples, op=ReduceOp.SUM) + return _stochastic_recon_loss_compute(sum_loss, sum_n_examples) diff --git a/spd/metrics/stochastic_recon_subset_ce_and_kl.py b/spd/metrics/stochastic_recon_subset_ce_and_kl.py index b1a98b2f2..b1d67f86f 100644 --- a/spd/metrics/stochastic_recon_subset_ce_and_kl.py +++ b/spd/metrics/stochastic_recon_subset_ce_and_kl.py @@ -19,7 +19,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class StochasticReconSubsetCEAndKL(Metric): +class StochasticReconSubsetCEAndKL(Metric[Tensor, Tensor]): """Compute reconstruction loss for specific subsets of components. NOTE: Assumes all batches and sequences are the same size. @@ -29,7 +29,7 @@ class StochasticReconSubsetCEAndKL(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[Tensor, Tensor], device: str, sampling: SamplingType, use_delta_component: bool, diff --git a/spd/metrics/stochastic_recon_subset_loss.py b/spd/metrics/stochastic_recon_subset_loss.py index 62573a889..ae8147126 100644 --- a/spd/metrics/stochastic_recon_subset_loss.py +++ b/spd/metrics/stochastic_recon_subset_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -11,16 +11,15 @@ from spd.routing import Router, get_subset_router from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device -def _stochastic_recon_subset_loss_update( - model: ComponentModel, +def _stochastic_recon_subset_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, router: Router, @@ -42,11 +41,9 @@ def _stochastic_recon_subset_loss_update( for stoch_mask_infos in stoch_mask_infos_list: out = model(batch, mask_infos=stoch_mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - n_examples += out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, batch_n_examples = model.reconstruction_loss(out, target_out) sum_loss += loss - + n_examples += batch_n_examples return sum_loss, n_examples @@ -56,13 +53,12 @@ def _stochastic_recon_subset_loss_compute( return sum_loss / n_examples -def stochastic_recon_subset_loss( - model: ComponentModel, +def stochastic_recon_subset_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing: SubsetRoutingType, @@ -71,36 +67,33 @@ def stochastic_recon_subset_loss( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, - router=get_subset_router(routing, batch.device), + router=get_subset_router(routing, device=get_obj_device(model)), ) return _stochastic_recon_subset_loss_compute(sum_loss, n_examples) -class StochasticReconSubsetLoss(Metric): +class StochasticReconSubsetLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss when sampling with stochastic masks and routing to subsets of component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], routing: SubsetRoutingType, ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.router = get_subset_router(routing, device) self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -109,8 +102,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -119,7 +112,6 @@ def update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, diff --git a/spd/metrics/unmasked_recon_loss.py b/spd/metrics/unmasked_recon_loss.py index 01cf67fe0..1c6681a20 100644 --- a/spd/metrics/unmasked_recon_loss.py +++ b/spd/metrics/unmasked_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -9,26 +9,23 @@ from spd.models.component_model import ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm +from spd.utils.general_utils import get_obj_device -def _unmasked_recon_loss_update( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def _unmasked_recon_loss_update[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ) -> tuple[Float[Tensor, ""], int]: all_ones_mask_infos = make_mask_infos( # (C,) will broadcast to (B, S, C) { - module_path: torch.ones(model.module_to_c[module_path], device=batch.device) + module_path: torch.ones(model.module_to_c[module_path], device=get_obj_device(model)) for module_path in model.target_module_paths } ) out = model(batch, mask_infos=all_ones_mask_infos) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - n_examples = out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() - return loss, n_examples + return model.reconstruction_loss(out, target_out) def _unmasked_recon_loss_compute( @@ -37,34 +34,30 @@ def _unmasked_recon_loss_compute( return sum_loss / n_examples -def unmasked_recon_loss( - model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], +def unmasked_recon_loss[BatchT, OutputT]( + model: ComponentModel[BatchT, OutputT], + batch: BatchT, + target_out: OutputT, ) -> Float[Tensor, ""]: sum_loss, n_examples = _unmasked_recon_loss_update( model, - output_loss_type, batch, target_out, ) return _unmasked_recon_loss_compute(sum_loss, n_examples) -class UnmaskedReconLoss(Metric): +class UnmaskedReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): """Recon loss using the unmasked components and without the delta component.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel, + model: ComponentModel[BatchT, OutputT], device: str, - output_loss_type: Literal["mse", "kl"], ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -72,13 +65,12 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: BatchT, + target_out: OutputT, **_: Any, ) -> None: sum_loss, n_examples = _unmasked_recon_loss_update( model=self.model, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ) diff --git a/spd/metrics/uv_plots.py b/spd/metrics/uv_plots.py index 26880f19e..2c1071ec3 100644 --- a/spd/metrics/uv_plots.py +++ b/spd/metrics/uv_plots.py @@ -9,7 +9,7 @@ from spd.plotting import plot_causal_importance_vals, plot_UV_matrices -class UVPlots(Metric): +class UVPlots(Metric[Any, Any]): metric_section: ClassVar[str] = "figures" slow: ClassVar[bool] = True @@ -17,7 +17,7 @@ class UVPlots(Metric): def __init__( self, - model: ComponentModel, + model: ComponentModel[Any, Any], sampling: SamplingType, identity_patterns: list[str] | None = None, dense_patterns: list[str] | None = None, diff --git a/spd/models/component_model.py b/spd/models/component_model.py index aa923179b..93f194265 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,16 +1,26 @@ -from collections.abc import Callable, Generator, Sequence +from abc import ABC +from collections.abc import Callable, Generator from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import Any, Literal, NamedTuple, overload, override +from typing import Any, Literal, NamedTuple, Protocol, Self, overload, override import torch +import torch.nn.functional as F from jaxtyping import Float, Int from torch import Tensor, nn from torch.utils.hooks import RemovableHandle from transformers.pytorch_utils import Conv1D as RadfordConv1D -from spd.configs import Config, SamplingType +from spd.configs import ( + Config, + IHTaskConfig, + LMTaskConfig, + ResidMLPTaskConfig, + SamplingType, + TaskConfig, + TMSTaskConfig, +) from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo from spd.models.components import ( @@ -38,10 +48,10 @@ class SPDRunInfo(RunInfo[Config]): checkpoint_prefix = "model" -class OutputWithCache(NamedTuple): +class OutputWithCache[OutputT](NamedTuple): """Output tensor and cached activations.""" - output: Tensor + output: OutputT cache: dict[str, Tensor] @@ -52,7 +62,15 @@ class CIOutputs: pre_sigmoid: dict[str, Tensor] -class ComponentModel(LoadableModule): +class RunBatch[BatchT, OutputT](Protocol): + def __call__(self, target_model: nn.Module, batch: BatchT) -> OutputT: ... + + +class ReconstructionLoss[OutputT](Protocol): + def __call__(self, pred: OutputT, target: OutputT) -> tuple[Float[Tensor, ""], int]: ... + + +class ComponentModel[BatchT, OutputT](LoadableModule, ABC): """Wrapper around an arbitrary pytorch model for running SPD. The underlying *base model* can be any subclass of `nn.Module` (e.g. @@ -78,7 +96,8 @@ def __init__( ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], sigmoid_type: SigmoidType, - pretrained_model_output_attr: str | None, + run_batch: RunBatch[BatchT, OutputT], + reconstruction_loss: ReconstructionLoss[OutputT], ): super().__init__() @@ -89,7 +108,6 @@ def __init__( ) self.target_model = target_model - self.pretrained_model_output_attr = pretrained_model_output_attr self.module_to_c = {info.module_path: info.C for info in module_path_info} self.target_module_paths = list(self.module_to_c.keys()) @@ -119,6 +137,9 @@ def __init__( self.lower_leaky_fn = SIGMOID_TYPES[sigmoid_type] self.upper_leaky_fn = SIGMOID_TYPES[sigmoid_type] + self._run_batch = run_batch + self.reconstruction_loss = reconstruction_loss + def target_weight(self, module_name: str) -> Float[Tensor, "rows cols"]: target_module = self.target_model.get_submodule(module_name) @@ -234,67 +255,41 @@ def _create_ci_fns( ) return ci_fns - def _extract_output(self, raw_output: Any) -> Tensor: - """Extract the desired output from the model's raw output. - - If pretrained_model_output_attr is None, returns the raw output directly. - If pretrained_model_output_attr starts with "idx_", returns the index specified by the - second part of the string. E.g. "idx_0" returns the first element of the raw output. - Otherwise, returns the specified attribute from the raw output. - - Args: - raw_output: The raw output from the model. - - Returns: - The extracted output. - """ - if self.pretrained_model_output_attr is None: - out = raw_output - elif self.pretrained_model_output_attr.startswith("idx_"): - idx_val = int(self.pretrained_model_output_attr.split("_")[1]) - assert isinstance(raw_output, Sequence), ( - f"raw_output must be a sequence, not {type(raw_output)}" - ) - assert idx_val < len(raw_output), ( - f"Index {idx_val} out of range for raw_output of length {len(raw_output)}" - ) - out = raw_output[idx_val] - else: - out = getattr(raw_output, self.pretrained_model_output_attr) - - assert isinstance(out, Tensor), f"Expected tensor output, got {type(out)}" - return out + @overload + def __call__( + self, + batch: BatchT, + cache_type: Literal["component_acts"], + mask_infos: dict[str, ComponentsMaskInfo] | None = None, + ) -> OutputWithCache[OutputT]: ... @overload def __call__( self, - *args: Any, + batch: BatchT, + cache_type: Literal["input"], mask_infos: dict[str, ComponentsMaskInfo] | None = None, - cache_type: Literal["component_acts", "input"], - **kwargs: Any, - ) -> OutputWithCache: ... + ) -> OutputWithCache[OutputT]: ... @overload def __call__( self, - *args: Any, + batch: BatchT, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["none"] = "none", - **kwargs: Any, - ) -> Tensor: ... + ) -> OutputT: ... @override - def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache: + def __call__(self, *args: Any, **kwargs: Any) -> OutputT | OutputWithCache[OutputT]: return super().__call__(*args, **kwargs) @override def forward( self, - *args: Any, + batch: BatchT, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["component_acts", "input", "none"] = "none", - **kwargs: Any, - ) -> Tensor | OutputWithCache: + ) -> OutputT | OutputWithCache[OutputT]: """Forward pass with optional component replacement and/or input caching. This method handles the following 4 cases: @@ -319,7 +314,7 @@ def forward( """ if mask_infos is None and cache_type == "none": # No hooks needed. Do a regular forward pass of the target model. - return self._extract_output(self.target_model(*args, **kwargs)) + return self._run_batch(self.target_model, batch) cache: dict[str, Tensor] = {} hooks: dict[str, Callable[..., Any]] = {} @@ -340,9 +335,8 @@ def forward( ) with self._attach_forward_hooks(hooks): - raw_out = self.target_model(*args, **kwargs) + out: OutputT = self._run_batch(self.target_model, batch) - out = self._extract_output(raw_out) match cache_type: case "input" | "component_acts": return OutputWithCache(output=out, cache=cache) @@ -426,7 +420,7 @@ def _attach_forward_hooks(self, hooks: dict[str, Callable[..., Any]]) -> Generat @classmethod @override - def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": + def from_run_info(cls, run_info: RunInfo[Config]) -> Self: """Load a trained ComponentModel checkpoint from a run info object.""" config = run_info.config @@ -456,12 +450,16 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": module_path_info = expand_module_patterns(target_model, config.all_module_info) - comp_model = ComponentModel( + run_batch = get_run_batch(config.task_config, config.pretrained_model_output_attr) + reconstruction_loss = get_reconstruction_loss(config.task_config) + + comp_model = cls( target_model=target_model, module_path_info=module_path_info, + run_batch=run_batch, + reconstruction_loss=reconstruction_loss, ci_fn_hidden_dims=config.ci_fn_hidden_dims, ci_fn_type=config.ci_fn_type, - pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, ) @@ -476,7 +474,7 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": @classmethod @override - def from_pretrained(cls, path: ModelPath) -> "ComponentModel": + def from_pretrained(cls, path: ModelPath) -> Self: """Load a trained ComponentModel checkpoint from a local or wandb path.""" run_info = SPDRunInfo.from_path(path) return cls.from_run_info(run_info) @@ -593,3 +591,78 @@ def handle_deprecated_state_dict_keys_(state_dict: dict[str, Tensor]) -> None: # replace if modified if new_key != key: state_dict[new_key] = state_dict.pop(key) + + +def pass_first_tuple_element_to_model[BatchT: tuple[Any, ...], OutputT]( + target_model: nn.Module, + batch: BatchT, # pyright: ignore[reportInvalidTypeVarUse] +) -> OutputT: # pyright: ignore[reportInvalidTypeVarUse] + return target_model(batch[0]) + + +def pass_batch_directly_to_model[BatchT, OutputT]( + target_model: nn.Module, + batch: BatchT, # pyright: ignore[reportInvalidTypeVarUse] +) -> OutputT: # pyright: ignore[reportInvalidTypeVarUse] + return target_model(batch) + + +def run_batch_extract_idx(idx: int, target_model: nn.Module, batch: Any) -> Any: + return target_model(batch)[idx] + + +def run_batch_extract_attr(attr: str, target_model: nn.Module, batch: Any) -> Any: + return getattr(target_model(batch), attr) + + +def make_run_batch_lm(output_attr: str | None) -> RunBatch[Any, Any]: + if output_attr is None: + return pass_batch_directly_to_model + if output_attr.startswith("idx_"): + idx = int(output_attr.removeprefix("idx_")) + return partial(run_batch_extract_idx, idx) + return partial(run_batch_extract_attr, output_attr) + + +def get_run_batch(task_config: TaskConfig, output_attr: str | None = None) -> RunBatch[Any, Any]: + match task_config: + case TMSTaskConfig() | ResidMLPTaskConfig(): + assert output_attr is None, ( + "output_attr not supported for TMSTaskConfig and ResidMLPTaskConfig" + ) + return pass_first_tuple_element_to_model + case LMTaskConfig() | IHTaskConfig(): + return make_run_batch_lm(output_attr) + + +# the following recon loss functions should return pre-mean values + + +def recon_loss_mse( + pred: Float[Tensor, "... d"], + target: Float[Tensor, "... d"], +) -> tuple[Float[Tensor, ""], int]: + assert pred.shape == target.shape + squared_errors = (pred - target) ** 2 + return squared_errors.sum(), pred.numel() + + +def recon_loss_kl( + pred: Float[Tensor, "... vocab"], + target: Float[Tensor, "... vocab"], +) -> tuple[Float[Tensor, ""], int]: + assert pred.shape == target.shape + log_q = torch.log_softmax(pred, dim=-1) # log Q + p = torch.softmax(target, dim=-1) # P + kl_per_position = F.kl_div(log_q, p, reduction="none").sum(dim=-1) # P · (log P − log Q) + return kl_per_position.sum(), pred.numel() + + +def get_reconstruction_loss( + task_config: TaskConfig, +) -> ReconstructionLoss[Any]: + match task_config: + case TMSTaskConfig() | ResidMLPTaskConfig(): + return recon_loss_mse + case LMTaskConfig() | IHTaskConfig(): + return recon_loss_kl diff --git a/spd/plotting.py b/spd/plotting.py index 81c9c1d5d..e66c391d4 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -1,6 +1,7 @@ import fnmatch import io from collections.abc import Callable +from typing import Any import numpy as np import torch @@ -182,7 +183,7 @@ def plot_mean_component_cis_both_scales( def get_single_feature_causal_importances( - model: ComponentModel, + model: ComponentModel[Any, Any], batch_shape: tuple[int, ...], input_magnitude: float, sampling: SamplingType, @@ -216,7 +217,7 @@ def get_single_feature_causal_importances( def plot_causal_importance_vals( - model: ComponentModel, + model: ComponentModel[Any, Any], batch_shape: tuple[int, ...], input_magnitude: float, sampling: SamplingType, diff --git a/spd/run_spd.py b/spd/run_spd.py index 82245a126..d045e3004 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -4,16 +4,14 @@ from collections import defaultdict from collections.abc import Iterator from pathlib import Path -from typing import cast +from typing import Any, cast import torch import torch.nn as nn import torch.nn.parallel -import torch.optim as optim import wandb -from jaxtyping import Float, Int from PIL import Image -from torch import Tensor +from torch import optim from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from tqdm import tqdm @@ -32,8 +30,7 @@ from spd.log import logger from spd.losses import compute_total_loss from spd.metrics import faithfulness_loss -from spd.metrics.alive_components import AliveComponentsTracker -from spd.models.component_model import ComponentModel, OutputWithCache +from spd.models.component_model import ComponentModel, OutputWithCache, ReconstructionLoss, RunBatch from spd.utils.component_utils import calc_ci_l_zero from spd.utils.distributed_utils import ( avg_metrics_across_ranks, @@ -41,11 +38,7 @@ is_main_process, sync_across_processes, ) -from spd.utils.general_utils import ( - dict_safe_update_, - extract_batch_data, - get_scheduled_value, -) +from spd.utils.general_utils import dict_safe_update_, get_scheduled_value from spd.utils.logging_utils import get_grad_norms_dict, local_log from spd.utils.module_utils import expand_module_patterns, replace_std_values_in_layernorm from spd.utils.run_utils import save_file @@ -53,7 +46,7 @@ def run_faithfulness_warmup( - component_model: ComponentModel, + component_model: ComponentModel[Any, Any], component_params: list[torch.nn.Parameter], config: Config, ) -> None: @@ -111,15 +104,15 @@ def get_unique_metric_configs( return eval_metric_configs -def optimize( +def optimize[BatchT, OutputT]( target_model: nn.Module, config: Config, device: str, - train_loader: DataLoader[Int[Tensor, "..."]] - | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], - eval_loader: DataLoader[Int[Tensor, "..."]] - | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + train_loader: DataLoader[BatchT], + eval_loader: DataLoader[BatchT], n_eval_steps: int, + run_batch: RunBatch[BatchT, OutputT], + reconstruction_loss: ReconstructionLoss[OutputT], out_dir: Path | None, tied_weights: list[tuple[str, str]] | None = None, ln_stds: dict[str, float] | None = None, @@ -129,9 +122,7 @@ def optimize( train_iterator = loop_dataloader(train_loader) eval_iterator = loop_dataloader(eval_loader) - def create_pgd_data_iter() -> ( - Iterator[Int[Tensor, "..."]] | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]] - ): + def create_pgd_data_iter() -> Iterator[BatchT]: assert hasattr(train_loader, "generator") and train_loader.generator is not None train_loader.generator.manual_seed(config.seed) return iter(train_loader) @@ -154,8 +145,9 @@ def create_pgd_data_iter() -> ( module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, + run_batch=run_batch, + reconstruction_loss=reconstruction_loss, ) if ln_stds is not None: @@ -166,6 +158,8 @@ def create_pgd_data_iter() -> ( # Wrap model with DDP if distributed dist_state = get_distributed_state() wrapped_model: nn.Module = model + + component_model: ComponentModel[BatchT, OutputT] if dist_state is not None: if dist_state.backend == "nccl": device_id = dist_state.local_rank @@ -222,22 +216,6 @@ def create_pgd_data_iter() -> ( eval_metric_configs = [ cfg for cfg in eval_metric_configs if cfg not in multibatch_pgd_eval_configs ] - batch_dims: tuple[int, ...] | None = None - - # Track which components are alive based on firing frequency - sample_batch = extract_batch_data(next(train_iterator)) - batch_dims = ( - sample_batch.shape[:-1] - if config.output_loss_type == "mse" # if mse then input is a vector - else sample_batch.shape # else it's a batch of token ids - ) - alive_tracker = AliveComponentsTracker( - module_to_c=model.module_to_c, - device=device, - n_examples_until_dead=config.n_examples_until_dead, - ci_alive_threshold=config.ci_alive_threshold, - global_n_examples_per_batch=batch_dims.numel(), - ) for step in tqdm(range(config.steps + 1), ncols=0, disable=not is_main_process()): optimizer.zero_grad() @@ -253,12 +231,14 @@ def create_pgd_data_iter() -> ( microbatch_log_data: defaultdict[str, float] = defaultdict(float) for _ in range(config.gradient_accumulation_steps): - microbatch = extract_batch_data(next(train_iterator)).to(device) + microbatch = next(train_iterator) # NOTE: we need to call the wrapped_model at least once each step in order to setup # the DDP gradient syncing for all parameters in the component model. Gradients will # sync regardless of whether the parameters are used in this call to wrapped_model. - target_model_output: OutputWithCache = wrapped_model(microbatch, cache_type="input") + target_model_output: OutputWithCache[OutputT] = wrapped_model( + microbatch, cache_type="input" + ) ci = component_model.calc_causal_importances( pre_weight_acts=target_model_output.cache, @@ -266,8 +246,6 @@ def create_pgd_data_iter() -> ( sampling=config.sampling, ) - alive_tracker.update(ci=ci.lower_leaky) - microbatch_total_loss, microbatch_loss_terms = compute_total_loss( loss_metric_configs=config.loss_metric_configs, model=component_model, @@ -280,7 +258,6 @@ def create_pgd_data_iter() -> ( sampling=config.sampling, use_delta_component=config.use_delta_component, n_mask_samples=config.n_mask_samples, - output_loss_type=config.output_loss_type, ) microbatch_total_loss.div_(config.gradient_accumulation_steps).backward() @@ -300,13 +277,6 @@ def create_pgd_data_iter() -> ( avg_metrics = avg_metrics_across_ranks(microbatch_log_data, device=device) microbatch_log_data = cast(defaultdict[str, float], avg_metrics) - alive_counts = alive_tracker.compute() - for target_module_path, n_alive_count in alive_counts.items(): - n_alive_key = ( - f"train/n_alive/t{alive_tracker.ci_alive_threshold}_{target_module_path}" - ) - microbatch_log_data[n_alive_key] = n_alive_count - grad_norms = get_grad_norms_dict(component_model, device) dict_safe_update_( microbatch_log_data, {f"train/grad_norms/{k}": v for k, v in grad_norms.items()} @@ -333,13 +303,11 @@ def create_pgd_data_iter() -> ( else step % config.slow_eval_freq == 0 ) - assert batch_dims is not None, "batch_dims is not set" multibatch_pgd_metrics = evaluate_multibatch_pgd( multibatch_pgd_eval_configs=multibatch_pgd_eval_configs, model=component_model, create_data_iter=create_pgd_data_iter, config=config, - batch_dims=batch_dims, device=device, ) diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index 4bfee9c9a..f41a56783 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -25,7 +25,7 @@ from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data, get_obj_device +from spd.utils.general_utils import get_obj_device from spd.utils.run_utils import save_file @@ -79,7 +79,7 @@ def __init__(self, config: CompareModelsConfig): config.reference_model_path ) - def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, Config]: + def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel[Any, Any], Config]: """Load model and config using the standard pattern from existing codebase.""" run_info = SPDRunInfo.from_path(model_path) model = ComponentModel.from_run_info(run_info) @@ -233,7 +233,7 @@ def _create_ih_data_loader(self) -> Iterator[Any]: ) def compute_activation_densities( - self, model: ComponentModel, eval_iterator: Iterator[Any], n_steps: int + self, model: ComponentModel[Any, Any], eval_iterator: Iterator[Any], n_steps: int ) -> dict[str, Float[Tensor, " C"]]: """Compute activation densities using same logic as ComponentActivationDensity.""" @@ -250,8 +250,7 @@ def compute_activation_densities( model.eval() with torch.no_grad(): for _step in range(n_steps): - batch = extract_batch_data(next(eval_iterator)) - batch = batch.to(self.device) + batch = next(eval_iterator)["input_ids"].to(self.device) pre_weight_acts = model(batch, cache_type="input").cache ci = model.calc_causal_importances( diff --git a/spd/utils/general_utils.py b/spd/utils/general_utils.py index e9a3fc08d..358def97e 100644 --- a/spd/utils/general_utils.py +++ b/spd/utils/general_utils.py @@ -2,7 +2,7 @@ import random from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal, Protocol +from typing import Any, Protocol import einops import numpy as np @@ -162,49 +162,9 @@ def resolve_class(path: str) -> type[nn.Module]: return getattr(module, class_name) -def extract_batch_data( - batch_item: dict[str, Any] | tuple[Tensor, "..."] | Tensor, - input_key: str = "input_ids", -) -> Tensor: - """Extract input data from various batch formats. - - This utility function handles different batch formats commonly used across the codebase: - 1. Dictionary format: {"input_ids": Tensor, "..."} - common in LM tasks - 2. Tuple format: (input_tensor, labels) - common in SPD optimization - 3. Direct tensor: when batch is already the input tensor - - Args: - batch_item: The batch item from a data loader - input_key: Key to use for dictionary format (default: "input_ids") - - Returns: - The input tensor extracted from the batch - """ - assert isinstance(batch_item, dict | tuple | Tensor), ( - f"Unsupported batch format: {type(batch_item)}. Must be a dictionary, tuple, or tensor." - ) - if isinstance(batch_item, dict): - # Dictionary format: extract the specified key - if input_key not in batch_item: - available_keys = list(batch_item.keys()) - raise KeyError( - f"Key '{input_key}' not found in batch. Available keys: {available_keys}" - ) - tensor = batch_item[input_key] - elif isinstance(batch_item, tuple): - # Assume input is the first element - tensor = batch_item[0] - else: - # Direct tensor format - tensor = batch_item - - return tensor - - def calc_kl_divergence_lm( pred: Float[Tensor, "... vocab"], target: Float[Tensor, "... vocab"], - reduce: bool = True, ) -> Float[Tensor, ""] | Float[Tensor, "..."]: """Calculate the KL divergence between two logits. @@ -221,24 +181,7 @@ def calc_kl_divergence_lm( p = torch.softmax(target, dim=-1) # P kl_raw = F.kl_div(log_q, p, reduction="none") # P · (log P − log Q) kl = kl_raw.sum(dim=-1) - if reduce: - return kl.mean() # Σ_vocab / (batch·seq) - else: - return kl - - -def calc_sum_recon_loss_lm( - pred: Float[Tensor, "... vocab"], - target: Float[Tensor, "... vocab"], - loss_type: Literal["mse", "kl"], -) -> Float[Tensor, ""]: - """Calculate the reconstruction loss for a language model without reduction.""" - match loss_type: - case "mse": - loss = ((pred - target) ** 2).sum() - case "kl": - loss = calc_kl_divergence_lm(pred=pred, target=target, reduce=False).sum() - return loss + return kl.mean() # Σ_vocab / (batch·seq) def runtime_cast[T](type_: type[T], obj: Any) -> T: diff --git a/spd/utils/logging_utils.py b/spd/utils/logging_utils.py index fd39afeeb..d006f082d 100644 --- a/spd/utils/logging_utils.py +++ b/spd/utils/logging_utils.py @@ -40,7 +40,7 @@ def local_log(data: dict[str, Any], step: int, out_dir: Path) -> None: def get_grad_norms_dict( - component_model: ComponentModel, device: torch.device | str + component_model: ComponentModel[Any, Any], device: torch.device | str ) -> dict[str, float]: """Create a dictionary of gradient norms for the parameters of a component model.""" diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 4647f6a14..7a402ea54 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -22,7 +22,7 @@ from spd.app.backend.server import app from spd.app.backend.state import HarvestCache, RunState, StateManager from spd.configs import Config, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig -from spd.models.component_model import ComponentModel +from spd.models.component_model import ComponentModel, make_run_batch_lm, recon_loss_kl from spd.utils.module_utils import expand_module_patterns DEVICE = "cpu" @@ -117,8 +117,9 @@ def app_with_state(): module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, + run_batch=make_run_batch_lm(config.pretrained_model_output_attr), + reconstruction_loss=recon_loss_kl, ) model.eval() sources_by_target = get_sources_by_target( diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index fa32cc1e3..b10e2a405 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -1,16 +1,20 @@ """Shared test fixtures for loss function tests.""" -from typing import override +from typing import Any, override import torch import torch.nn as nn from jaxtyping import Float from torch import Tensor -from spd.models.component_model import ComponentModel +from spd.models.component_model import ComponentModel, recon_loss_mse from spd.utils.module_utils import ModulePathInfo +def _test_run_batch(target_model: nn.Module, batch: Tensor) -> Tensor: + return target_model(batch) + + class OneLayerLinearModel(nn.Module): """One-layer linear model for testing.""" @@ -38,7 +42,7 @@ def forward(self, x: Tensor) -> Tensor: return x -def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel: +def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel[Any, Any]: """Create a ComponentModel with a single linear layer for testing. Args: @@ -58,8 +62,9 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_fn_hidden_dims=[2], ci_fn_type="mlp", - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=_test_run_batch, + reconstruction_loss=recon_loss_mse, ) return comp_model @@ -67,7 +72,7 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo def make_two_layer_component_model( weight1: Float[Tensor, " d_hidden d_in"], weight2: Float[Tensor, " d_out d_hidden"] -) -> ComponentModel: +) -> ComponentModel[Any, Any]: """Create a ComponentModel with two linear layers for testing. Args: @@ -95,8 +100,9 @@ def make_two_layer_component_model( ], ci_fn_hidden_dims=[2], ci_fn_type="mlp", - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=_test_run_batch, + reconstruction_loss=recon_loss_mse, ) return comp_model diff --git a/tests/metrics/test_alive_components.py b/tests/metrics/test_alive_components.py deleted file mode 100644 index 009334063..000000000 --- a/tests/metrics/test_alive_components.py +++ /dev/null @@ -1,174 +0,0 @@ -"""Tests for AliveComponentsTracker metric (single-rank).""" - -import torch - -from spd.metrics.alive_components import AliveComponentsTracker - - -def test_initialization(): - """Test that AliveComponentsTracker initializes correctly.""" - metric = AliveComponentsTracker( - module_to_c={"layer1": 5, "layer2": 5}, - device="cpu", - n_examples_until_dead=100, - ci_alive_threshold=0.1, - global_n_examples_per_batch=2, - ) - - assert metric.n_examples_until_dead == 100 - assert metric.ci_alive_threshold == 0.1 - assert metric.n_batches_until_dead == 50 - assert "layer1" in metric.n_batches_since_fired - assert "layer2" in metric.n_batches_since_fired - assert metric.n_batches_since_fired["layer1"].shape == (5,) - assert metric.n_batches_since_fired["layer2"].shape == (5,) - - -def test_update_counter_mechanics(): - """Test that firing resets counter to 0 and non-firing increments by 1.""" - metric = AliveComponentsTracker( - module_to_c={"layer1": 3}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=1, - ) - - # Component 0 fires, components 1 and 2 don't - ci = {"layer1": torch.tensor([0.2, 0.05, 0.08])} - metric.update(ci=ci) - - assert metric.n_batches_since_fired["layer1"][0] == 0 # fired - assert metric.n_batches_since_fired["layer1"][1] == 1 # didn't fire - assert metric.n_batches_since_fired["layer1"][2] == 1 # didn't fire - - # No components fire - all should increment - ci = {"layer1": torch.tensor([0.05, 0.08, 0.09])} - metric.update(ci=ci) - - assert metric.n_batches_since_fired["layer1"][0] == 1 - assert metric.n_batches_since_fired["layer1"][1] == 2 - assert metric.n_batches_since_fired["layer1"][2] == 2 - - # Component 1 fires - ci = {"layer1": torch.tensor([0.05, 0.15, 0.09])} - metric.update(ci=ci) - - assert metric.n_batches_since_fired["layer1"][0] == 2 - assert metric.n_batches_since_fired["layer1"][1] == 0 # reset - assert metric.n_batches_since_fired["layer1"][2] == 3 - - -def test_update_with_multidimensional_input(): - """Test that firing detection works with batch dimensions.""" - metric = AliveComponentsTracker( - module_to_c={"layer1": 3}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=6, - ) - - # Shape: (batch=2, seq=3, C=3) - # Component 0: fires in batch 0, token 0 - # Component 1: fires in batch 1, token 2 - # Component 2: never fires - ci = { - "layer1": torch.tensor( - [ - [[0.2, 0.05, 0.08], [0.05, 0.08, 0.09], [0.05, 0.08, 0.09]], # batch 0 - [[0.05, 0.08, 0.09], [0.05, 0.08, 0.09], [0.05, 0.12, 0.09]], # batch 1 - ] - ) - } - - metric.update(ci=ci) - - assert metric.n_batches_since_fired["layer1"][0] == 0 # fired - assert metric.n_batches_since_fired["layer1"][1] == 0 # fired - assert metric.n_batches_since_fired["layer1"][2] == 1 # didn't fire - - -def test_compute_alive_counts(): - """Test that compute() correctly counts alive components.""" - metric = AliveComponentsTracker( - module_to_c={"layer1": 4, "layer2": 4}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=10, - ) - - # n_batches_until_dead = 50 // 10 = 5 - # Manually set counter values - metric.n_batches_since_fired["layer1"] = torch.tensor([0, 3, 5, 10]) - metric.n_batches_since_fired["layer2"] = torch.tensor([4, 4, 6, 0]) - - result = metric.compute() - - # layer1: components 0, 1 are alive (< 5) - assert result["layer1"] == 2 - # layer2: components 0, 1, 3 are alive (< 5) - assert result["layer2"] == 3 - - -def test_multiple_modules(): - """Test tracking across multiple modules.""" - metric = AliveComponentsTracker( - module_to_c={"layer1": 3, "layer2": 3}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=1, - ) - - ci = { - "layer1": torch.tensor([0.2, 0.05, 0.08]), # component 0 fires - "layer2": torch.tensor([0.05, 0.12, 0.15]), # components 1, 2 fire - } - - metric.update(ci=ci) - - assert metric.n_batches_since_fired["layer1"][0] == 0 - assert metric.n_batches_since_fired["layer1"][1] == 1 - assert metric.n_batches_since_fired["layer1"][2] == 1 - - assert metric.n_batches_since_fired["layer2"][0] == 1 - assert metric.n_batches_since_fired["layer2"][1] == 0 - assert metric.n_batches_since_fired["layer2"][2] == 0 - - -def test_boundary_conditions(): - """Test boundary conditions for alive/dead determination.""" - metric = AliveComponentsTracker( - module_to_c={"layer1": 3}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=10, - ) - # n_batches_until_dead = 50 // 10 = 5 - # Test boundary: 4 < 5 (alive), 5 >= 5 (dead) - metric.n_batches_since_fired["layer1"] = torch.tensor([4, 5, 6]) - - result = metric.compute() - assert result["layer1"] == 1 # only component 0 - - -def test_threshold_boundary(): - """Test that the CI threshold is applied correctly.""" - metric = AliveComponentsTracker( - module_to_c={"layer1": 3}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=10, - ) - - # Test boundary: 0.1 > 0.1 is False, so exactly 0.1 doesn't count as firing - ci = {"layer1": torch.tensor([0.09, 0.1, 0.11])} - metric.update(ci=ci) - - assert metric.n_batches_since_fired["layer1"][0] == 1 # 0.09 not > 0.1 - assert metric.n_batches_since_fired["layer1"][1] == 1 # 0.1 not > 0.1 - assert metric.n_batches_since_fired["layer1"][2] == 0 # 0.11 > 0.1 diff --git a/tests/metrics/test_alive_components_distributed.py b/tests/metrics/test_alive_components_distributed.py deleted file mode 100644 index 29c1c67d2..000000000 --- a/tests/metrics/test_alive_components_distributed.py +++ /dev/null @@ -1,257 +0,0 @@ -"""Distributed tests for AliveComponentsTracker metric. - -Run with: torchrun --standalone --nproc_per_node=2 --master_port=29504 tests/metrics/test_alive_components_distributed.py -Or via pytest (slower): pytest tests/metrics/test_alive_components_distributed.py --runslow -""" - -import os -import subprocess -import sys -from pathlib import Path - -import pytest -import torch - -from spd.metrics.alive_components import AliveComponentsTracker -from spd.utils.distributed_utils import ( - cleanup_distributed, - get_distributed_state, - init_distributed, - sync_across_processes, - with_distributed_cleanup, -) - - -def _test_min_reduction(): - """Test that compute() uses min reduction correctly.""" - dist_state = get_distributed_state() - assert dist_state is not None - - metric = AliveComponentsTracker( - module_to_c={"layer1": 3}, - device="cpu", - n_examples_until_dead=100, - ci_alive_threshold=0.1, - global_n_examples_per_batch=2, - ) - - # Initialize n_batches_until_dead by calling update once - # CI shape (3,) = 1 example per rank * 2 ranks = 2 global examples - # n_batches_until_dead = 100 // 2 = 50 - metric.update(ci={"layer1": torch.tensor([0.0, 0.0, 0.0])}) - - # Set different counter values on each rank - if dist_state.rank == 0: - metric.n_batches_since_fired["layer1"] = torch.tensor([5, 2, 8]) - else: - metric.n_batches_since_fired["layer1"] = torch.tensor([3, 4, 1]) - - # compute() will sync and apply min reduction - # After min reduction: min(5,3)=3 < 50, min(2,4)=2 < 50, min(8,1)=1 < 50 - # All components should be alive - result = metric.compute() - assert result["layer1"] == 3 - - if dist_state.rank == 0: - print("✓ Min reduction test passed") - - -def _test_different_firing_patterns(): - """Test that components firing on any rank are considered alive globally.""" - dist_state = get_distributed_state() - assert dist_state is not None - - metric = AliveComponentsTracker( - module_to_c={"layer1": 3}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=2, - ) - - # Run 3 batches with different firing on each rank - for _ in range(3): - if dist_state.rank == 0: - # Rank 0: only component 0 fires - ci = {"layer1": torch.tensor([0.2, 0.0, 0.0])} - else: - # Rank 1: only component 1 fires - ci = {"layer1": torch.tensor([0.0, 0.2, 0.0])} - metric.update(ci=ci) - - # Before compute: each rank has different local state - if dist_state.rank == 0: - assert metric.n_batches_since_fired["layer1"][0] == 0 # fired locally - assert metric.n_batches_since_fired["layer1"][1] == 3 # didn't fire locally - assert metric.n_batches_since_fired["layer1"][2] == 3 # didn't fire locally - else: - assert metric.n_batches_since_fired["layer1"][0] == 3 # didn't fire locally - assert metric.n_batches_since_fired["layer1"][1] == 0 # fired locally - assert metric.n_batches_since_fired["layer1"][2] == 3 # didn't fire locally - - # compute() will sync with min reduction: - # Component 0: min(0, 3) = 0 (fired on rank 0) - # Component 1: min(3, 0) = 0 (fired on rank 1) - # Component 2: min(3, 3) = 3 (didn't fire on either) - # n_batches_until_dead = 50 // (1 * 2) = 25 - # All < 25, so all alive - result = metric.compute() - assert result["layer1"] == 3 # all components alive - - if dist_state.rank == 0: - print(f"✓ Different firing patterns test passed (n_alive={result['layer1']})") - - -def _test_dead_components(): - """Test that components are correctly marked as dead after threshold.""" - dist_state = get_distributed_state() - assert dist_state is not None - - metric = AliveComponentsTracker( - module_to_c={"layer1": 3}, - device="cpu", - n_examples_until_dead=5, - ci_alive_threshold=0.1, - global_n_examples_per_batch=2, - ) - - # Run batches where component 2 never fires - # CI shape is (3,) which is [C], so 1 local example * 2 ranks = 2 global examples per batch - # n_batches_until_dead = 5 // 2 = 2 - for _ in range(3): # Need 3 batches to exceed threshold (3*2=6 > 5) - if dist_state.rank == 0: - ci = {"layer1": torch.tensor([0.2, 0.0, 0.0])} - else: - ci = {"layer1": torch.tensor([0.0, 0.2, 0.0])} - metric.update(ci=ci) - - # compute() will sync with min reduction: - # Component 0: min(0, 3) = 0 (alive) - # Component 1: min(3, 0) = 0 (alive) - # Component 2: min(3, 3) = 3 (dead, >= 2) - print(f"Rank {dist_state.rank} n_batches_since_fired: {metric.n_batches_since_fired['layer1']}") - result = metric.compute() - # only components 0 and 1 alive - assert result["layer1"] == 2, f"Expected 2 alive components, got {result['layer1']}" - - if dist_state.rank == 0: - print(f"✓ Dead components test passed (n_alive={result['layer1']})") - - -def _test_multiple_modules(): - """Test tracking across multiple modules in distributed setting.""" - dist_state = get_distributed_state() - assert dist_state is not None - - metric = AliveComponentsTracker( - module_to_c={"layer1": 2, "layer2": 2}, - device="cpu", - n_examples_until_dead=50, - ci_alive_threshold=0.1, - global_n_examples_per_batch=2, - ) - - # Each rank fires different components in different modules - if dist_state.rank == 0: - ci = { - "layer1": torch.tensor([0.2, 0.0]), - "layer2": torch.tensor([0.0, 0.0]), - } - else: - ci = { - "layer1": torch.tensor([0.0, 0.0]), - "layer2": torch.tensor([0.0, 0.2]), - } - - metric.update(ci=ci) - - # compute() will sync with min reduction: - # layer1: min(0, 1) = 0, min(1, 1) = 1 - # layer2: min(1, 1) = 1, min(1, 0) = 0 - # n_batches_until_dead = 50 // (1 * 2) = 25 - # All < 25, so all alive - result = metric.compute() - assert result["layer1"] == 2 - assert result["layer2"] == 2 - - if dist_state.rank == 0: - print( - f"✓ Multiple modules test passed (layer1={result['layer1']}, layer2={result['layer2']})" - ) - - -@with_distributed_cleanup -def run_all_tests(): - """Run all distributed tests when called directly with torchrun.""" - init_distributed() - dist_state = get_distributed_state() - assert dist_state is not None - rank = dist_state.rank - world_size = dist_state.world_size - - if world_size != 2: - if rank == 0: - print(f"✗ Tests require exactly 2 ranks, got {world_size}") - cleanup_distributed() - sys.exit(1) - - tests = [ - ("Min reduction", _test_min_reduction), - ("Different firing patterns", _test_different_firing_patterns), - ("Dead components", _test_dead_components), - ("Multiple modules", _test_multiple_modules), - ] - - if rank == 0: - print(f"\nRunning {len(tests)} distributed AliveComponentsTracker tests...\n") - - for test_name, test_func in tests: - try: - test_func() - except Exception as e: - if rank == 0: - print(f"✗ {test_name} failed: {e}") - raise - # Barrier to ensure clean test separation - sync_across_processes() - - if rank == 0: - print(f"\n✓ All {len(tests)} distributed tests passed!\n") - - -# ===== Pytest wrapper ===== -@pytest.mark.slow -class TestDistributedAliveComponentsTracker: - """Pytest wrapper for distributed AliveComponentsTracker tests.""" - - def test_distributed_alive_components(self): - """Run distributed tests via torchrun in subprocess.""" - script_path = Path(__file__).resolve() - - # ports should be globally unique in tests to allow test parallelization - # see discussion at: https://github.com/goodfire-ai/spd/pull/186 - cmd = [ - "torchrun", - "--standalone", - "--nproc_per_node=2", - "--master_port", - "29504", - str(script_path), - ] - - # disable cuda so we run on cpu: - new_env = os.environ.copy() - new_env["CUDA_VISIBLE_DEVICES"] = "" - - result = subprocess.run(cmd, env=new_env, capture_output=True, text=True, timeout=120) - - if result.returncode != 0: - print(f"STDOUT:\n{result.stdout}") - print(f"STDERR:\n{result.stderr}") - raise RuntimeError(f"Distributed test failed with code {result.returncode}") - - print(result.stdout) - - -if __name__ == "__main__": - run_all_tests() diff --git a/tests/metrics/test_ci_masked_recon_layerwise_loss.py b/tests/metrics/test_ci_masked_recon_layerwise_loss.py index 00b8092b2..045b8aca0 100644 --- a/tests/metrics/test_ci_masked_recon_layerwise_loss.py +++ b/tests/metrics/test_ci_masked_recon_layerwise_loss.py @@ -43,7 +43,7 @@ def test_two_layer_manual_calculation(self: object) -> None: # Calculate actual loss actual_loss = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, batch=batch, target_out=target_out, ci=ci ) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( @@ -59,11 +59,9 @@ def test_layerwise_vs_all_layer(self: object) -> None: target_out = torch.randn(1, 2, dtype=torch.float32) ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} - loss_all = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci - ) + loss_all = ci_masked_recon_loss(model=model, batch=batch, target_out=target_out, ci=ci) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, batch=batch, target_out=target_out, ci=ci ) # For single layer, results should be the same diff --git a/tests/metrics/test_ci_masked_recon_loss.py b/tests/metrics/test_ci_masked_recon_loss.py index 3f1202425..63a94d50e 100644 --- a/tests/metrics/test_ci_masked_recon_loss.py +++ b/tests/metrics/test_ci_masked_recon_loss.py @@ -25,9 +25,7 @@ def test_manual_calculation(self: object) -> None: expected_loss = torch.nn.functional.mse_loss(out, target_out) # Calculate actual loss - actual_loss = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci - ) + actual_loss = ci_masked_recon_loss(model=model, batch=batch, target_out=target_out, ci=ci) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( f"Expected {expected_loss}, got {actual_loss}" @@ -45,10 +43,10 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_full + model=model, batch=batch, target_out=target_out, ci=ci_full ) loss_half = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_half + model=model, batch=batch, target_out=target_out, ci=ci_half ) # Different CI values should produce different losses diff --git a/tests/metrics/test_ci_masked_recon_subset_loss.py b/tests/metrics/test_ci_masked_recon_subset_loss.py index 4a9f870b7..6adb7e836 100644 --- a/tests/metrics/test_ci_masked_recon_subset_loss.py +++ b/tests/metrics/test_ci_masked_recon_subset_loss.py @@ -77,7 +77,6 @@ def mock_sample_uniform_k_subset_routing_masks( for _ in range(2): actual_loss = ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, diff --git a/tests/metrics/test_faithfulness_loss.py b/tests/metrics/test_faithfulness_loss.py index b6036f72d..a2635c22d 100644 --- a/tests/metrics/test_faithfulness_loss.py +++ b/tests/metrics/test_faithfulness_loss.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from spd.metrics import faithfulness_loss @@ -5,7 +7,7 @@ from tests.metrics.fixtures import make_one_layer_component_model -def zero_out_components(model: ComponentModel) -> None: +def zero_out_components(model: ComponentModel[Any, Any]) -> None: with torch.no_grad(): for cm in model.components.values(): cm.V.zero_() diff --git a/tests/metrics/test_stochastic_recon_layerwise_loss.py b/tests/metrics/test_stochastic_recon_layerwise_loss.py index 3862d85f8..e22a61c5f 100644 --- a/tests/metrics/test_stochastic_recon_layerwise_loss.py +++ b/tests/metrics/test_stochastic_recon_layerwise_loss.py @@ -105,7 +105,6 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -130,7 +129,6 @@ def test_layerwise_vs_full_loss(self: object) -> None: model=model, sampling="continuous", n_mask_samples=5, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -140,7 +138,6 @@ def test_layerwise_vs_full_loss(self: object) -> None: model=model, sampling="continuous", n_mask_samples=5, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, diff --git a/tests/metrics/test_stochastic_recon_loss.py b/tests/metrics/test_stochastic_recon_loss.py index 594b55a7f..1ee7b723c 100644 --- a/tests/metrics/test_stochastic_recon_loss.py +++ b/tests/metrics/test_stochastic_recon_loss.py @@ -78,7 +78,6 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, diff --git a/tests/metrics/test_stochastic_recon_subset_loss.py b/tests/metrics/test_stochastic_recon_subset_loss.py index 484e3d49f..0d889448c 100644 --- a/tests/metrics/test_stochastic_recon_subset_loss.py +++ b/tests/metrics/test_stochastic_recon_subset_loss.py @@ -92,7 +92,6 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 7f89a7cea..2ee5e79f4 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -20,6 +20,8 @@ from spd.models.component_model import ( ComponentModel, SPDRunInfo, + pass_batch_directly_to_model, + recon_loss_mse, ) from spd.models.components import ( ComponentsMaskInfo, @@ -91,8 +93,9 @@ def test_correct_parameters_require_grad(): ], ci_fn_type="mlp", ci_fn_hidden_dims=[4], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) for module_path, components in component_model.components.items(): @@ -175,8 +178,9 @@ def test_from_run_info(): module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, - pretrained_model_output_attr=config.pretrained_model_output_attr, sigmoid_type=config.sigmoid_type, + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) save_file(cm.state_dict(), comp_model_dir / "model.pth") @@ -282,8 +286,9 @@ def test_full_weight_delta_matches_target_behaviour(): module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[4], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) token_ids = torch.randint( @@ -314,8 +319,9 @@ def test_input_cache_captures_pre_weight_input(): module_path_info=[ModulePathInfo(module_path=p, C=2) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) # WHEN we forward the component model with input caching @@ -349,8 +355,9 @@ def test_weight_deltas(): module_path_info=[ModulePathInfo(module_path=p, C=3) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) # THEN the weight deltas match the target weight @@ -384,8 +391,9 @@ def forward(self, x: Tensor) -> Tensor: module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) # WHEN we set the target model weights to be UV @@ -440,8 +448,9 @@ def forward(self, x: Tensor) -> Tensor: module_path_info=[ModulePathInfo(module_path="linear.pre_identity", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) # and a random input @@ -490,8 +499,9 @@ def forward(self, x: Tensor) -> Tensor: module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=pass_batch_directly_to_model, + reconstruction_loss=recon_loss_mse, ) # and a random input diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index dedf0ba58..86e772bad 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -16,6 +16,7 @@ ) from spd.data import DatasetConfig, create_data_loader from spd.identity_insertion import insert_identity_operations_ +from spd.models.component_model import make_run_batch_lm, recon_loss_kl from spd.run_spd import optimize from spd.utils.general_utils import resolve_class, set_seed @@ -151,6 +152,8 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch_lm(config.pretrained_model_output_attr), + reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 2e926eae1..f489cade8 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -17,6 +17,7 @@ from spd.experiments.ih.configs import InductionModelConfig from spd.experiments.ih.model import InductionTransformer from spd.identity_insertion import insert_identity_operations_ +from spd.models.component_model import make_run_batch_lm, recon_loss_kl from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.general_utils import set_seed @@ -135,6 +136,8 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch_lm(config.pretrained_model_output_attr), + reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index dfe771e00..935312519 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -13,6 +13,7 @@ from spd.experiments.resid_mlp.models import ResidMLP from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.identity_insertion import insert_identity_operations_ +from spd.models.component_model import make_run_batch_lm, recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.general_utils import set_seed @@ -131,6 +132,8 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch_lm(config.pretrained_model_output_attr), + reconstruction_loss=recon_loss_mse, out_dir=tmp_path, ) diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 69a546f6e..6631a56ca 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -1,4 +1,4 @@ -from typing import override +from typing import Any, override import torch import torch.nn as nn @@ -16,10 +16,14 @@ stochastic_recon_loss, stochastic_recon_subset_loss, ) -from spd.models.component_model import ComponentModel +from spd.models.component_model import ComponentModel, recon_loss_mse from spd.utils.module_utils import ModulePathInfo +def _test_run_batch(target_model: nn.Module, batch: Tensor) -> Tensor: + return target_model(batch) + + class TinyLinearModel(nn.Module): def __init__(self, d_in: int, d_out: int) -> None: super().__init__() @@ -30,7 +34,7 @@ def forward(self, x: Tensor) -> Tensor: return self.fc(x) -def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel: +def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel[Any, Any]: d_out, d_in = weight.shape target = TinyLinearModel(d_in=d_in, d_out=d_out) with torch.no_grad(): @@ -42,14 +46,15 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_fn_hidden_dims=[2], ci_fn_type="mlp", - pretrained_model_output_attr=None, sigmoid_type="leaky_hard", + run_batch=_test_run_batch, + reconstruction_loss=recon_loss_mse, ) return comp_model -def _zero_components_for_test(model: ComponentModel) -> None: +def _zero_components_for_test(model: ComponentModel[Any, Any]) -> None: with torch.no_grad(): for cm in model.components.values(): cm.V.zero_() @@ -279,7 +284,6 @@ def test_mse_loss_basic(self: object) -> None: result = ci_masked_recon_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -304,7 +308,6 @@ def test_kl_loss_basic(self: object) -> None: result = ci_masked_recon_loss( model=model, - output_loss_type="kl", batch=batch, target_out=target_out, ci=ci, @@ -324,10 +327,10 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_full + model=model, batch=batch, target_out=target_out, ci=ci_full ) loss_half = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_half + model=model, batch=batch, target_out=target_out, ci=ci_half ) # Different CI values should produce different losses @@ -346,7 +349,6 @@ def test_layerwise_basic(self: object) -> None: result = ci_masked_recon_layerwise_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -365,11 +367,9 @@ def test_layerwise_vs_all_layer(self: object) -> None: target_out = torch.tensor([[1.0, 2.0]], dtype=torch.float32) ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} - loss_all = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci - ) + loss_all = ci_masked_recon_loss(model=model, batch=batch, target_out=target_out, ci=ci) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, batch=batch, target_out=target_out, ci=ci ) # For single layer, results should be the same @@ -388,7 +388,6 @@ def test_subset_basic(self: object) -> None: result = ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -411,7 +410,6 @@ def test_subset_stochastic_behavior(self: object) -> None: losses = [ ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -439,7 +437,6 @@ def test_continuous_sampling_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -462,7 +459,6 @@ def test_binomial_sampling_basic(self: object) -> None: model=model, sampling="binomial", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -487,7 +483,6 @@ def test_multiple_mask_samples(self: object) -> None: model=model, sampling="continuous", n_mask_samples=n_samples, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -509,7 +504,6 @@ def test_with_and_without_delta_component(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -520,7 +514,6 @@ def test_with_and_without_delta_component(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -547,7 +540,6 @@ def test_layerwise_stochastic_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -571,7 +563,6 @@ def test_layerwise_multiple_samples(self: object) -> None: model=model, sampling="continuous", n_mask_samples=n_samples, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -595,7 +586,6 @@ def test_subset_stochastic_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -619,7 +609,6 @@ def test_subset_with_binomial_sampling(self: object) -> None: model=model, sampling="binomial", n_mask_samples=3, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, @@ -644,7 +633,6 @@ def test_subset_stochastic_variability(self: object) -> None: model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", batch=batch, target_out=target_out, ci=ci, diff --git a/tests/test_tms.py b/tests/test_tms.py index bbbcec4cb..f695bf7d4 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -18,6 +18,7 @@ from spd.experiments.tms.models import TMSModel from spd.experiments.tms.train_tms import get_model_and_dataloader, train from spd.identity_insertion import insert_identity_operations_ +from spd.models.component_model import make_run_batch_lm, recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.general_utils import set_seed @@ -139,6 +140,8 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: train_loader=train_loader, eval_loader=eval_loader, n_eval_steps=config.n_eval_steps, + run_batch=make_run_batch_lm(config.pretrained_model_output_attr), + reconstruction_loss=recon_loss_mse, out_dir=tmp_path, tied_weights=tied_weights, ) diff --git a/tests/test_wandb_run_loading.py b/tests/test_wandb_run_loading.py index 55653b7c8..463ed1247 100644 --- a/tests/test_wandb_run_loading.py +++ b/tests/test_wandb_run_loading.py @@ -5,6 +5,8 @@ the canonical configs, and update the registry with your new run(s). """ +from typing import Any + import pytest from spd.models.component_model import ComponentModel, SPDRunInfo @@ -12,12 +14,12 @@ from spd.utils.wandb_utils import parse_wandb_run_path -def from_run_info(canonical_run: str) -> ComponentModel: +def from_run_info(canonical_run: str) -> ComponentModel[Any, Any]: run_info = SPDRunInfo.from_path(canonical_run) return ComponentModel.from_run_info(run_info) -def from_pretrained(canonical_run: str) -> ComponentModel: +def from_pretrained(canonical_run: str) -> ComponentModel[Any, Any]: return ComponentModel.from_pretrained(canonical_run) From bd03f577015e0c7a7c93e1b29f5421d17021d17f Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Thu, 29 Jan 2026 23:18:10 +0000 Subject: [PATCH 02/16] wip: Refactor ComponentModel loading to use task-specific factories --- .claude/.nfs582d9ab79662f72a00003620 | 23 ++ spd/app/backend/compute.py | 32 ++- spd/app/backend/optim_cis.py | 8 +- .../backend/routers/dataset_attributions.py | 2 +- spd/app/backend/routers/runs.py | 5 +- spd/app/backend/state.py | 3 +- spd/autointerp/interpret.py | 9 +- spd/clustering/activations.py | 2 +- spd/clustering/dataset.py | 6 +- spd/clustering/scripts/run_clustering.py | 5 +- spd/dataset_attributions/harvest.py | 8 +- spd/dataset_attributions/harvester.py | 6 +- spd/eval.py | 23 +- spd/experiments/ih/ih_decomposition.py | 4 +- spd/experiments/lm/lm_decomposition.py | 4 +- spd/experiments/lm/loaders.py | 114 ++++++++ spd/experiments/resid_mlp/models.py | 57 +++- .../resid_mlp/resid_mlp_decomposition.py | 4 +- spd/experiments/resid_mlp/resid_mlp_interp.py | 12 +- spd/experiments/tms/models.py | 54 +++- spd/experiments/tms/plotting.py | 5 +- spd/experiments/tms/tms_decomposition.py | 4 +- spd/harvest/harvest.py | 3 +- spd/losses.py | 12 + spd/metrics/ci_masked_recon_layerwise_loss.py | 9 +- spd/metrics/ci_masked_recon_loss.py | 9 +- spd/metrics/ci_masked_recon_subset_loss.py | 10 +- .../pgd_masked_recon_layerwise_loss.py | 8 + spd/metrics/pgd_masked_recon_loss.py | 6 + spd/metrics/pgd_masked_recon_subset_loss.py | 6 + spd/metrics/pgd_utils.py | 12 +- .../stochastic_recon_layerwise_loss.py | 9 +- spd/metrics/stochastic_recon_loss.py | 9 +- spd/metrics/stochastic_recon_subset_loss.py | 9 +- spd/metrics/unmasked_recon_loss.py | 9 +- spd/models/batch_and_loss_fns.py | 39 +++ spd/models/component_model.py | 190 ++----------- spd/run_spd.py | 24 +- spd/scripts/compare_models/compare_models.py | 4 +- spd/simple_trainer.py | 251 ++++++++++++++++++ spd/utils/wandb_utils.py | 1 + tests/app/test_server_api.py | 9 +- tests/metrics/fixtures.py | 10 +- .../test_ci_masked_recon_layerwise_loss.py | 21 +- tests/metrics/test_ci_masked_recon_loss.py | 21 +- .../test_ci_masked_recon_subset_loss.py | 2 + .../test_stochastic_recon_layerwise_loss.py | 4 + tests/metrics/test_stochastic_recon_loss.py | 2 + .../test_stochastic_recon_subset_loss.py | 2 + tests/test_component_model.py | 42 +-- tests/test_gpt2.py | 4 +- tests/test_ih_transformer.py | 4 +- tests/test_resid_mlp.py | 4 +- tests/test_spd_losses.py | 50 +++- tests/test_tms.py | 4 +- tests/test_wandb_run_loading.py | 34 +-- 56 files changed, 895 insertions(+), 328 deletions(-) create mode 100644 .claude/.nfs582d9ab79662f72a00003620 create mode 100644 spd/experiments/lm/loaders.py create mode 100644 spd/models/batch_and_loss_fns.py create mode 100644 spd/simple_trainer.py diff --git a/.claude/.nfs582d9ab79662f72a00003620 b/.claude/.nfs582d9ab79662f72a00003620 new file mode 100644 index 000000000..d67075d64 --- /dev/null +++ b/.claude/.nfs582d9ab79662f72a00003620 @@ -0,0 +1,23 @@ +{ + "permissions": { + "allow": [ + "Bash(source:*)", + "Bash(npm run check:*)", + "Bash(make check-app:*)", + "Bash(npm run lint:*)", + "Bash(git stash push:*)", + "Bash(grep:*)", + "Bash(npm run format:*)", + "Bash(npm run build:*)", + "Bash(npx eslint:*)", + "Bash(npx prettier:*)", + "Bash(git add:*)", + "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nAdd \"Use as Prompt\" popup for selected text in dataset explorer\n\n- Select text within story content to show floating popup\n- \"Use as Prompt\" button creates a custom prompt from selection\n- Text is cleaned: newlines → spaces, whitespace collapsed, trimmed\n- Shows hint when no run is loaded\n- Only triggers on .story-text elements \\(not headers/tags\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")", + "Bash(git revert:*)", + "Bash(python:*)", + "Bash(make:*)", + "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nOptimize random sampling and hide zero occurrence badges\n\n- Use random indices instead of shuffling entire dataset \\(~100x faster\\)\n- Hide occurrence badge when count is 0 \\(for random samples\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")", + "Bash(git commit:*)" + ] + } +} diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index bdea2163b..3b47d6203 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -7,7 +7,7 @@ from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass -from typing import Any, override +from typing import Any, cast, override import torch from jaxtyping import Bool, Float @@ -126,7 +126,7 @@ def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: def get_sources_by_target( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], device: str, sampling: SamplingType, ) -> dict[str, list[str]]: @@ -166,8 +166,9 @@ def wte_hook( wte_cache["wte_post_detach"] = output return output - assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" - wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + wte = getattr(model.target_model, "wte") + assert isinstance(wte, nn.Module), "wte is not a module" + wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) with torch.enable_grad(): comp_output_with_cache: OutputWithCache[Any] = model( @@ -192,7 +193,7 @@ def wte_hook( "mlp.c_fc", "mlp.down_proj", ] - n_blocks = get_model_n_blocks(model.target_model) + n_blocks = get_model_n_blocks(cast(nn.Module, model.target_model)) for i in range(n_blocks): layers.extend([f"h.{i}.{layer_name}" for layer_name in component_layer_names]) @@ -305,7 +306,7 @@ def _compute_edges_for_target( def compute_edges_from_ci( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], tokens: Float[Tensor, "1 seq"], ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], pre_weight_acts: dict[str, Float[Tensor, "1 seq d_in"]], @@ -342,8 +343,9 @@ def compute_edges_from_ci( # Setup wte hook and run forward pass for gradient computation wte_hook, wte_cache = _setup_wte_hook() - assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" - wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) + wte = getattr(model.target_model, "wte") + assert isinstance(wte, nn.Module), "wte is not a module" + wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) weight_deltas = model.calc_weight_deltas() weight_deltas_and_masks = { @@ -490,7 +492,7 @@ def filter_ci_to_included_nodes( def compute_prompt_attributions( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], output_prob_threshold: float, @@ -540,7 +542,7 @@ def compute_prompt_attributions( def compute_prompt_attributions_optimized( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], optim_config: OptimCIConfig, @@ -624,7 +626,7 @@ class CIOnlyResult: def compute_ci_only( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], tokens: Float[Tensor, "1 seq"], sampling: SamplingType, ) -> CIOnlyResult: @@ -768,6 +770,12 @@ def get_model_n_blocks(model: nn.Module) -> int: from simple_stories_train.models.llama_simple_mlp import LlamaSimpleMLP from transformers.models.gpt2 import GPT2LMHeadModel + from spd.experiments.lm.loaders import LogitsOnlyWrapper + + # Unwrap LogitsOnlyWrapper if present + if isinstance(model, LogitsOnlyWrapper): + model = model.model + match model: case GPT2LMHeadModel(): return len(model.transformer.h) @@ -788,7 +796,7 @@ class InterventionResult: def compute_intervention_forward( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] top_k: int, diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 3bacf0aef..77616a282 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -72,7 +72,7 @@ class OptimizableCIParams: ci_pre_sigmoid: dict[str, list[Tensor]] # layer_name -> list of [alive_at_pos] values alive_info: AliveComponentInfo - def create_ci_outputs(self, model: ComponentModel[Any, Any], device: str) -> CIOutputs: + def create_ci_outputs(self, model: ComponentModel[Tensor, Tensor], device: str) -> CIOutputs: """Expand sparse pre-sigmoid values to full CI tensors and create CIOutputs.""" pre_sigmoid: dict[str, Tensor] = {} @@ -139,7 +139,7 @@ def create_optimizable_ci_params( def compute_label_prob( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], tokens: Tensor, ci_lower_leaky: dict[str, Tensor], label_token: int, @@ -165,7 +165,7 @@ def compute_l0_stats( def compute_final_token_ce_kl( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], batch: Tensor, target_out: Tensor, ci: dict[str, Tensor], @@ -267,7 +267,7 @@ class OptimCIConfig: def optimize_ci_values( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], tokens: Tensor, config: OptimCIConfig, device: str, diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index fa38f5146..06bc55725 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -85,7 +85,7 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: """Get the unembedding matrix from the loaded model.""" - lm_head = loaded.model.target_model.lm_head + lm_head = getattr(loaded.model.target_model, "lm_head") assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" return lm_head.weight.T.detach() diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index c62df6688..a7ce9a8b3 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -14,8 +14,9 @@ from spd.app.backend.dependencies import DepStateManager from spd.app.backend.state import HarvestCache, RunState from spd.app.backend.utils import build_token_lookup, log_errors +from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.component_model import SPDRunInfo from spd.utils.distributed_utils import get_device from spd.utils.wandb_utils import parse_wandb_run_path @@ -92,7 +93,7 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): # Load the model logger.info(f"[API] Loading model for run {run.id}: {run.wandb_path}") - model = ComponentModel.from_run_info(run_info) + model = load_lm_component_model_from_run_info(run_info) model = model.to(DEVICE) model.eval() diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index 1a86a06e0..ca178b688 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -8,6 +8,7 @@ from dataclasses import dataclass, field from typing import Any +from torch import Tensor from transformers.tokenization_utils_base import PreTrainedTokenizerBase from spd.app.backend.database import PromptAttrDB, Run @@ -110,7 +111,7 @@ class RunState: """Runtime state for a loaded run (model, tokenizer, etc.)""" run: Run - model: ComponentModel[Any, Any] + model: ComponentModel[Tensor, Tensor] tokenizer: PreTrainedTokenizerBase sources_by_target: dict[str, list[str]] config: Config diff --git a/spd/autointerp/interpret.py b/spd/autointerp/interpret.py index 5e38e4985..571ba3200 100644 --- a/spd/autointerp/interpret.py +++ b/spd/autointerp/interpret.py @@ -5,8 +5,10 @@ from dataclasses import asdict, dataclass from enum import StrEnum from pathlib import Path +from typing import cast import httpx +import torch.nn as nn from openrouter import OpenRouter from openrouter.components import JSONSchemaConfig, MessageTypedDict, ResponseFormatJSONSchema from openrouter.errors import ( @@ -26,12 +28,13 @@ from spd.autointerp.prompt_template import INTERPRETATION_SCHEMA, format_prompt_template from spd.autointerp.schemas import ArchitectureInfo, InterpretationResult from spd.configs import LMTaskConfig +from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.harvest.analysis import TokenPRLift, get_input_token_stats, get_output_token_stats from spd.harvest.harvest import HarvestResult from spd.harvest.schemas import ComponentData from spd.harvest.storage import TokenStatsStorage from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.component_model import SPDRunInfo # Retry config MAX_RETRIES = 8 @@ -334,8 +337,8 @@ async def process_one( def get_architecture_info(wandb_path: str) -> ArchitectureInfo: run_info = SPDRunInfo.from_path(wandb_path) - model = ComponentModel.from_run_info(run_info) - n_blocks = get_model_n_blocks(model.target_model) + model = load_lm_component_model_from_run_info(run_info) + n_blocks = get_model_n_blocks(cast(nn.Module, model.target_model)) config = run_info.config task_config = config.task_config assert isinstance(task_config, LMTaskConfig) diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index 2999efe21..6c427d6b1 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -17,7 +17,7 @@ def component_activations( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], device: torch.device | str, batch: Int[Tensor, "batch_size n_ctx"], ) -> dict[str, ActivationsTensor]: diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index c8e86f0fc..85da48838 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -8,8 +8,8 @@ from spd.clustering.consts import BatchTensor from spd.configs import LMTaskConfig, ResidMLPTaskConfig from spd.data import DatasetConfig, create_data_loader -from spd.experiments.resid_mlp.models import ResidMLP -from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.experiments.resid_mlp.models import ResidMLP, load_resid_mlp_component_model_from_run_info +from spd.models.component_model import SPDRunInfo from spd.spd_types import TaskName @@ -106,7 +106,7 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT spd_run = SPDRunInfo.from_path(model_path) cfg = spd_run.config - component_model = ComponentModel.from_pretrained(spd_run.checkpoint_path) + component_model = load_resid_mlp_component_model_from_run_info(spd_run) assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( f"Expected task_config to be of type ResidMLPTaskConfig, but got {type(cfg.task_config) = }" diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index f76e64d5a..e24c50827 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -49,8 +49,9 @@ from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration from spd.clustering.storage import StorageBase from spd.clustering.wandb_tensor_info import wandb_log_tensor +from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.component_model import SPDRunInfo from spd.spd_types import TaskName from spd.utils.distributed_utils import get_device from spd.utils.general_utils import replace_pydantic_model @@ -298,7 +299,7 @@ def main(run_config: ClusteringRunConfig) -> Path: # 3. Load model logger.info("Loading model") - model = ComponentModel.from_run_info(spd_run).to(device) + model = load_lm_component_model_from_run_info(spd_run).to(device) # 4. Compute activations logger.info("Computing activations") diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 3318b71d7..a16be5897 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -15,7 +15,6 @@ import itertools from dataclasses import dataclass from pathlib import Path -from typing import Any import torch import tqdm @@ -27,6 +26,7 @@ from spd.dataset_attributions.harvester import AttributionHarvester from spd.dataset_attributions.loaders import get_attributions_dir from spd.dataset_attributions.storage import DatasetAttributionStorage +from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.harvest.loaders import load_activation_contexts_summary from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -42,7 +42,7 @@ class DatasetAttributionConfig: ci_threshold: float -def _build_component_layer_keys(model: ComponentModel[Any, Any]) -> list[str]: +def _build_component_layer_keys(model: ComponentModel[Tensor, Tensor]) -> list[str]: """Build list of component layer keys in canonical order. Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. @@ -57,7 +57,7 @@ def _build_component_layer_keys(model: ComponentModel[Any, Any]) -> list[str]: def _build_alive_masks( - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], run_id: str, ci_threshold: float, n_components: int, @@ -140,7 +140,7 @@ def harvest_attributions( _, _, run_id = parse_wandb_run_path(config.wandb_path) run_info = SPDRunInfo.from_path(config.wandb_path) - model = ComponentModel.from_run_info(run_info).to(device) + model = load_lm_component_model_from_run_info(run_info).to(device) model.eval() spd_config = run_info.config diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 3c1d18a64..1b86cc300 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -45,7 +45,7 @@ class AttributionHarvester: def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Tensor, Tensor], sources_by_target: dict[str, list[str]], n_components: int, vocab_size: int, @@ -74,7 +74,7 @@ def __init__( # For output targets: store attributions to output residual dimensions assert hasattr(model.target_model, "lm_head"), "Model must have lm_head" - lm_head = model.target_model.lm_head + lm_head = getattr(model.target_model, "lm_head") assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" self.d_model = lm_head.in_features self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) @@ -143,7 +143,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - wte = self.model.target_model.wte + wte = getattr(self.model.target_model, "wte") assert isinstance(wte, nn.Module) h1 = wte.register_forward_hook(wte_hook, with_kwargs=True) h2 = self.lm_head.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) diff --git a/spd/eval.py b/spd/eval.py index 9184eff1e..daf281ef8 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -58,6 +58,7 @@ from spd.metrics.stochastic_recon_loss import StochasticReconLoss from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss from spd.metrics.uv_plots import UVPlots +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel, OutputWithCache from spd.routing import AllLayersRouter, get_subset_router from spd.utils.distributed_utils import avg_metrics_across_ranks, is_distributed @@ -118,6 +119,7 @@ def init_metric[BatchT, OutputT]( model: ComponentModel[BatchT, OutputT], run_config: Config, device: str, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Metric[BatchT, OutputT]: match cfg: case ImportanceMinimalityLossConfig(): @@ -157,16 +159,19 @@ def init_metric[BatchT, OutputT]( model=model, device=device, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLayerwiseLossConfig(): metric = CIMaskedReconLayerwiseLoss( model=model, device=device, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLossConfig(): metric = CIMaskedReconLoss( model=model, device=device, + reconstruction_loss=reconstruction_loss, ) case CIMeanPerComponentConfig(): metric = CIMeanPerComponent(model=model, device=device) @@ -195,6 +200,7 @@ def init_metric[BatchT, OutputT]( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, + reconstruction_loss=reconstruction_loss, ) case StochasticReconLossConfig(): metric = StochasticReconLoss( @@ -203,6 +209,7 @@ def init_metric[BatchT, OutputT]( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, + reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetLossConfig(): metric = StochasticReconSubsetLoss( @@ -212,6 +219,7 @@ def init_metric[BatchT, OutputT]( use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLossConfig(): metric = PGDReconLoss( @@ -219,6 +227,7 @@ def init_metric[BatchT, OutputT]( device=device, use_delta_component=run_config.use_delta_component, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case PGDReconSubsetLossConfig(): metric = PGDReconSubsetLoss( @@ -227,6 +236,7 @@ def init_metric[BatchT, OutputT]( use_delta_component=run_config.use_delta_component, pgd_config=cfg, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLayerwiseLossConfig(): metric = PGDReconLayerwiseLoss( @@ -234,6 +244,7 @@ def init_metric[BatchT, OutputT]( device=device, use_delta_component=run_config.use_delta_component, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetCEAndKLConfig(): raise ValueError("fix this typing!") @@ -265,6 +276,7 @@ def init_metric[BatchT, OutputT]( metric = UnmaskedReconLoss( model=model, device=device, + reconstruction_loss=reconstruction_loss, ) case _: @@ -283,12 +295,19 @@ def evaluate[BatchT, OutputT]( slow_step: bool, n_eval_steps: int, current_frac_of_training: float, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> MetricOutType: """Run evaluation and return a mapping of metric names to values/images.""" metrics: list[Metric[BatchT, OutputT]] = [] for cfg in eval_metric_configs: - metric = init_metric(cfg=cfg, model=model, run_config=run_config, device=device) + metric = init_metric( + cfg=cfg, + model=model, + run_config=run_config, + device=device, + reconstruction_loss=reconstruction_loss, + ) if metric.slow and not slow_step: continue metrics.append(metric) @@ -337,6 +356,7 @@ def evaluate_multibatch_pgd[BatchT, OutputT]( create_data_iter: CreateDataIter[BatchT], config: Config, device: str, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> dict[str, float]: """Calculate multibatch PGD metrics.""" weight_deltas = model.calc_weight_deltas() if config.use_delta_component else None @@ -362,5 +382,6 @@ def evaluate_multibatch_pgd[BatchT, OutputT]( sampling=config.sampling, use_delta_component=config.use_delta_component, device=device, + reconstruction_loss=reconstruction_loss, ).item() return metrics diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 1bbc45e69..0fef3ff71 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -7,7 +7,7 @@ from spd.configs import Config, IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer from spd.log import logger -from spd.models.component_model import make_run_batch_lm, recon_loss_kl +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.distributed_utils import get_device @@ -97,8 +97,6 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=make_run_batch_lm(config.pretrained_model_output_attr), reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index fcb6c596a..fc2d98886 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -10,7 +10,7 @@ from spd.configs import Config, LMTaskConfig from spd.data import DatasetConfig, create_data_loader from spd.log import logger -from spd.models.component_model import make_run_batch_lm, recon_loss_kl +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize from spd.utils.distributed_utils import ( DistributedState, @@ -181,8 +181,6 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=make_run_batch_lm(config.pretrained_model_output_attr), reconstruction_loss=recon_loss_kl, out_dir=out_dir, ln_stds=ln_stds, diff --git a/spd/experiments/lm/loaders.py b/spd/experiments/lm/loaders.py new file mode 100644 index 000000000..242614a8f --- /dev/null +++ b/spd/experiments/lm/loaders.py @@ -0,0 +1,114 @@ +"""Loaders for LM ComponentModels.""" + +from collections.abc import Generator, Iterator +from typing import Any, override + +import torch +from torch import Tensor, nn +from torch.nn import Parameter + +from spd.configs import Config +from spd.identity_insertion import insert_identity_operations_ +from spd.interfaces import LoadableModule, RunInfo +from spd.models.component_model import ( + ComponentModel, + SPDRunInfo, + handle_deprecated_state_dict_keys_, +) +from spd.spd_types import ModelPath +from spd.utils.general_utils import resolve_class +from spd.utils.module_utils import expand_module_patterns + + +class LogitsOnlyWrapper(nn.Module): + """Wrapper that extracts logits from models that return (logits, loss) tuples.""" + + def __init__(self, model: nn.Module): + super().__init__() + self.model = model + + @override + def forward(self, *args: Any, **kwargs: Any) -> Tensor: + out = self.model(*args, **kwargs) + if isinstance(out, tuple): + return out[0] + return out + + @override + def get_submodule(self, target: str) -> nn.Module: + # Delegate to wrapped model so paths don't need "model." prefix + return self.model.get_submodule(target) + + @override + def named_parameters( + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[tuple[str, Parameter]]: + # Delegate to wrapped model so parameter names don't have "model." prefix + return self.model.named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + + @override + def named_modules( + self, memo: set[nn.Module] | None = None, prefix: str = "", remove_duplicate: bool = True + ) -> Generator[tuple[str, nn.Module], None, None]: + # Delegate to wrapped model so module names don't have "model." prefix + yield from self.model.named_modules(memo=memo, prefix=prefix, remove_duplicate=remove_duplicate) + + @override + def __getattr__(self, name: str) -> Any: + # Delegate attribute access to the wrapped model for things like wte, lm_head, etc. + if name == "model": + return super().__getattr__(name) + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.model, name) + + +def load_lm_component_model_from_run_info( + run_info: RunInfo[Config], +) -> ComponentModel[Tensor, Tensor]: + """Load a trained LM ComponentModel from a run info object.""" + config = run_info.config + + model_class = resolve_class(config.pretrained_model_class) + if config.pretrained_model_name is not None: + assert hasattr(model_class, "from_pretrained") + target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] + else: + assert issubclass(model_class, LoadableModule) + assert config.pretrained_model_path is not None + target_model = model_class.from_pretrained(config.pretrained_model_path) + + target_model.eval() + target_model.requires_grad_(False) + + if config.identity_module_info is not None: + insert_identity_operations_( + target_model, + identity_module_info=config.identity_module_info, + ) + + # Wrap the model to extract logits from (logits, loss) tuple outputs + wrapped_model = LogitsOnlyWrapper(target_model) + + module_path_info = expand_module_patterns(target_model, config.all_module_info) + + comp_model: ComponentModel[Tensor, Tensor] = ComponentModel( + target_model=wrapped_model, + module_path_info=module_path_info, + ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_fn_type=config.ci_fn_type, + sigmoid_type=config.sigmoid_type, + ) + + comp_model_weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) + handle_deprecated_state_dict_keys_(comp_model_weights) + comp_model.load_state_dict(comp_model_weights) + + return comp_model + + +def load_lm_component_model(path: ModelPath) -> ComponentModel[Tensor, Tensor]: + """Load a trained LM ComponentModel from a wandb or local path.""" + run_info = SPDRunInfo.from_path(path) + return load_lm_component_model_from_run_info(run_info) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index 108d0b520..ce68d3409 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -10,13 +10,23 @@ from jaxtyping import Float from torch import Tensor, nn +from spd.configs import Config from spd.experiments.resid_mlp.configs import ( ResidMLPModelConfig, ResidMLPTrainConfig, ) +from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo +from spd.models.component_model import ( + ComponentModel, + SPDRunInfo, + handle_deprecated_state_dict_keys_, +) from spd.spd_types import ModelPath -from spd.utils.module_utils import init_param_ +from spd.utils.module_utils import expand_module_patterns, init_param_ + +ResidMLPBatch = tuple[Float[Tensor, "... n_features"], Float[Tensor, "... n_features"]] +ResidMLPOutput = Float[Tensor, "... n_features"] @dataclass @@ -89,9 +99,10 @@ def __init__(self, config: ResidMLPModelConfig): @override def forward( self, - x: Float[Tensor, "... n_features"], + batch: ResidMLPBatch | Float[Tensor, "... n_features"], return_residual: bool = False, ) -> Float[Tensor, "... n_features"] | Float[Tensor, "... d_embed"]: + x = batch[0] if isinstance(batch, tuple) else batch residual = einops.einsum(x, self.W_E, "... n_features, n_features d_embed -> ... d_embed") for layer in self.layers: out = layer(residual) @@ -121,3 +132,45 @@ def from_pretrained(cls, path: ModelPath) -> "ResidMLP": """Fetch a pretrained model from wandb or a local path to a checkpoint.""" run_info = ResidMLPTargetRunInfo.from_path(path) return cls.from_run_info(run_info) + + +def load_resid_mlp_component_model_from_run_info( + run_info: RunInfo[Config], +) -> ComponentModel[ResidMLPBatch, ResidMLPOutput]: + """Load a trained ResidMLP ComponentModel from a run info object.""" + config = run_info.config + assert config.pretrained_model_path is not None + + target_model = ResidMLP.from_pretrained(config.pretrained_model_path) + target_model.eval() + target_model.requires_grad_(False) + + if config.identity_module_info is not None: + insert_identity_operations_( + target_model, + identity_module_info=config.identity_module_info, + ) + + module_path_info = expand_module_patterns(target_model, config.all_module_info) + + comp_model: ComponentModel[ResidMLPBatch, ResidMLPOutput] = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_fn_type=config.ci_fn_type, + sigmoid_type=config.sigmoid_type, + ) + + comp_model_weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) + handle_deprecated_state_dict_keys_(comp_model_weights) + comp_model.load_state_dict(comp_model_weights) + + return comp_model + + +def load_resid_mlp_component_model( + path: ModelPath, +) -> ComponentModel[ResidMLPBatch, ResidMLPOutput]: + """Load a trained ResidMLP ComponentModel from a wandb or local path.""" + run_info = SPDRunInfo.from_path(path) + return load_resid_mlp_component_model_from_run_info(run_info) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index d27b742ae..f49615c20 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -13,7 +13,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.log import logger -from spd.models.component_model import pass_first_tuple_element_to_model, recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.distributed_utils import get_device @@ -109,8 +109,6 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=pass_first_tuple_element_to_model, reconstruction_loss=recon_loss_mse, out_dir=out_dir, ) diff --git a/spd/experiments/resid_mlp/resid_mlp_interp.py b/spd/experiments/resid_mlp/resid_mlp_interp.py index 209ef9747..93a8d7618 100644 --- a/spd/experiments/resid_mlp/resid_mlp_interp.py +++ b/spd/experiments/resid_mlp/resid_mlp_interp.py @@ -9,10 +9,14 @@ from PIL import Image from torch import Tensor -from spd.experiments.resid_mlp.models import MLP, ResidMLP +from spd.experiments.resid_mlp.models import ( + MLP, + ResidMLP, + load_resid_mlp_component_model_from_run_info, +) from spd.experiments.tms.models import TMSModel from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.models.component_model import SPDRunInfo from spd.models.components import Components from spd.plotting import plot_causal_importance_vals from spd.registry import EXPERIMENT_REGISTRY @@ -35,7 +39,7 @@ def extract_ci_val_figures( Dictionary containing causal importances data and metadata """ run_info = SPDRunInfo.from_path(run_id) - model = ComponentModel.from_run_info(run_info) + model = load_resid_mlp_component_model_from_run_info(run_info) model.to(device) config = run_info.config @@ -478,7 +482,7 @@ def main(out_dir: Path, device: str): wandb_id = path.split("/")[-1] run_info = SPDRunInfo.from_path(path) - model = ComponentModel.from_run_info(run_info) + model = load_resid_mlp_component_model_from_run_info(run_info) config = run_info.config assert isinstance(model.target_model, ResidMLP) model.target_model.to(device) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index f8643e225..93e0187d5 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -6,9 +6,20 @@ from torch import Tensor, nn from torch.nn import functional as F +from spd.configs import Config from spd.experiments.tms.configs import TMSModelConfig, TMSTrainConfig +from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo +from spd.models.component_model import ( + ComponentModel, + SPDRunInfo, + handle_deprecated_state_dict_keys_, +) from spd.spd_types import ModelPath +from spd.utils.module_utils import expand_module_patterns + +TMSBatch = tuple[Float[Tensor, "... n_features"], Float[Tensor, "... n_features"]] +TMSOutput = Float[Tensor, "... n_features"] @dataclass @@ -53,8 +64,9 @@ def to(self, *args: Any, **kwargs: Any) -> Self: @override def forward( - self, x: Float[Tensor, "... n_features"], **_: Any + self, batch: TMSBatch | Float[Tensor, "... n_features"], **_: Any ) -> Float[Tensor, "... n_features"]: + x = batch[0] if isinstance(batch, tuple) else batch hidden = self.linear1(x) if self.hidden_layers is not None: for layer in self.hidden_layers: @@ -80,3 +92,43 @@ def from_pretrained(cls, path: ModelPath) -> "TMSModel": """Fetch a pretrained model from wandb or a local path to a checkpoint.""" run_info = TMSTargetRunInfo.from_path(path) return cls.from_run_info(run_info) + + +def load_tms_component_model_from_run_info( + run_info: RunInfo[Config], +) -> ComponentModel[TMSBatch, TMSOutput]: + """Load a trained TMS ComponentModel from a run info object.""" + config = run_info.config + assert config.pretrained_model_path is not None + + target_model = TMSModel.from_pretrained(config.pretrained_model_path) + target_model.eval() + target_model.requires_grad_(False) + + if config.identity_module_info is not None: + insert_identity_operations_( + target_model, + identity_module_info=config.identity_module_info, + ) + + module_path_info = expand_module_patterns(target_model, config.all_module_info) + + comp_model: ComponentModel[TMSBatch, TMSOutput] = ComponentModel( + target_model=target_model, + module_path_info=module_path_info, + ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_fn_type=config.ci_fn_type, + sigmoid_type=config.sigmoid_type, + ) + + comp_model_weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) + handle_deprecated_state_dict_keys_(comp_model_weights) + comp_model.load_state_dict(comp_model_weights) + + return comp_model + + +def load_tms_component_model(path: ModelPath) -> ComponentModel[TMSBatch, TMSOutput]: + """Load a trained TMS ComponentModel from a wandb or local path.""" + run_info = SPDRunInfo.from_path(path) + return load_tms_component_model_from_run_info(run_info) diff --git a/spd/experiments/tms/plotting.py b/spd/experiments/tms/plotting.py index daa7ec57b..ca2094810 100644 --- a/spd/experiments/tms/plotting.py +++ b/spd/experiments/tms/plotting.py @@ -20,9 +20,8 @@ from matplotlib.figure import Figure from torch import Tensor -from spd.experiments.tms.models import TMSModel +from spd.experiments.tms.models import TMSModel, load_tms_component_model from spd.log import logger -from spd.models.component_model import ComponentModel from spd.models.components import Components from spd.settings import REPO_ROOT @@ -981,7 +980,7 @@ def main(): out_dir.mkdir(parents=True, exist_ok=True) # Load models - model = ComponentModel.from_pretrained(run_id) + model = load_tms_component_model(run_id) assert isinstance(model.target_model, TMSModel) # Get custom config and name for this run diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 48a51abef..2471c1830 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -13,7 +13,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSTargetRunInfo from spd.log import logger -from spd.models.component_model import pass_first_tuple_element_to_model, recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.distributed_utils import get_device @@ -105,8 +105,6 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=pass_first_tuple_element_to_model, reconstruction_loss=recon_loss_mse, out_dir=out_dir, tied_weights=tied_weights, diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index a246bc5e2..8e82d7af2 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -24,6 +24,7 @@ from torch import Tensor from spd.data import train_loader_and_tokenizer +from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.harvest.lib.harvester import Harvester, HarvesterState from spd.harvest.schemas import ( ActivationExample, @@ -202,7 +203,7 @@ def harvest_activation_contexts( logger.info(f"Loading model on {device}") run_info = SPDRunInfo.from_path(config.wandb_path) - model = ComponentModel.from_run_info(run_info).to(device) + model = load_lm_component_model_from_run_info(run_info).to(device) model.eval() spd_config = run_info.config diff --git a/spd/losses.py b/spd/losses.py index 6e36ea2d3..a756c7d06 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -34,6 +34,7 @@ stochastic_recon_subset_loss, unmasked_recon_loss, ) +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.utils.general_utils import get_obj_device @@ -50,6 +51,7 @@ def compute_total_loss[BatchT, OutputT]( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], dict[str, float]]: """Compute weighted total loss and per-term raw values using new loss primitives. @@ -79,6 +81,7 @@ def compute_total_loss[BatchT, OutputT]( model=model, batch=batch, target_out=target_out, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconSubsetLossConfig(): loss = ci_masked_recon_subset_loss( @@ -87,6 +90,7 @@ def compute_total_loss[BatchT, OutputT]( target_out=target_out, ci=ci.lower_leaky, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLayerwiseLossConfig(): loss = ci_masked_recon_layerwise_loss( @@ -94,6 +98,7 @@ def compute_total_loss[BatchT, OutputT]( batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=reconstruction_loss, ) case CIMaskedReconLossConfig(): loss = ci_masked_recon_loss( @@ -101,6 +106,7 @@ def compute_total_loss[BatchT, OutputT]( batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=reconstruction_loss, ) case StochasticReconLayerwiseLossConfig(): loss = stochastic_recon_layerwise_loss( @@ -111,6 +117,7 @@ def compute_total_loss[BatchT, OutputT]( target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, + reconstruction_loss=reconstruction_loss, ) case StochasticReconLossConfig(): loss = stochastic_recon_loss( @@ -121,6 +128,7 @@ def compute_total_loss[BatchT, OutputT]( target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, + reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetLossConfig(): loss = stochastic_recon_subset_loss( @@ -132,6 +140,7 @@ def compute_total_loss[BatchT, OutputT]( ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLossConfig(): loss = pgd_recon_loss( @@ -141,6 +150,7 @@ def compute_total_loss[BatchT, OutputT]( ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case PGDReconSubsetLossConfig(): loss = pgd_recon_subset_loss( @@ -151,6 +161,7 @@ def compute_total_loss[BatchT, OutputT]( weight_deltas=weight_deltas if use_delta_component else None, pgd_config=cfg, routing=cfg.routing, + reconstruction_loss=reconstruction_loss, ) case PGDReconLayerwiseLossConfig(): loss = pgd_recon_layerwise_loss( @@ -160,6 +171,7 @@ def compute_total_loss[BatchT, OutputT]( ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, pgd_config=cfg, + reconstruction_loss=reconstruction_loss, ) case StochasticHiddenActsReconLossConfig(): loss = stochastic_hidden_acts_recon_loss( diff --git a/spd/metrics/ci_masked_recon_layerwise_loss.py b/spd/metrics/ci_masked_recon_layerwise_loss.py index 5db109845..44539c236 100644 --- a/spd/metrics/ci_masked_recon_layerwise_loss.py +++ b/spd/metrics/ci_masked_recon_layerwise_loss.py @@ -6,6 +6,7 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce @@ -17,13 +18,14 @@ def _ci_masked_recon_layerwise_loss_update[BatchT, OutputT]( batch: BatchT, target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: sum_loss = torch.tensor(0.0, device=get_obj_device(model)) sum_n_examples = 0 mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) for module_name, mask_info in mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss, n_examples = model.reconstruction_loss(out, target_out) + loss, n_examples = reconstruction_loss(out, target_out) sum_loss += loss sum_n_examples += n_examples return sum_loss, sum_n_examples @@ -40,12 +42,14 @@ def ci_masked_recon_layerwise_loss[BatchT, OutputT]( batch: BatchT, target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, sum_n_examples = _ci_masked_recon_layerwise_loss_update( model=model, batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=reconstruction_loss, ) return _ci_masked_recon_layerwise_loss_compute(sum_loss, sum_n_examples) @@ -59,8 +63,10 @@ def __init__( self, model: ComponentModel[BatchT, OutputT], device: str, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.sum_n_examples = torch.tensor(0, device=device) @@ -78,6 +84,7 @@ def update( batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.sum_n_examples += sum_n_examples diff --git a/spd/metrics/ci_masked_recon_loss.py b/spd/metrics/ci_masked_recon_loss.py index c085eb0e9..a54ceb6e2 100644 --- a/spd/metrics/ci_masked_recon_loss.py +++ b/spd/metrics/ci_masked_recon_loss.py @@ -6,6 +6,7 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce @@ -16,10 +17,11 @@ def _ci_masked_recon_loss_update[BatchT, OutputT]( batch: BatchT, target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) out = model(batch, mask_infos=mask_infos) - return model.reconstruction_loss(out, target_out) + return reconstruction_loss(out, target_out) def _ci_masked_recon_loss_compute( @@ -33,12 +35,14 @@ def ci_masked_recon_loss[BatchT, OutputT]( batch: BatchT, target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_loss_update( model=model, batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=reconstruction_loss, ) return _ci_masked_recon_loss_compute(sum_loss, n_examples) @@ -52,8 +56,10 @@ def __init__( self, model: ComponentModel[BatchT, OutputT], device: str, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -71,6 +77,7 @@ def update( batch=batch, target_out=target_out, ci=ci.lower_leaky, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/ci_masked_recon_subset_loss.py b/spd/metrics/ci_masked_recon_subset_loss.py index bbd40a81b..6b0fba016 100644 --- a/spd/metrics/ci_masked_recon_subset_loss.py +++ b/spd/metrics/ci_masked_recon_subset_loss.py @@ -7,6 +7,7 @@ from spd.configs import SubsetRoutingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.routing import Router, get_subset_router @@ -20,6 +21,7 @@ def _ci_masked_recon_subset_loss_update[BatchT, OutputT]( target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], router: Router, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: subset_routing_masks = router.get_masks( module_names=model.target_module_paths, @@ -31,7 +33,7 @@ def _ci_masked_recon_subset_loss_update[BatchT, OutputT]( weight_deltas_and_masks=None, ) out = model(batch, mask_infos=mask_infos) - return model.reconstruction_loss(out, target_out) + return reconstruction_loss(out, target_out) def _ci_masked_recon_subset_loss_compute( @@ -46,6 +48,7 @@ def ci_masked_recon_subset_loss[BatchT, OutputT]( target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=model, @@ -53,6 +56,7 @@ def ci_masked_recon_subset_loss[BatchT, OutputT]( target_out=target_out, ci=ci, router=get_subset_router(routing, device=get_obj_device(model)), + reconstruction_loss=reconstruction_loss, ) return _ci_masked_recon_subset_loss_compute(sum_loss, n_examples) @@ -67,10 +71,11 @@ def __init__( model: ComponentModel[BatchT, OutputT], device: str, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model self.router = get_subset_router(routing, device) - + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -89,6 +94,7 @@ def update( target_out=target_out, ci=ci.lower_leaky, router=self.router, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/pgd_masked_recon_layerwise_loss.py b/spd/metrics/pgd_masked_recon_layerwise_loss.py index 1c21fa175..cd24a3c5d 100644 --- a/spd/metrics/pgd_masked_recon_layerwise_loss.py +++ b/spd/metrics/pgd_masked_recon_layerwise_loss.py @@ -8,6 +8,7 @@ from spd.configs import PGDConfig from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import LayerRouter from spd.utils.distributed_utils import all_reduce @@ -21,6 +22,7 @@ def _pgd_recon_layerwise_loss_update[BatchT, OutputT]( ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], Int[Tensor, ""]]: device = next(iter(ci.values())).device sum_loss = torch.tensor(0.0, device=device) @@ -34,6 +36,7 @@ def _pgd_recon_layerwise_loss_update[BatchT, OutputT]( target_out=target_out, router=LayerRouter(device=device, layer_name=layer), pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) sum_loss += sum_loss_layer n_examples += n_examples_layer @@ -48,6 +51,7 @@ def pgd_recon_layerwise_loss[BatchT, OutputT]( ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, n_examples = _pgd_recon_layerwise_loss_update( model=model, @@ -56,6 +60,7 @@ def pgd_recon_layerwise_loss[BatchT, OutputT]( ci=ci, weight_deltas=weight_deltas, pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) return sum_loss / n_examples @@ -72,10 +77,12 @@ def __init__( pgd_config: PGDConfig, device: str, use_delta_component: bool, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config self.use_delta_component: bool = use_delta_component + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -96,6 +103,7 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, pgd_config=self.pgd_config, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/pgd_masked_recon_loss.py b/spd/metrics/pgd_masked_recon_loss.py index b02413526..b82763b75 100644 --- a/spd/metrics/pgd_masked_recon_loss.py +++ b/spd/metrics/pgd_masked_recon_loss.py @@ -8,6 +8,7 @@ from spd.configs import PGDConfig from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.distributed_utils import all_reduce @@ -21,6 +22,7 @@ def pgd_recon_loss[BatchT, OutputT]( ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, n_examples = pgd_masked_recon_loss_update( model=model, @@ -30,6 +32,7 @@ def pgd_recon_loss[BatchT, OutputT]( target_out=target_out, router=AllLayersRouter(), pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) return sum_loss / n_examples @@ -46,10 +49,12 @@ def __init__( device: str, pgd_config: PGDConfig, use_delta_component: bool, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config self.use_delta_component: bool = use_delta_component + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -71,6 +76,7 @@ def update( target_out=target_out, router=AllLayersRouter(), pgd_config=self.pgd_config, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/pgd_masked_recon_subset_loss.py b/spd/metrics/pgd_masked_recon_subset_loss.py index a34c3677a..1dfdf22e9 100644 --- a/spd/metrics/pgd_masked_recon_subset_loss.py +++ b/spd/metrics/pgd_masked_recon_subset_loss.py @@ -8,6 +8,7 @@ from spd.configs import PGDConfig, SubsetRoutingType from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import get_subset_router from spd.utils.distributed_utils import all_reduce @@ -23,6 +24,7 @@ def pgd_recon_subset_loss[BatchT, OutputT]( weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, n_examples = pgd_masked_recon_loss_update( model=model, @@ -32,6 +34,7 @@ def pgd_recon_subset_loss[BatchT, OutputT]( target_out=target_out, router=get_subset_router(routing, device=get_obj_device(model)), pgd_config=pgd_config, + reconstruction_loss=reconstruction_loss, ) return sum_loss / n_examples @@ -49,11 +52,13 @@ def __init__( use_delta_component: bool, pgd_config: PGDConfig, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config self.use_delta_component: bool = use_delta_component self.router = get_subset_router(routing, device=get_obj_device(model)) + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -76,6 +81,7 @@ def update( target_out=target_out, router=self.router, pgd_config=self.pgd_config, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index 8f97ac816..5a0e4b32e 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -9,6 +9,7 @@ from spd.configs import PGDConfig, PGDInitStrategy, PGDMultiBatchConfig, SamplingType from spd.log import logger +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import RoutingMasks, make_mask_infos from spd.routing import Router @@ -24,6 +25,7 @@ def pgd_masked_recon_loss_update[BatchT, OutputT]( target_out: OutputT, router: Router, pgd_config: PGDConfig, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: """Central implementation of PGD masked reconstruction loss. @@ -57,6 +59,7 @@ def pgd_masked_recon_loss_update[BatchT, OutputT]( routing_masks=routing_masks, target_out=target_out, batch_dims=batch_dims, + reconstruction_loss=reconstruction_loss, ) for _ in range(pgd_config.n_steps): @@ -90,6 +93,7 @@ def calc_multibatch_pgd_masked_recon_loss[BatchT, OutputT]( sampling: SamplingType, use_delta_component: bool, device: str, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: """PGD masked reconstruction loss with gradient accumulation over multiple batches. @@ -102,10 +106,10 @@ def calc_multibatch_pgd_masked_recon_loss[BatchT, OutputT]( create_data_iter: Function to create an iterator over batches. This function should return an iterator which behaves identically each time. Specifically in terms of data ordering and shuffling. - output_loss_type: Loss type for reconstruction ("mse" or "kl") router: Router to use for routing masks sampling: Sampling mode for causal importance calculation use_delta_component: Whether to include weight delta component + reconstruction_loss: Function to compute reconstruction loss Returns: Final reconstruction loss after PGD optimization """ @@ -136,6 +140,7 @@ def calc_multibatch_pgd_masked_recon_loss[BatchT, OutputT]( device=device, sampling=sampling, router=router, + reconstruction_loss=reconstruction_loss, ) for _ in range(pgd_config.n_steps): @@ -160,6 +165,7 @@ def _forward_with_adv_sources[BatchT, OutputT]( routing_masks: RoutingMasks, target_out: OutputT, batch_dims: tuple[int, ...], + reconstruction_loss: ReconstructionLoss[OutputT], ): expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]] @@ -180,7 +186,7 @@ def _forward_with_adv_sources[BatchT, OutputT]( ) out = model(batch, mask_infos=mask_infos) - sum_loss, n_examples = model.reconstruction_loss(out, target_out) + sum_loss, n_examples = reconstruction_loss(out, target_out) return sum_loss, n_examples @@ -194,6 +200,7 @@ def _multibatch_pgd_fwd_bwd[BatchT, OutputT]( device: torch.device | str, router: Router, sampling: SamplingType, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int, dict[str, Float[Tensor, "*ones mask_c"]]]: """Perform a forward and backward pass over multiple batches with gradient accumulation. @@ -239,6 +246,7 @@ def _multibatch_pgd_fwd_bwd[BatchT, OutputT]( routing_masks=routing_masks, target_out=target_model_output.output, batch_dims=batch_dims, + reconstruction_loss=reconstruction_loss, ) pgd_step_accum_sum_loss += batch_sum_loss diff --git a/spd/metrics/stochastic_recon_layerwise_loss.py b/spd/metrics/stochastic_recon_layerwise_loss.py index da8df9f0d..6d93ae430 100644 --- a/spd/metrics/stochastic_recon_layerwise_loss.py +++ b/spd/metrics/stochastic_recon_layerwise_loss.py @@ -7,6 +7,7 @@ from spd.configs import SamplingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info @@ -22,6 +23,7 @@ def _stochastic_recon_layerwise_loss_update[BatchT, OutputT]( target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -41,7 +43,7 @@ def _stochastic_recon_layerwise_loss_update[BatchT, OutputT]( for stochastic_mask_infos in stochastic_mask_infos_list: for module_name, mask_info in stochastic_mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss, batch_n_examples = model.reconstruction_loss(out, target_out) + loss, batch_n_examples = reconstruction_loss(out, target_out) sum_loss += loss sum_n_examples += batch_n_examples return sum_loss, sum_n_examples @@ -61,6 +63,7 @@ def stochastic_recon_layerwise_loss[BatchT, OutputT]( target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, sum_n_examples = _stochastic_recon_layerwise_loss_update( model=model, @@ -70,6 +73,7 @@ def stochastic_recon_layerwise_loss[BatchT, OutputT]( target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=reconstruction_loss, ) return _stochastic_recon_layerwise_loss_compute(sum_loss, sum_n_examples) @@ -86,11 +90,13 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.sum_n_examples = torch.tensor(0, device=device) @@ -112,6 +118,7 @@ def update( target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.sum_n_examples += sum_n_examples diff --git a/spd/metrics/stochastic_recon_loss.py b/spd/metrics/stochastic_recon_loss.py index 42601b514..893ca432e 100644 --- a/spd/metrics/stochastic_recon_loss.py +++ b/spd/metrics/stochastic_recon_loss.py @@ -7,6 +7,7 @@ from spd.configs import SamplingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info @@ -22,6 +23,7 @@ def _stochastic_recon_loss_update[BatchT, OutputT]( target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -36,7 +38,7 @@ def _stochastic_recon_loss_update[BatchT, OutputT]( router=AllLayersRouter(), ) out = model(batch, mask_infos=stoch_mask_infos) - loss, n_examples = model.reconstruction_loss(out, target_out) + loss, n_examples = reconstruction_loss(out, target_out) sum_loss += loss sum_n_examples += n_examples @@ -57,6 +59,7 @@ def stochastic_recon_loss[BatchT, OutputT]( target_out: OutputT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, sum_n_examples = _stochastic_recon_loss_update( model=model, @@ -66,6 +69,7 @@ def stochastic_recon_loss[BatchT, OutputT]( target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=reconstruction_loss, ) return _stochastic_recon_loss_compute(sum_loss, sum_n_examples) @@ -82,11 +86,13 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.sum_n_examples = torch.tensor(0, device=device) @@ -108,6 +114,7 @@ def update( target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.sum_n_examples += sum_n_examples diff --git a/spd/metrics/stochastic_recon_subset_loss.py b/spd/metrics/stochastic_recon_subset_loss.py index ae8147126..ea7cbd5f1 100644 --- a/spd/metrics/stochastic_recon_subset_loss.py +++ b/spd/metrics/stochastic_recon_subset_loss.py @@ -7,6 +7,7 @@ from spd.configs import SamplingType, SubsetRoutingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import Router, get_subset_router from spd.utils.component_utils import calc_stochastic_component_mask_info @@ -23,6 +24,7 @@ def _stochastic_recon_subset_loss_update[BatchT, OutputT]( ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, router: Router, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -41,7 +43,7 @@ def _stochastic_recon_subset_loss_update[BatchT, OutputT]( for stoch_mask_infos in stoch_mask_infos_list: out = model(batch, mask_infos=stoch_mask_infos) - loss, batch_n_examples = model.reconstruction_loss(out, target_out) + loss, batch_n_examples = reconstruction_loss(out, target_out) sum_loss += loss n_examples += batch_n_examples return sum_loss, n_examples @@ -62,6 +64,7 @@ def stochastic_recon_subset_loss[BatchT, OutputT]( ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, n_examples = _stochastic_recon_subset_loss_update( model=model, @@ -72,6 +75,7 @@ def stochastic_recon_subset_loss[BatchT, OutputT]( ci=ci, weight_deltas=weight_deltas, router=get_subset_router(routing, device=get_obj_device(model)), + reconstruction_loss=reconstruction_loss, ) return _stochastic_recon_subset_loss_compute(sum_loss, n_examples) @@ -89,12 +93,14 @@ def __init__( use_delta_component: bool, n_mask_samples: int, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples self.router = get_subset_router(routing, device) + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -117,6 +123,7 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, router=self.router, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/unmasked_recon_loss.py b/spd/metrics/unmasked_recon_loss.py index 1c6681a20..f9bc34545 100644 --- a/spd/metrics/unmasked_recon_loss.py +++ b/spd/metrics/unmasked_recon_loss.py @@ -6,6 +6,7 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce @@ -16,6 +17,7 @@ def _unmasked_recon_loss_update[BatchT, OutputT]( model: ComponentModel[BatchT, OutputT], batch: BatchT, target_out: OutputT, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> tuple[Float[Tensor, ""], int]: all_ones_mask_infos = make_mask_infos( # (C,) will broadcast to (B, S, C) @@ -25,7 +27,7 @@ def _unmasked_recon_loss_update[BatchT, OutputT]( } ) out = model(batch, mask_infos=all_ones_mask_infos) - return model.reconstruction_loss(out, target_out) + return reconstruction_loss(out, target_out) def _unmasked_recon_loss_compute( @@ -38,11 +40,13 @@ def unmasked_recon_loss[BatchT, OutputT]( model: ComponentModel[BatchT, OutputT], batch: BatchT, target_out: OutputT, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> Float[Tensor, ""]: sum_loss, n_examples = _unmasked_recon_loss_update( model, batch, target_out, + reconstruction_loss, ) return _unmasked_recon_loss_compute(sum_loss, n_examples) @@ -56,8 +60,10 @@ def __init__( self, model: ComponentModel[BatchT, OutputT], device: str, + reconstruction_loss: ReconstructionLoss[OutputT], ) -> None: self.model = model + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -73,6 +79,7 @@ def update( model=self.model, batch=batch, target_out=target_out, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/models/batch_and_loss_fns.py b/spd/models/batch_and_loss_fns.py new file mode 100644 index 000000000..90914e263 --- /dev/null +++ b/spd/models/batch_and_loss_fns.py @@ -0,0 +1,39 @@ +"""Batch handling and reconstruction loss functions for different model types. + +These functions parameterize ComponentModel and training for different target model architectures. +""" + +from typing import Protocol + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor + + +class ReconstructionLoss[OutputT](Protocol): + """Protocol for computing reconstruction loss between predictions and targets.""" + + def __call__(self, pred: OutputT, target: OutputT) -> tuple[Float[Tensor, ""], int]: ... + + +def recon_loss_mse( + pred: Float[Tensor, "... d"], + target: Float[Tensor, "... d"], +) -> tuple[Float[Tensor, ""], int]: + """MSE reconstruction loss. Returns (sum_of_squared_errors, n_elements).""" + assert pred.shape == target.shape + squared_errors = (pred - target) ** 2 + return squared_errors.sum(), pred.numel() + + +def recon_loss_kl( + pred: Float[Tensor, "... vocab"], + target: Float[Tensor, "... vocab"], +) -> tuple[Float[Tensor, ""], int]: + """KL divergence reconstruction loss for logits. Returns (sum_of_kl, n_positions).""" + assert pred.shape == target.shape + log_q = torch.log_softmax(pred, dim=-1) # log Q + p = torch.softmax(target, dim=-1) # P + kl_per_position = F.kl_div(log_q, p, reduction="none").sum(dim=-1) # P · (log P − log Q) + return kl_per_position.sum(), pred[..., 0].numel() diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 93f194265..84c9c5a9e 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,28 +1,17 @@ -from abc import ABC -from collections.abc import Callable, Generator +from collections.abc import Callable, Generator, Iterator from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import Any, Literal, NamedTuple, Protocol, Self, overload, override +from typing import Any, Literal, NamedTuple, Protocol, overload, override import torch -import torch.nn.functional as F from jaxtyping import Float, Int from torch import Tensor, nn from torch.utils.hooks import RemovableHandle from transformers.pytorch_utils import Conv1D as RadfordConv1D -from spd.configs import ( - Config, - IHTaskConfig, - LMTaskConfig, - ResidMLPTaskConfig, - SamplingType, - TaskConfig, - TMSTaskConfig, -) -from spd.identity_insertion import insert_identity_operations_ -from spd.interfaces import LoadableModule, RunInfo +from spd.configs import Config, SamplingType +from spd.interfaces import RunInfo from spd.models.components import ( Components, ComponentsMaskInfo, @@ -34,9 +23,9 @@ VectorSharedMLPCiFn, ) from spd.models.sigmoids import SIGMOID_TYPES, SigmoidType -from spd.spd_types import CiFnType, ModelPath -from spd.utils.general_utils import resolve_class, runtime_cast -from spd.utils.module_utils import ModulePathInfo, expand_module_patterns +from spd.spd_types import CiFnType +from spd.utils.general_utils import runtime_cast +from spd.utils.module_utils import ModulePathInfo @dataclass @@ -62,15 +51,19 @@ class CIOutputs: pre_sigmoid: dict[str, Tensor] -class RunBatch[BatchT, OutputT](Protocol): - def __call__(self, target_model: nn.Module, batch: BatchT) -> OutputT: ... +class TargetModel[BatchT, OutputT](Protocol): + # def __call__(self, batch: BatchT) -> OutputT: ... + + def __call__(self, batch: BatchT) -> OutputT: ... + + def get_submodule(self, target: str) -> nn.Module: ... + def named_parameters(self) -> Iterator[tuple[str, nn.Parameter]]: ... -class ReconstructionLoss[OutputT](Protocol): - def __call__(self, pred: OutputT, target: OutputT) -> tuple[Float[Tensor, ""], int]: ... + # def named_modules(self) -> Generator[tuple[str, nn.Module]]: ... -class ComponentModel[BatchT, OutputT](LoadableModule, ABC): +class ComponentModel[BatchT, OutputT](nn.Module): """Wrapper around an arbitrary pytorch model for running SPD. The underlying *base model* can be any subclass of `nn.Module` (e.g. @@ -91,13 +84,11 @@ class ComponentModel[BatchT, OutputT](LoadableModule, ABC): def __init__( self, - target_model: nn.Module, + target_model: TargetModel[BatchT, OutputT], module_path_info: list[ModulePathInfo], ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], sigmoid_type: SigmoidType, - run_batch: RunBatch[BatchT, OutputT], - reconstruction_loss: ReconstructionLoss[OutputT], ): super().__init__() @@ -137,9 +128,6 @@ def __init__( self.lower_leaky_fn = SIGMOID_TYPES[sigmoid_type] self.upper_leaky_fn = SIGMOID_TYPES[sigmoid_type] - self._run_batch = run_batch - self.reconstruction_loss = reconstruction_loss - def target_weight(self, module_name: str) -> Float[Tensor, "rows cols"]: target_module = self.target_model.get_submodule(module_name) @@ -196,7 +184,7 @@ def _create_component( @staticmethod def _create_components( - target_model: nn.Module, + target_model: TargetModel[BatchT, OutputT], module_to_c: dict[str, int], ) -> dict[str, Components]: components: dict[str, Components] = {} @@ -239,7 +227,7 @@ def _create_ci_fn( @staticmethod def _create_ci_fns( - target_model: nn.Module, + target_model: TargetModel[BatchT, OutputT], module_to_c: dict[str, int], ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], @@ -314,7 +302,7 @@ def forward( """ if mask_infos is None and cache_type == "none": # No hooks needed. Do a regular forward pass of the target model. - return self._run_batch(self.target_model, batch) + return self.target_model(batch) cache: dict[str, Tensor] = {} hooks: dict[str, Callable[..., Any]] = {} @@ -335,7 +323,7 @@ def forward( ) with self._attach_forward_hooks(hooks): - out: OutputT = self._run_batch(self.target_model, batch) + out: OutputT = self.target_model(batch) match cache_type: case "input" | "component_acts": @@ -418,67 +406,6 @@ def _attach_forward_hooks(self, hooks: dict[str, Callable[..., Any]]) -> Generat for handle in handles: handle.remove() - @classmethod - @override - def from_run_info(cls, run_info: RunInfo[Config]) -> Self: - """Load a trained ComponentModel checkpoint from a run info object.""" - config = run_info.config - - # Load the target model - model_class = resolve_class(config.pretrained_model_class) - if config.pretrained_model_name is not None: - assert hasattr(model_class, "from_pretrained"), ( - f"Model class {model_class} should have a `from_pretrained` method" - ) - target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] - else: - assert issubclass(model_class, LoadableModule), ( - f"Model class {model_class} should be a subclass of LoadableModule which " - "defines a `from_pretrained` method" - ) - assert run_info.config.pretrained_model_path is not None - target_model = model_class.from_pretrained(run_info.config.pretrained_model_path) - - target_model.eval() - target_model.requires_grad_(False) - - if config.identity_module_info is not None: - insert_identity_operations_( - target_model, - identity_module_info=config.identity_module_info, - ) - - module_path_info = expand_module_patterns(target_model, config.all_module_info) - - run_batch = get_run_batch(config.task_config, config.pretrained_model_output_attr) - reconstruction_loss = get_reconstruction_loss(config.task_config) - - comp_model = cls( - target_model=target_model, - module_path_info=module_path_info, - run_batch=run_batch, - reconstruction_loss=reconstruction_loss, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - ci_fn_type=config.ci_fn_type, - sigmoid_type=config.sigmoid_type, - ) - - comp_model_weights = torch.load( - run_info.checkpoint_path, map_location="cpu", weights_only=True - ) - - handle_deprecated_state_dict_keys_(comp_model_weights) - - comp_model.load_state_dict(comp_model_weights) - return comp_model - - @classmethod - @override - def from_pretrained(cls, path: ModelPath) -> Self: - """Load a trained ComponentModel checkpoint from a local or wandb path.""" - run_info = SPDRunInfo.from_path(path) - return cls.from_run_info(run_info) - def calc_causal_importances( self, pre_weight_acts: dict[str, Float[Tensor, "... d_in"] | Int[Tensor, "... pos"]], @@ -591,78 +518,3 @@ def handle_deprecated_state_dict_keys_(state_dict: dict[str, Tensor]) -> None: # replace if modified if new_key != key: state_dict[new_key] = state_dict.pop(key) - - -def pass_first_tuple_element_to_model[BatchT: tuple[Any, ...], OutputT]( - target_model: nn.Module, - batch: BatchT, # pyright: ignore[reportInvalidTypeVarUse] -) -> OutputT: # pyright: ignore[reportInvalidTypeVarUse] - return target_model(batch[0]) - - -def pass_batch_directly_to_model[BatchT, OutputT]( - target_model: nn.Module, - batch: BatchT, # pyright: ignore[reportInvalidTypeVarUse] -) -> OutputT: # pyright: ignore[reportInvalidTypeVarUse] - return target_model(batch) - - -def run_batch_extract_idx(idx: int, target_model: nn.Module, batch: Any) -> Any: - return target_model(batch)[idx] - - -def run_batch_extract_attr(attr: str, target_model: nn.Module, batch: Any) -> Any: - return getattr(target_model(batch), attr) - - -def make_run_batch_lm(output_attr: str | None) -> RunBatch[Any, Any]: - if output_attr is None: - return pass_batch_directly_to_model - if output_attr.startswith("idx_"): - idx = int(output_attr.removeprefix("idx_")) - return partial(run_batch_extract_idx, idx) - return partial(run_batch_extract_attr, output_attr) - - -def get_run_batch(task_config: TaskConfig, output_attr: str | None = None) -> RunBatch[Any, Any]: - match task_config: - case TMSTaskConfig() | ResidMLPTaskConfig(): - assert output_attr is None, ( - "output_attr not supported for TMSTaskConfig and ResidMLPTaskConfig" - ) - return pass_first_tuple_element_to_model - case LMTaskConfig() | IHTaskConfig(): - return make_run_batch_lm(output_attr) - - -# the following recon loss functions should return pre-mean values - - -def recon_loss_mse( - pred: Float[Tensor, "... d"], - target: Float[Tensor, "... d"], -) -> tuple[Float[Tensor, ""], int]: - assert pred.shape == target.shape - squared_errors = (pred - target) ** 2 - return squared_errors.sum(), pred.numel() - - -def recon_loss_kl( - pred: Float[Tensor, "... vocab"], - target: Float[Tensor, "... vocab"], -) -> tuple[Float[Tensor, ""], int]: - assert pred.shape == target.shape - log_q = torch.log_softmax(pred, dim=-1) # log Q - p = torch.softmax(target, dim=-1) # P - kl_per_position = F.kl_div(log_q, p, reduction="none").sum(dim=-1) # P · (log P − log Q) - return kl_per_position.sum(), pred.numel() - - -def get_reconstruction_loss( - task_config: TaskConfig, -) -> ReconstructionLoss[Any]: - match task_config: - case TMSTaskConfig() | ResidMLPTaskConfig(): - return recon_loss_mse - case LMTaskConfig() | IHTaskConfig(): - return recon_loss_kl diff --git a/spd/run_spd.py b/spd/run_spd.py index d045e3004..38a919f7c 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -30,7 +30,8 @@ from spd.log import logger from spd.losses import compute_total_loss from spd.metrics import faithfulness_loss -from spd.models.component_model import ComponentModel, OutputWithCache, ReconstructionLoss, RunBatch +from spd.models.batch_and_loss_fns import ReconstructionLoss +from spd.models.component_model import ComponentModel, OutputWithCache, TargetModel from spd.utils.component_utils import calc_ci_l_zero from spd.utils.distributed_utils import ( avg_metrics_across_ranks, @@ -38,7 +39,7 @@ is_main_process, sync_across_processes, ) -from spd.utils.general_utils import dict_safe_update_, get_scheduled_value +from spd.utils.general_utils import dict_safe_update_, get_scheduled_value, runtime_cast from spd.utils.logging_utils import get_grad_norms_dict, local_log from spd.utils.module_utils import expand_module_patterns, replace_std_values_in_layernorm from spd.utils.run_utils import save_file @@ -105,13 +106,11 @@ def get_unique_metric_configs( def optimize[BatchT, OutputT]( - target_model: nn.Module, + target_model: TargetModel[BatchT, OutputT], config: Config, device: str, train_loader: DataLoader[BatchT], eval_loader: DataLoader[BatchT], - n_eval_steps: int, - run_batch: RunBatch[BatchT, OutputT], reconstruction_loss: ReconstructionLoss[OutputT], out_dir: Path | None, tied_weights: list[tuple[str, str]] | None = None, @@ -132,13 +131,15 @@ def create_pgd_data_iter() -> Iterator[BatchT]: if config.identity_module_info is not None: insert_identity_operations_( - target_model, + runtime_cast(nn.Module, target_model), identity_module_info=config.identity_module_info, ) - target_model.requires_grad_(False) + cast(nn.Module, target_model).requires_grad_(False) - module_path_info = expand_module_patterns(target_model, config.all_module_info) + module_path_info = expand_module_patterns( + runtime_cast(nn.Module, target_model), config.all_module_info + ) model = ComponentModel( target_model=target_model, @@ -146,8 +147,6 @@ def create_pgd_data_iter() -> Iterator[BatchT]: ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, sigmoid_type=config.sigmoid_type, - run_batch=run_batch, - reconstruction_loss=reconstruction_loss, ) if ln_stds is not None: @@ -258,6 +257,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: sampling=config.sampling, use_delta_component=config.use_delta_component, n_mask_samples=config.n_mask_samples, + reconstruction_loss=reconstruction_loss, ) microbatch_total_loss.div_(config.gradient_accumulation_steps).backward() @@ -309,6 +309,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: create_data_iter=create_pgd_data_iter, config=config, device=device, + reconstruction_loss=reconstruction_loss, ) metrics = evaluate( @@ -318,8 +319,9 @@ def create_pgd_data_iter() -> Iterator[BatchT]: device=device, run_config=config, slow_step=slow_step, - n_eval_steps=n_eval_steps, + n_eval_steps=config.n_eval_steps, current_frac_of_training=step / config.steps, + reconstruction_loss=reconstruction_loss, ) dict_safe_update_(metrics, multibatch_pgd_metrics) diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index f41a56783..f1b25e608 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -22,6 +22,7 @@ from spd.base_config import BaseConfig from spd.configs import Config +from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device @@ -82,7 +83,8 @@ def __init__(self, config: CompareModelsConfig): def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel[Any, Any], Config]: """Load model and config using the standard pattern from existing codebase.""" run_info = SPDRunInfo.from_path(model_path) - model = ComponentModel.from_run_info(run_info) + # TODO(oli): this should actually be generic (one of the only instances of this I think) + model = load_lm_component_model_from_run_info(run_info) model.to(self.device) model.eval() model.requires_grad_(False) diff --git a/spd/simple_trainer.py b/spd/simple_trainer.py new file mode 100644 index 000000000..bcf92945c --- /dev/null +++ b/spd/simple_trainer.py @@ -0,0 +1,251 @@ +"""Run SPD on a model.""" + +import gc +from collections import defaultdict +from pathlib import Path +from typing import cast + +import torch +import torch.nn as nn +import torch.nn.parallel +import wandb +from PIL import Image +from torch import optim +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import DataLoader +from tqdm import tqdm + +from spd.configs import Config +from spd.eval import evaluate +from spd.log import logger +from spd.losses import compute_total_loss +from spd.models.batch_and_loss_fns import ReconstructionLoss +from spd.models.component_model import ComponentModel, OutputWithCache, TargetModel +from spd.run_spd import get_unique_metric_configs, run_faithfulness_warmup +from spd.utils.component_utils import calc_ci_l_zero +from spd.utils.distributed_utils import ( + avg_metrics_across_ranks, + get_distributed_state, + is_main_process, + sync_across_processes, +) +from spd.utils.general_utils import dict_safe_update_, get_scheduled_value, runtime_cast +from spd.utils.logging_utils import get_grad_norms_dict, local_log +from spd.utils.module_utils import expand_module_patterns +from spd.utils.run_utils import save_file +from spd.utils.wandb_utils import try_wandb + + +def optimize[BatchT, OutputT]( + target_model: TargetModel[BatchT, OutputT], + config: Config, + device: str, + train_loader: DataLoader[BatchT], + eval_loader: DataLoader[BatchT], + n_eval_steps: int, + reconstruction_loss: ReconstructionLoss[OutputT], + out_dir: Path | None, +) -> None: + """Run the optimization loop for LM decomposition.""" + train_iterator = iter(train_loader) + eval_iterator = iter(eval_loader) + + runtime_cast(nn.Module, target_model).requires_grad_(False) + model = ComponentModel( + target_model=target_model, + module_path_info=expand_module_patterns( + runtime_cast(nn.Module, target_model), config.all_module_info + ), + ci_fn_type=config.ci_fn_type, + ci_fn_hidden_dims=config.ci_fn_hidden_dims, + sigmoid_type=config.sigmoid_type, + ) + model.to(device) + + # Wrap model with DDP if distributed + dist_state = get_distributed_state() + wrapped_model: nn.Module = model + + component_model: ComponentModel[BatchT, OutputT] + if dist_state is not None: + if dist_state.backend == "nccl": + device_id = dist_state.local_rank + wrapped_model = torch.nn.parallel.DistributedDataParallel( + model, + device_ids=[device_id], + output_device=device_id, + ) + else: + # For CPU, don't pass device_ids or output_device + wrapped_model = torch.nn.parallel.DistributedDataParallel(model) + # Access the underlying module for component operations + component_model = wrapped_model.module # type: ignore[attr-defined] + else: + component_model = model + assert isinstance(component_model, ComponentModel), "component_model is not a ComponentModel" + + component_params: list[torch.nn.Parameter] = [] + ci_fn_params: list[torch.nn.Parameter] = [] + for name in component_model.target_module_paths: + component_params.extend(component_model.components[name].parameters()) + ci_fn_params.extend(component_model.ci_fns[name].parameters()) + + assert len(component_params) > 0, "No parameters found in components to optimize" + + optimizer = optim.AdamW( + component_params + ci_fn_params, + lr=config.lr_schedule.start_val, + weight_decay=0, + ) + + logger.info(f"LR scheduler: {config.lr_schedule.fn_type}") + + if config.faithfulness_warmup_steps > 0: + run_faithfulness_warmup(component_model, component_params, config) + + eval_metric_configs = get_unique_metric_configs( + loss_configs=config.loss_metric_configs, eval_configs=config.eval_metric_configs + ) + + for step in tqdm(range(config.steps + 1), ncols=0, disable=not is_main_process()): + optimizer.zero_grad() + + step_lr = get_scheduled_value( + step=step, total_steps=config.steps, config=config.lr_schedule + ) + for group in optimizer.param_groups: + group["lr"] = step_lr + + weight_deltas = component_model.calc_weight_deltas() + + microbatch_log_data: defaultdict[str, float] = defaultdict(float) + + for _ in range(config.gradient_accumulation_steps): + microbatch = next(train_iterator) + + # NOTE: we need to call the wrapped_model at least once each step in order to setup + # the DDP gradient syncing for all parameters in the component model. Gradients will + # sync regardless of whether the parameters are used in this call to wrapped_model. + target_model_output: OutputWithCache[OutputT] = wrapped_model( + microbatch, cache_type="input" + ) + + ci = component_model.calc_causal_importances( + pre_weight_acts=target_model_output.cache, + detach_inputs=False, + sampling=config.sampling, + ) + + microbatch_total_loss, microbatch_loss_terms = compute_total_loss( + loss_metric_configs=config.loss_metric_configs, + model=component_model, + batch=microbatch, + ci=ci, + target_out=target_model_output.output, + weight_deltas=weight_deltas, + pre_weight_acts=target_model_output.cache, + current_frac_of_training=step / config.steps, + sampling=config.sampling, + use_delta_component=config.use_delta_component, + n_mask_samples=config.n_mask_samples, + reconstruction_loss=reconstruction_loss, + ) + microbatch_total_loss.div_(config.gradient_accumulation_steps).backward() + + for loss_name, loss_value in microbatch_loss_terms.items(): + microbatch_log_data[f"train/{loss_name}"] += ( + loss_value / config.gradient_accumulation_steps + ) + + for layer_name, layer_ci in ci.lower_leaky.items(): + l0_val = calc_ci_l_zero(layer_ci, config.ci_alive_threshold) + microbatch_log_data[f"train/l0/{layer_name}"] += ( + l0_val / config.gradient_accumulation_steps + ) + + # --- Train Logging --- # + if step % config.train_log_freq == 0: + avg_metrics = avg_metrics_across_ranks(microbatch_log_data, device=device) + microbatch_log_data = cast(defaultdict[str, float], avg_metrics) + + grad_norms = get_grad_norms_dict(component_model, device) + dict_safe_update_( + microbatch_log_data, {f"train/grad_norms/{k}": v for k, v in grad_norms.items()} + ) + + microbatch_log_data["train/schedules/lr"] = step_lr + + if is_main_process(): + assert out_dir is not None + tqdm.write(f"--- Step {step} ---") + tqdm.write(f"LR: {step_lr:.6f}") + for name, value in microbatch_log_data.items(): + tqdm.write(f"{name}: {value:.15f}") + local_log(microbatch_log_data, step, out_dir) + if config.wandb_project: + try_wandb(wandb.log, microbatch_log_data, step=step) + + # --- Evaluation --- # + if step % config.eval_freq == 0: + with torch.no_grad(): + slow_step: bool = ( + config.slow_eval_on_first_step + if step == 0 + else step % config.slow_eval_freq == 0 + ) + + metrics = evaluate( + eval_metric_configs=eval_metric_configs, + model=component_model, # No backward passes so DDP wrapped_model not needed + eval_iterator=eval_iterator, + device=device, + run_config=config, + slow_step=slow_step, + n_eval_steps=n_eval_steps, + current_frac_of_training=step / config.steps, + reconstruction_loss=reconstruction_loss, + ) + + if is_main_process(): + assert out_dir is not None + for k, v in metrics.items(): + tqdm.write(f"eval/{k}: {v}") + local_log(metrics, step, out_dir) + if config.wandb_project: + wandb_logs = { + f"eval/{k}": wandb.Image(v) if isinstance(v, Image.Image) else v + for k, v in metrics.items() + } + try_wandb(wandb.log, wandb_logs, step=step) + + del metrics + + gc.collect() + torch.cuda.empty_cache() + + # --- Saving Checkpoint --- # + if ( + (config.save_freq is not None and step % config.save_freq == 0 and step > 0) + or step == config.steps + ) and is_main_process(): + assert out_dir is not None + # Save the state dict of the underlying module (not DDP wrapper) + save_file(component_model.state_dict(), out_dir / f"model_{step}.pth") + logger.info(f"Saved model, optimizer, and out_dir to {out_dir}") + if config.wandb_project: + try_wandb( + wandb.save, + str(out_dir / f"model_{step}.pth"), + base_path=str(out_dir), + policy="now", + ) + + sync_across_processes() + if config.grad_clip_norm_components is not None: + clip_grad_norm_(component_params, config.grad_clip_norm_components) + if config.grad_clip_norm_ci_fns is not None: + clip_grad_norm_(ci_fn_params, config.grad_clip_norm_ci_fns) + optimizer.step() + + if is_main_process(): + logger.info("Finished training loop.") diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 04cb6727c..55efbb3d9 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -593,6 +593,7 @@ def create_view_and_report( _n_try_wandb_comm_errors = 0 +# this exists to stop infra issues from crashing training runs def try_wandb[**P, T](wandb_fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T | None: """Attempts to call `wandb_fn` and if it fails with a wandb CommError, logs a warning and returns None. The choice of wandb CommError is to catch issues communicating with the wandb server but diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 7a402ea54..992edd6ab 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -22,7 +22,8 @@ from spd.app.backend.server import app from spd.app.backend.state import HarvestCache, RunState, StateManager from spd.configs import Config, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig -from spd.models.component_model import ComponentModel, make_run_batch_lm, recon_loss_kl +from spd.experiments.lm.loaders import LogitsOnlyWrapper +from spd.models.component_model import ComponentModel from spd.utils.module_utils import expand_module_patterns DEVICE = "cpu" @@ -112,14 +113,14 @@ def app_with_state(): ), ) module_path_info = expand_module_patterns(target_model, config.module_info) + # Wrap the model to extract logits from (logits, loss) tuple outputs + wrapped_model = LogitsOnlyWrapper(target_model) model = ComponentModel( - target_model=target_model, + target_model=wrapped_model, module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, sigmoid_type=config.sigmoid_type, - run_batch=make_run_batch_lm(config.pretrained_model_output_attr), - reconstruction_loss=recon_loss_kl, ) model.eval() sources_by_target = get_sources_by_target( diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index b10e2a405..ae97199ff 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -7,14 +7,10 @@ from jaxtyping import Float from torch import Tensor -from spd.models.component_model import ComponentModel, recon_loss_mse +from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo -def _test_run_batch(target_model: nn.Module, batch: Tensor) -> Tensor: - return target_model(batch) - - class OneLayerLinearModel(nn.Module): """One-layer linear model for testing.""" @@ -63,8 +59,6 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo ci_fn_hidden_dims=[2], ci_fn_type="mlp", sigmoid_type="leaky_hard", - run_batch=_test_run_batch, - reconstruction_loss=recon_loss_mse, ) return comp_model @@ -101,8 +95,6 @@ def make_two_layer_component_model( ci_fn_hidden_dims=[2], ci_fn_type="mlp", sigmoid_type="leaky_hard", - run_batch=_test_run_batch, - reconstruction_loss=recon_loss_mse, ) return comp_model diff --git a/tests/metrics/test_ci_masked_recon_layerwise_loss.py b/tests/metrics/test_ci_masked_recon_layerwise_loss.py index 045b8aca0..04ef2609e 100644 --- a/tests/metrics/test_ci_masked_recon_layerwise_loss.py +++ b/tests/metrics/test_ci_masked_recon_layerwise_loss.py @@ -1,6 +1,7 @@ import torch from spd.metrics import ci_masked_recon_layerwise_loss, ci_masked_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model, make_two_layer_component_model @@ -43,7 +44,11 @@ def test_two_layer_manual_calculation(self: object) -> None: # Calculate actual loss actual_loss = ci_masked_recon_layerwise_loss( - model=model, batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( @@ -59,9 +64,19 @@ def test_layerwise_vs_all_layer(self: object) -> None: target_out = torch.randn(1, 2, dtype=torch.float32) ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} - loss_all = ci_masked_recon_loss(model=model, batch=batch, target_out=target_out, ci=ci) + loss_all = ci_masked_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, + ) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) # For single layer, results should be the same diff --git a/tests/metrics/test_ci_masked_recon_loss.py b/tests/metrics/test_ci_masked_recon_loss.py index 63a94d50e..7635d2757 100644 --- a/tests/metrics/test_ci_masked_recon_loss.py +++ b/tests/metrics/test_ci_masked_recon_loss.py @@ -1,6 +1,7 @@ import torch from spd.metrics import ci_masked_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model @@ -25,7 +26,13 @@ def test_manual_calculation(self: object) -> None: expected_loss = torch.nn.functional.mse_loss(out, target_out) # Calculate actual loss - actual_loss = ci_masked_recon_loss(model=model, batch=batch, target_out=target_out, ci=ci) + actual_loss = ci_masked_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, + ) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( f"Expected {expected_loss}, got {actual_loss}" @@ -43,10 +50,18 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, batch=batch, target_out=target_out, ci=ci_full + model=model, + batch=batch, + target_out=target_out, + ci=ci_full, + reconstruction_loss=recon_loss_mse, ) loss_half = ci_masked_recon_loss( - model=model, batch=batch, target_out=target_out, ci=ci_half + model=model, + batch=batch, + target_out=target_out, + ci=ci_half, + reconstruction_loss=recon_loss_mse, ) # Different CI values should produce different losses diff --git a/tests/metrics/test_ci_masked_recon_subset_loss.py b/tests/metrics/test_ci_masked_recon_subset_loss.py index 6adb7e836..9c5661b3f 100644 --- a/tests/metrics/test_ci_masked_recon_subset_loss.py +++ b/tests/metrics/test_ci_masked_recon_subset_loss.py @@ -5,6 +5,7 @@ from spd.configs import UniformKSubsetRoutingConfig from spd.metrics import ci_masked_recon_subset_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model @@ -81,6 +82,7 @@ def mock_sample_uniform_k_subset_routing_masks( target_out=target_out, ci=ci, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) actual_losses.append(actual_loss.item()) diff --git a/tests/metrics/test_stochastic_recon_layerwise_loss.py b/tests/metrics/test_stochastic_recon_layerwise_loss.py index e22a61c5f..2f4d97515 100644 --- a/tests/metrics/test_stochastic_recon_layerwise_loss.py +++ b/tests/metrics/test_stochastic_recon_layerwise_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType from spd.metrics import stochastic_recon_layerwise_loss, stochastic_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model, make_two_layer_component_model @@ -109,6 +110,7 @@ def mock_calc_stochastic_component_mask_info( target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( @@ -133,6 +135,7 @@ def test_layerwise_vs_full_loss(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) loss_layerwise = stochastic_recon_layerwise_loss( model=model, @@ -142,6 +145,7 @@ def test_layerwise_vs_full_loss(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) # For single layer, results should be the same diff --git a/tests/metrics/test_stochastic_recon_loss.py b/tests/metrics/test_stochastic_recon_loss.py index 1ee7b723c..e20e25f84 100644 --- a/tests/metrics/test_stochastic_recon_loss.py +++ b/tests/metrics/test_stochastic_recon_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType from spd.metrics import stochastic_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model @@ -82,6 +83,7 @@ def mock_calc_stochastic_component_mask_info( target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( diff --git a/tests/metrics/test_stochastic_recon_subset_loss.py b/tests/metrics/test_stochastic_recon_subset_loss.py index 0d889448c..54c928c39 100644 --- a/tests/metrics/test_stochastic_recon_subset_loss.py +++ b/tests/metrics/test_stochastic_recon_subset_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType, UniformKSubsetRoutingConfig from spd.metrics import stochastic_recon_subset_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model @@ -97,6 +98,7 @@ def mock_calc_stochastic_component_mask_info( ci=ci, weight_deltas=None, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) assert torch.allclose(actual_loss, torch.tensor(expected_loss), rtol=1e-5), ( diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 2ee5e79f4..eee72c77c 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -20,8 +20,6 @@ from spd.models.component_model import ( ComponentModel, SPDRunInfo, - pass_batch_directly_to_model, - recon_loss_mse, ) from spd.models.components import ( ComponentsMaskInfo, @@ -94,8 +92,6 @@ def test_correct_parameters_require_grad(): ci_fn_type="mlp", ci_fn_hidden_dims=[4], sigmoid_type="leaky_hard", - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) for module_path, components in component_model.components.items(): @@ -179,17 +175,37 @@ def test_from_run_info(): ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, sigmoid_type=config.sigmoid_type, - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) save_file(cm.state_dict(), comp_model_dir / "model.pth") save_file(config.model_dump(mode="json"), comp_model_dir / "final_config.yaml") cm_run_info = SPDRunInfo.from_path(comp_model_dir / "model.pth") - cm_loaded = ComponentModel.from_run_info(cm_run_info) assert config == cm_run_info.config + + # Manually reconstruct component model and load state dict + assert cm_run_info.config.pretrained_model_path is not None + loaded_target = SimpleTestModel.from_pretrained(cm_run_info.config.pretrained_model_path) + loaded_target.eval() + loaded_target.requires_grad_(False) + if cm_run_info.config.identity_module_info is not None: + insert_identity_operations_( + loaded_target, + identity_module_info=cm_run_info.config.identity_module_info, + ) + loaded_module_path_info = expand_module_patterns( + loaded_target, cm_run_info.config.all_module_info + ) + cm_loaded = ComponentModel( + target_model=loaded_target, + module_path_info=loaded_module_path_info, + ci_fn_type=cm_run_info.config.ci_fn_type, + ci_fn_hidden_dims=cm_run_info.config.ci_fn_hidden_dims, + sigmoid_type=cm_run_info.config.sigmoid_type, + ) + cm_loaded.load_state_dict(torch.load(cm_run_info.checkpoint_path)) + for k, v in cm_loaded.state_dict().items(): torch.testing.assert_close(v, cm.state_dict()[k]) @@ -287,8 +303,6 @@ def test_full_weight_delta_matches_target_behaviour(): ci_fn_type="mlp", ci_fn_hidden_dims=[4], sigmoid_type="leaky_hard", - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) token_ids = torch.randint( @@ -320,8 +334,6 @@ def test_input_cache_captures_pre_weight_input(): ci_fn_type="mlp", ci_fn_hidden_dims=[2], sigmoid_type="leaky_hard", - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) # WHEN we forward the component model with input caching @@ -356,8 +368,6 @@ def test_weight_deltas(): ci_fn_type="mlp", ci_fn_hidden_dims=[2], sigmoid_type="leaky_hard", - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) # THEN the weight deltas match the target weight @@ -392,8 +402,6 @@ def forward(self, x: Tensor) -> Tensor: ci_fn_type="mlp", ci_fn_hidden_dims=[2], sigmoid_type="leaky_hard", - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) # WHEN we set the target model weights to be UV @@ -449,8 +457,6 @@ def forward(self, x: Tensor) -> Tensor: ci_fn_type="mlp", ci_fn_hidden_dims=[2], sigmoid_type="leaky_hard", - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) # and a random input @@ -500,8 +506,6 @@ def forward(self, x: Tensor) -> Tensor: ci_fn_type="mlp", ci_fn_hidden_dims=[2], sigmoid_type="leaky_hard", - run_batch=pass_batch_directly_to_model, - reconstruction_loss=recon_loss_mse, ) # and a random input diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index 86e772bad..308249248 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -16,7 +16,7 @@ ) from spd.data import DatasetConfig, create_data_loader from spd.identity_insertion import insert_identity_operations_ -from spd.models.component_model import make_run_batch_lm, recon_loss_kl +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize from spd.utils.general_utils import resolve_class, set_seed @@ -151,8 +151,6 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=make_run_batch_lm(config.pretrained_model_output_attr), reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index f489cade8..aa0122004 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -17,7 +17,7 @@ from spd.experiments.ih.configs import InductionModelConfig from spd.experiments.ih.model import InductionTransformer from spd.identity_insertion import insert_identity_operations_ -from spd.models.component_model import make_run_batch_lm, recon_loss_kl +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.general_utils import set_seed @@ -135,8 +135,6 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=make_run_batch_lm(config.pretrained_model_output_attr), reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 935312519..f31dd122b 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -13,7 +13,7 @@ from spd.experiments.resid_mlp.models import ResidMLP from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.identity_insertion import insert_identity_operations_ -from spd.models.component_model import make_run_batch_lm, recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.general_utils import set_seed @@ -131,8 +131,6 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=make_run_batch_lm(config.pretrained_model_output_attr), reconstruction_loss=recon_loss_mse, out_dir=tmp_path, ) diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 6631a56ca..b83a4fe3e 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -16,14 +16,11 @@ stochastic_recon_loss, stochastic_recon_subset_loss, ) -from spd.models.component_model import ComponentModel, recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse +from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo -def _test_run_batch(target_model: nn.Module, batch: Tensor) -> Tensor: - return target_model(batch) - - class TinyLinearModel(nn.Module): def __init__(self, d_in: int, d_out: int) -> None: super().__init__() @@ -47,8 +44,6 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel ci_fn_hidden_dims=[2], ci_fn_type="mlp", sigmoid_type="leaky_hard", - run_batch=_test_run_batch, - reconstruction_loss=recon_loss_mse, ) return comp_model @@ -287,6 +282,7 @@ def test_mse_loss_basic(self: object) -> None: batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=recon_loss_mse, ) # Since we're using a simple identity-like weight, and CI is 1, @@ -311,6 +307,7 @@ def test_kl_loss_basic(self: object) -> None: batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -327,10 +324,18 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, batch=batch, target_out=target_out, ci=ci_full + model=model, + batch=batch, + target_out=target_out, + ci=ci_full, + reconstruction_loss=recon_loss_mse, ) loss_half = ci_masked_recon_loss( - model=model, batch=batch, target_out=target_out, ci=ci_half + model=model, + batch=batch, + target_out=target_out, + ci=ci_half, + reconstruction_loss=recon_loss_mse, ) # Different CI values should produce different losses @@ -352,6 +357,7 @@ def test_layerwise_basic(self: object) -> None: batch=batch, target_out=target_out, ci=ci, + reconstruction_loss=recon_loss_mse, ) # Layerwise should produce a valid loss @@ -367,9 +373,19 @@ def test_layerwise_vs_all_layer(self: object) -> None: target_out = torch.tensor([[1.0, 2.0]], dtype=torch.float32) ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} - loss_all = ci_masked_recon_loss(model=model, batch=batch, target_out=target_out, ci=ci) + loss_all = ci_masked_recon_loss( + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, + ) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, batch=batch, target_out=target_out, ci=ci + model=model, + batch=batch, + target_out=target_out, + ci=ci, + reconstruction_loss=recon_loss_mse, ) # For single layer, results should be the same @@ -392,6 +408,7 @@ def test_subset_basic(self: object) -> None: target_out=target_out, ci=ci, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) # Subset routing should produce a valid loss @@ -414,6 +431,7 @@ def test_subset_stochastic_behavior(self: object) -> None: target_out=target_out, ci=ci, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) for _ in range(3) ] @@ -441,6 +459,7 @@ def test_continuous_sampling_basic(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -463,6 +482,7 @@ def test_binomial_sampling_basic(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -487,6 +507,7 @@ def test_multiple_mask_samples(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -508,6 +529,7 @@ def test_with_and_without_delta_component(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) loss_without_delta = stochastic_recon_loss( @@ -518,6 +540,7 @@ def test_with_and_without_delta_component(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=None, + reconstruction_loss=recon_loss_mse, ) # Both should be valid @@ -544,6 +567,7 @@ def test_layerwise_stochastic_basic(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -567,6 +591,7 @@ def test_layerwise_multiple_samples(self: object) -> None: target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -591,6 +616,7 @@ def test_subset_stochastic_basic(self: object) -> None: ci=ci, weight_deltas=weight_deltas, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -614,6 +640,7 @@ def test_subset_with_binomial_sampling(self: object) -> None: ci=ci, weight_deltas=weight_deltas, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) assert result >= 0.0 @@ -638,6 +665,7 @@ def test_subset_stochastic_variability(self: object) -> None: ci=ci, weight_deltas=weight_deltas, routing=UniformKSubsetRoutingConfig(), + reconstruction_loss=recon_loss_mse, ) for _ in range(3) ] diff --git a/tests/test_tms.py b/tests/test_tms.py index f695bf7d4..9b19cb0c2 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -18,7 +18,7 @@ from spd.experiments.tms.models import TMSModel from spd.experiments.tms.train_tms import get_model_and_dataloader, train from spd.identity_insertion import insert_identity_operations_ -from spd.models.component_model import make_run_batch_lm, recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.general_utils import set_seed @@ -139,8 +139,6 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, - run_batch=make_run_batch_lm(config.pretrained_model_output_attr), reconstruction_loss=recon_loss_mse, out_dir=tmp_path, tied_weights=tied_weights, diff --git a/tests/test_wandb_run_loading.py b/tests/test_wandb_run_loading.py index 463ed1247..be150d17e 100644 --- a/tests/test_wandb_run_loading.py +++ b/tests/test_wandb_run_loading.py @@ -9,22 +9,31 @@ import pytest +from spd.experiments.lm.loaders import load_lm_component_model_from_run_info +from spd.experiments.resid_mlp.models import load_resid_mlp_component_model_from_run_info +from spd.experiments.tms.models import load_tms_component_model_from_run_info from spd.models.component_model import ComponentModel, SPDRunInfo from spd.registry import EXPERIMENT_REGISTRY from spd.utils.wandb_utils import parse_wandb_run_path -def from_run_info(canonical_run: str) -> ComponentModel[Any, Any]: +def load_component_model(canonical_run: str, task_name: str) -> ComponentModel[Any, Any]: + """Load a ComponentModel using the appropriate experiment-specific loader.""" run_info = SPDRunInfo.from_path(canonical_run) - return ComponentModel.from_run_info(run_info) + loaders: dict[str, Any] = { + "tms": load_tms_component_model_from_run_info, + "resid_mlp": load_resid_mlp_component_model_from_run_info, + "lm": load_lm_component_model_from_run_info, + } -def from_pretrained(canonical_run: str) -> ComponentModel[Any, Any]: - return ComponentModel.from_pretrained(canonical_run) + loader = loaders.get(task_name) + assert loader is not None, f"No loader found for task_name: {task_name}" + return loader(run_info) CANONICAL_EXPS = [ - (exp_name, exp_config.canonical_run) + (exp_name, exp_config.canonical_run, exp_config.task_name) for exp_name, exp_config in EXPERIMENT_REGISTRY.items() if exp_config.canonical_run is not None ] @@ -32,19 +41,12 @@ def from_pretrained(canonical_run: str) -> ComponentModel[Any, Any]: @pytest.mark.requires_wandb @pytest.mark.slow -@pytest.mark.parametrize("exp_name, canonical_run", CANONICAL_EXPS) -def test_loading_from_wandb(exp_name: str, canonical_run: str) -> None: - # We put both from_run_info and from_pretrained in the same test to avoid distributed read - # errors from the same wandb cache +@pytest.mark.parametrize("exp_name, canonical_run, task_name", CANONICAL_EXPS) +def test_loading_from_wandb(exp_name: str, canonical_run: str, task_name: str) -> None: try: - from_run_info(canonical_run) + load_component_model(canonical_run, task_name) except Exception as e: - e.add_note(f"Error with from_run_info for {exp_name} from {canonical_run}") - raise e - try: - from_pretrained(canonical_run) - except Exception as e: - e.add_note(f"Error with from_pretrained for {exp_name} from {canonical_run}") + e.add_note(f"Error loading {exp_name} from {canonical_run} (task: {task_name})") raise e From e7125735dbc277b905eb0f071f911111a621f042 Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Mon, 2 Feb 2026 18:59:09 +0000 Subject: [PATCH 03/16] >>>>>>> dev/app --- spd/app/backend/compute.py | 4 ++-- spd/app/backend/routers/dataset_attributions.py | 2 +- spd/dataset_attributions/harvester.py | 4 ++-- spd/experiments/lm/loaders.py | 10 +++++++--- spd/models/component_model.py | 6 +----- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 3b47d6203..7061c5839 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -166,7 +166,7 @@ def wte_hook( wte_cache["wte_post_detach"] = output return output - wte = getattr(model.target_model, "wte") + wte = model.target_model.wte assert isinstance(wte, nn.Module), "wte is not a module" wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) @@ -343,7 +343,7 @@ def compute_edges_from_ci( # Setup wte hook and run forward pass for gradient computation wte_hook, wte_cache = _setup_wte_hook() - wte = getattr(model.target_model, "wte") + wte = model.target_model.wte assert isinstance(wte, nn.Module), "wte is not a module" wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 06bc55725..fa38f5146 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -85,7 +85,7 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: """Get the unembedding matrix from the loaded model.""" - lm_head = getattr(loaded.model.target_model, "lm_head") + lm_head = loaded.model.target_model.lm_head assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" return lm_head.weight.T.detach() diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 1b86cc300..54915ce8f 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -74,7 +74,7 @@ def __init__( # For output targets: store attributions to output residual dimensions assert hasattr(model.target_model, "lm_head"), "Model must have lm_head" - lm_head = getattr(model.target_model, "lm_head") + lm_head = model.target_model.lm_head assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" self.d_model = lm_head.in_features self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) @@ -143,7 +143,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - wte = getattr(self.model.target_model, "wte") + wte = self.model.target_model.wte assert isinstance(wte, nn.Module) h1 = wte.register_forward_hook(wte_hook, with_kwargs=True) h2 = self.lm_head.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) diff --git a/spd/experiments/lm/loaders.py b/spd/experiments/lm/loaders.py index 242614a8f..8e9fe11c9 100644 --- a/spd/experiments/lm/loaders.py +++ b/spd/experiments/lm/loaders.py @@ -44,14 +44,18 @@ def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[tuple[str, Parameter]]: # Delegate to wrapped model so parameter names don't have "model." prefix - return self.model.named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + return self.model.named_parameters( + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate + ) @override def named_modules( self, memo: set[nn.Module] | None = None, prefix: str = "", remove_duplicate: bool = True - ) -> Generator[tuple[str, nn.Module], None, None]: + ) -> Generator[tuple[str, nn.Module]]: # Delegate to wrapped model so module names don't have "model." prefix - yield from self.model.named_modules(memo=memo, prefix=prefix, remove_duplicate=remove_duplicate) + yield from self.model.named_modules( + memo=memo, prefix=prefix, remove_duplicate=remove_duplicate + ) @override def __getattr__(self, name: str) -> Any: diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 84c9c5a9e..194d37286 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -52,16 +52,12 @@ class CIOutputs: class TargetModel[BatchT, OutputT](Protocol): - # def __call__(self, batch: BatchT) -> OutputT: ... - def __call__(self, batch: BatchT) -> OutputT: ... + # stubs of pytorch methods for the type checker. you almost certainly don't actually need to implement these. def get_submodule(self, target: str) -> nn.Module: ... - def named_parameters(self) -> Iterator[tuple[str, nn.Parameter]]: ... - # def named_modules(self) -> Generator[tuple[str, nn.Module]]: ... - class ComponentModel[BatchT, OutputT](nn.Module): """Wrapper around an arbitrary pytorch model for running SPD. From 2312fe6786c973fe55b132a8e04e7e6c96e5ab93 Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Mon, 2 Feb 2026 19:00:08 +0000 Subject: [PATCH 04/16] Revert ">>>>>>> dev/app" This reverts commit e7125735dbc277b905eb0f071f911111a621f042. --- spd/app/backend/compute.py | 4 ++-- spd/app/backend/routers/dataset_attributions.py | 2 +- spd/dataset_attributions/harvester.py | 4 ++-- spd/experiments/lm/loaders.py | 10 +++------- spd/models/component_model.py | 6 +++++- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 7061c5839..3b47d6203 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -166,7 +166,7 @@ def wte_hook( wte_cache["wte_post_detach"] = output return output - wte = model.target_model.wte + wte = getattr(model.target_model, "wte") assert isinstance(wte, nn.Module), "wte is not a module" wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) @@ -343,7 +343,7 @@ def compute_edges_from_ci( # Setup wte hook and run forward pass for gradient computation wte_hook, wte_cache = _setup_wte_hook() - wte = model.target_model.wte + wte = getattr(model.target_model, "wte") assert isinstance(wte, nn.Module), "wte is not a module" wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index fa38f5146..06bc55725 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -85,7 +85,7 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: """Get the unembedding matrix from the loaded model.""" - lm_head = loaded.model.target_model.lm_head + lm_head = getattr(loaded.model.target_model, "lm_head") assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" return lm_head.weight.T.detach() diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 54915ce8f..1b86cc300 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -74,7 +74,7 @@ def __init__( # For output targets: store attributions to output residual dimensions assert hasattr(model.target_model, "lm_head"), "Model must have lm_head" - lm_head = model.target_model.lm_head + lm_head = getattr(model.target_model, "lm_head") assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" self.d_model = lm_head.in_features self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) @@ -143,7 +143,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - wte = self.model.target_model.wte + wte = getattr(self.model.target_model, "wte") assert isinstance(wte, nn.Module) h1 = wte.register_forward_hook(wte_hook, with_kwargs=True) h2 = self.lm_head.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) diff --git a/spd/experiments/lm/loaders.py b/spd/experiments/lm/loaders.py index 8e9fe11c9..242614a8f 100644 --- a/spd/experiments/lm/loaders.py +++ b/spd/experiments/lm/loaders.py @@ -44,18 +44,14 @@ def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[tuple[str, Parameter]]: # Delegate to wrapped model so parameter names don't have "model." prefix - return self.model.named_parameters( - prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate - ) + return self.model.named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) @override def named_modules( self, memo: set[nn.Module] | None = None, prefix: str = "", remove_duplicate: bool = True - ) -> Generator[tuple[str, nn.Module]]: + ) -> Generator[tuple[str, nn.Module], None, None]: # Delegate to wrapped model so module names don't have "model." prefix - yield from self.model.named_modules( - memo=memo, prefix=prefix, remove_duplicate=remove_duplicate - ) + yield from self.model.named_modules(memo=memo, prefix=prefix, remove_duplicate=remove_duplicate) @override def __getattr__(self, name: str) -> Any: diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 194d37286..84c9c5a9e 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -52,12 +52,16 @@ class CIOutputs: class TargetModel[BatchT, OutputT](Protocol): + # def __call__(self, batch: BatchT) -> OutputT: ... + def __call__(self, batch: BatchT) -> OutputT: ... - # stubs of pytorch methods for the type checker. you almost certainly don't actually need to implement these. def get_submodule(self, target: str) -> nn.Module: ... + def named_parameters(self) -> Iterator[tuple[str, nn.Parameter]]: ... + # def named_modules(self) -> Generator[tuple[str, nn.Module]]: ... + class ComponentModel[BatchT, OutputT](nn.Module): """Wrapper around an arbitrary pytorch model for running SPD. From 357898a44c750872067f8814b62d7cbb97ef0afc Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Mon, 2 Feb 2026 19:19:13 +0000 Subject: [PATCH 05/16] wip: Replace getattr with cast for type safety on model attributes --- spd/app/backend/compute.py | 4 ++-- spd/app/backend/routers/dataset_attributions.py | 4 ++-- spd/dataset_attributions/harvester.py | 6 +++--- spd/experiments/lm/loaders.py | 10 +++++++--- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 3b47d6203..f8f9f2814 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -166,7 +166,7 @@ def wte_hook( wte_cache["wte_post_detach"] = output return output - wte = getattr(model.target_model, "wte") + wte = cast(Any, model.target_model).wte assert isinstance(wte, nn.Module), "wte is not a module" wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) @@ -343,7 +343,7 @@ def compute_edges_from_ci( # Setup wte hook and run forward pass for gradient computation wte_hook, wte_cache = _setup_wte_hook() - wte = getattr(model.target_model, "wte") + wte = cast(Any, model.target_model).wte assert isinstance(wte, nn.Module), "wte is not a module" wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index 06bc55725..e32f2e372 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -4,7 +4,7 @@ over the full training dataset. """ -from typing import Annotated, Literal +from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, HTTPException, Query from jaxtyping import Float @@ -85,7 +85,7 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: """Get the unembedding matrix from the loaded model.""" - lm_head = getattr(loaded.model.target_model, "lm_head") + lm_head = cast(Any, loaded.model.target_model).lm_head assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" return lm_head.weight.T.detach() diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 1b86cc300..c6126ef17 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -10,7 +10,7 @@ Output attributions computed on-the-fly at query time via w_unembed """ -from typing import Any +from typing import Any, cast import torch from jaxtyping import Bool, Float, Int @@ -74,7 +74,7 @@ def __init__( # For output targets: store attributions to output residual dimensions assert hasattr(model.target_model, "lm_head"), "Model must have lm_head" - lm_head = getattr(model.target_model, "lm_head") + lm_head = cast(Any, model.target_model).lm_head assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" self.d_model = lm_head.in_features self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) @@ -143,7 +143,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - wte = getattr(self.model.target_model, "wte") + wte = cast(Any, self.model.target_model).wte assert isinstance(wte, nn.Module) h1 = wte.register_forward_hook(wte_hook, with_kwargs=True) h2 = self.lm_head.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) diff --git a/spd/experiments/lm/loaders.py b/spd/experiments/lm/loaders.py index 242614a8f..8e9fe11c9 100644 --- a/spd/experiments/lm/loaders.py +++ b/spd/experiments/lm/loaders.py @@ -44,14 +44,18 @@ def named_parameters( self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True ) -> Iterator[tuple[str, Parameter]]: # Delegate to wrapped model so parameter names don't have "model." prefix - return self.model.named_parameters(prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate) + return self.model.named_parameters( + prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate + ) @override def named_modules( self, memo: set[nn.Module] | None = None, prefix: str = "", remove_duplicate: bool = True - ) -> Generator[tuple[str, nn.Module], None, None]: + ) -> Generator[tuple[str, nn.Module]]: # Delegate to wrapped model so module names don't have "model." prefix - yield from self.model.named_modules(memo=memo, prefix=prefix, remove_duplicate=remove_duplicate) + yield from self.model.named_modules( + memo=memo, prefix=prefix, remove_duplicate=remove_duplicate + ) @override def __getattr__(self, name: str) -> Any: From 8fdfa81ff38caeafa7a97c5d344351668d0320fd Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 6 Feb 2026 19:28:41 +0000 Subject: [PATCH 06/16] Remove accidentally added files --- .claude/.nfs2f6abdf93653d08500002cba | 22 ---------------------- .claude/.nfs582d9ab79662f72a00003620 | 23 ----------------------- 2 files changed, 45 deletions(-) delete mode 100644 .claude/.nfs2f6abdf93653d08500002cba delete mode 100644 .claude/.nfs582d9ab79662f72a00003620 diff --git a/.claude/.nfs2f6abdf93653d08500002cba b/.claude/.nfs2f6abdf93653d08500002cba deleted file mode 100644 index dbd210f9b..000000000 --- a/.claude/.nfs2f6abdf93653d08500002cba +++ /dev/null @@ -1,22 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(source:*)", - "Bash(npm run check:*)", - "Bash(make check-app:*)", - "Bash(npm run lint:*)", - "Bash(git stash push:*)", - "Bash(grep:*)", - "Bash(npm run format:*)", - "Bash(npm run build:*)", - "Bash(npx eslint:*)", - "Bash(npx prettier:*)", - "Bash(git add:*)", - "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nAdd \"Use as Prompt\" popup for selected text in dataset explorer\n\n- Select text within story content to show floating popup\n- \"Use as Prompt\" button creates a custom prompt from selection\n- Text is cleaned: newlines → spaces, whitespace collapsed, trimmed\n- Shows hint when no run is loaded\n- Only triggers on .story-text elements \\(not headers/tags\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")", - "Bash(git revert:*)", - "Bash(python:*)", - "Bash(make:*)", - "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nOptimize random sampling and hide zero occurrence badges\n\n- Use random indices instead of shuffling entire dataset \\(~100x faster\\)\n- Hide occurrence badge when count is 0 \\(for random samples\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")" - ] - } -} diff --git a/.claude/.nfs582d9ab79662f72a00003620 b/.claude/.nfs582d9ab79662f72a00003620 deleted file mode 100644 index d67075d64..000000000 --- a/.claude/.nfs582d9ab79662f72a00003620 +++ /dev/null @@ -1,23 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(source:*)", - "Bash(npm run check:*)", - "Bash(make check-app:*)", - "Bash(npm run lint:*)", - "Bash(git stash push:*)", - "Bash(grep:*)", - "Bash(npm run format:*)", - "Bash(npm run build:*)", - "Bash(npx eslint:*)", - "Bash(npx prettier:*)", - "Bash(git add:*)", - "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nAdd \"Use as Prompt\" popup for selected text in dataset explorer\n\n- Select text within story content to show floating popup\n- \"Use as Prompt\" button creates a custom prompt from selection\n- Text is cleaned: newlines → spaces, whitespace collapsed, trimmed\n- Shows hint when no run is loaded\n- Only triggers on .story-text elements \\(not headers/tags\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")", - "Bash(git revert:*)", - "Bash(python:*)", - "Bash(make:*)", - "Bash(SKIP=type git commit -m \"$\\(cat <<''EOF''\nOptimize random sampling and hide zero occurrence badges\n\n- Use random indices instead of shuffling entire dataset \\(~100x faster\\)\n- Hide occurrence badge when count is 0 \\(for random samples\\)\n\nCo-Authored-By: Claude Opus 4.5 \nEOF\n\\)\")", - "Bash(git commit:*)" - ] - } -} From 539edb2a29b8500d29fe210cd1e2b0a72b9e261f Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Fri, 6 Feb 2026 21:07:06 +0000 Subject: [PATCH 07/16] Remove LogitsOnlyWrapper and add back ComponentModel.from_run_info --- spd/app/backend/compute.py | 6 - spd/app/backend/routers/runs.py | 5 +- spd/autointerp/interpret.py | 5 +- spd/clustering/dataset.py | 6 +- spd/clustering/scripts/run_clustering.py | 5 +- spd/data.py | 9 +- spd/dataset_attributions/harvest.py | 3 +- spd/experiments/ih/model.py | 7 +- spd/experiments/lm/lm_decomposition.py | 4 +- spd/experiments/lm/loaders.py | 118 -------- spd/experiments/resid_mlp/models.py | 51 +--- spd/experiments/resid_mlp/resid_mlp_interp.py | 7 +- spd/experiments/tms/models.py | 48 ---- spd/experiments/tms/plotting.py | 5 +- spd/harvest/harvest.py | 3 +- spd/metrics/identity_ci_error.py | 5 +- spd/metrics/permuted_ci_plots.py | 5 +- spd/metrics/uv_plots.py | 5 +- spd/models/component_model.py | 72 ++++- spd/scripts/compare_models/compare_models.py | 3 +- spd/simple_trainer.py | 251 ------------------ tests/app/test_server_api.py | 5 +- tests/test_distributed.py | 9 +- tests/test_gpt2.py | 4 +- tests/test_wandb_run_loading.py | 32 +-- 25 files changed, 127 insertions(+), 546 deletions(-) delete mode 100644 spd/experiments/lm/loaders.py delete mode 100644 spd/simple_trainer.py diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index f8f9f2814..11a0637f3 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -770,12 +770,6 @@ def get_model_n_blocks(model: nn.Module) -> int: from simple_stories_train.models.llama_simple_mlp import LlamaSimpleMLP from transformers.models.gpt2 import GPT2LMHeadModel - from spd.experiments.lm.loaders import LogitsOnlyWrapper - - # Unwrap LogitsOnlyWrapper if present - if isinstance(model, LogitsOnlyWrapper): - model = model.model - match model: case GPT2LMHeadModel(): return len(model.transformer.h) diff --git a/spd/app/backend/routers/runs.py b/spd/app/backend/routers/runs.py index a7ce9a8b3..c62df6688 100644 --- a/spd/app/backend/routers/runs.py +++ b/spd/app/backend/routers/runs.py @@ -14,9 +14,8 @@ from spd.app.backend.dependencies import DepStateManager from spd.app.backend.state import HarvestCache, RunState from spd.app.backend.utils import build_token_lookup, log_errors -from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.log import logger -from spd.models.component_model import SPDRunInfo +from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device from spd.utils.wandb_utils import parse_wandb_run_path @@ -93,7 +92,7 @@ def load_run(wandb_path: str, context_length: int, manager: DepStateManager): # Load the model logger.info(f"[API] Loading model for run {run.id}: {run.wandb_path}") - model = load_lm_component_model_from_run_info(run_info) + model = ComponentModel.from_run_info(run_info) model = model.to(DEVICE) model.eval() diff --git a/spd/autointerp/interpret.py b/spd/autointerp/interpret.py index 571ba3200..cbf892ad7 100644 --- a/spd/autointerp/interpret.py +++ b/spd/autointerp/interpret.py @@ -28,13 +28,12 @@ from spd.autointerp.prompt_template import INTERPRETATION_SCHEMA, format_prompt_template from spd.autointerp.schemas import ArchitectureInfo, InterpretationResult from spd.configs import LMTaskConfig -from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.harvest.analysis import TokenPRLift, get_input_token_stats, get_output_token_stats from spd.harvest.harvest import HarvestResult from spd.harvest.schemas import ComponentData from spd.harvest.storage import TokenStatsStorage from spd.log import logger -from spd.models.component_model import SPDRunInfo +from spd.models.component_model import ComponentModel, SPDRunInfo # Retry config MAX_RETRIES = 8 @@ -337,7 +336,7 @@ async def process_one( def get_architecture_info(wandb_path: str) -> ArchitectureInfo: run_info = SPDRunInfo.from_path(wandb_path) - model = load_lm_component_model_from_run_info(run_info) + model = ComponentModel.from_run_info(run_info) n_blocks = get_model_n_blocks(cast(nn.Module, model.target_model)) config = run_info.config task_config = config.task_config diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index 85da48838..ae745a0c9 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -8,8 +8,8 @@ from spd.clustering.consts import BatchTensor from spd.configs import LMTaskConfig, ResidMLPTaskConfig from spd.data import DatasetConfig, create_data_loader -from spd.experiments.resid_mlp.models import ResidMLP, load_resid_mlp_component_model_from_run_info -from spd.models.component_model import SPDRunInfo +from spd.experiments.resid_mlp.models import ResidMLP +from spd.models.component_model import ComponentModel, SPDRunInfo from spd.spd_types import TaskName @@ -106,7 +106,7 @@ def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchT spd_run = SPDRunInfo.from_path(model_path) cfg = spd_run.config - component_model = load_resid_mlp_component_model_from_run_info(spd_run) + component_model = ComponentModel.from_run_info(spd_run) assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( f"Expected task_config to be of type ResidMLPTaskConfig, but got {type(cfg.task_config) = }" diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index e24c50827..f76e64d5a 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -49,9 +49,8 @@ from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration from spd.clustering.storage import StorageBase from spd.clustering.wandb_tensor_info import wandb_log_tensor -from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.log import logger -from spd.models.component_model import SPDRunInfo +from spd.models.component_model import ComponentModel, SPDRunInfo from spd.spd_types import TaskName from spd.utils.distributed_utils import get_device from spd.utils.general_utils import replace_pydantic_model @@ -299,7 +298,7 @@ def main(run_config: ClusteringRunConfig) -> Path: # 3. Load model logger.info("Loading model") - model = load_lm_component_model_from_run_info(spd_run).to(device) + model = ComponentModel.from_run_info(spd_run).to(device) # 4. Compute activations logger.info("Computing activations") diff --git a/spd/data.py b/spd/data.py index 2519f9fda..0b694c1a3 100644 --- a/spd/data.py +++ b/spd/data.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import Callable, Generator from typing import Any import numpy as np @@ -153,6 +153,7 @@ def create_data_loader( dist_state: DistributedState | None = None, global_seed: int = 0, to_lower: bool = True, + collate_fn: Callable[..., Any] | None = None, ) -> tuple[DataLoader[Int[Tensor, "batch seq"]], PreTrainedTokenizer]: """Create a DataLoader for the given dataset. @@ -263,10 +264,16 @@ def create_data_loader( ), drop_last=True, generator=generator, + collate_fn=collate_fn, ) return loader, tokenizer +def lm_collate_fn(batch: list[dict[str, Tensor]]) -> Tensor: + """Collate function that extracts input_ids tensors from HuggingFace dataset dicts.""" + return torch.stack([item["input_ids"] for item in batch]) + + def loop_dataloader[T](dl: DataLoader[T]) -> Generator[T]: """Loop over a dataloader, resetting the iterator when it is exhausted. diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index a16be5897..2068be277 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -26,7 +26,6 @@ from spd.dataset_attributions.harvester import AttributionHarvester from spd.dataset_attributions.loaders import get_attributions_dir from spd.dataset_attributions.storage import DatasetAttributionStorage -from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.harvest.loaders import load_activation_contexts_summary from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo @@ -140,7 +139,7 @@ def harvest_attributions( _, _, run_id = parse_wandb_run_path(config.wandb_path) run_info = SPDRunInfo.from_path(config.wandb_path) - model = load_lm_component_model_from_run_info(run_info).to(device) + model = ComponentModel.from_run_info(run_info).to(device) model.eval() spd_config = run_info.config diff --git a/spd/experiments/ih/model.py b/spd/experiments/ih/model.py index 84babad42..143f4eefe 100644 --- a/spd/experiments/ih/model.py +++ b/spd/experiments/ih/model.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import override +from typing import Any, override import torch from jaxtyping import Float @@ -210,7 +210,10 @@ def __init__(self, cfg: InductionModelConfig): self.unembed = nn.Linear(cfg.d_model, adjusted_vocab_size, bias=False) @override - def forward(self, tokens: Float[Tensor, "B S"], **_): + def forward( + self, batch: tuple[Float[Tensor, "B S"], ...] | Float[Tensor, "B S"], **_: Any + ) -> Float[Tensor, "B S V"]: + tokens = batch[0] if isinstance(batch, tuple) else batch x = self.token_embed(tokens) for block in self.blocks: diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index fc2d98886..6172cffd2 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -8,7 +8,7 @@ from simple_stories_train.run_info import RunInfo as SSRunInfo from spd.configs import Config, LMTaskConfig -from spd.data import DatasetConfig, create_data_loader +from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.log import logger from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize @@ -141,6 +141,7 @@ def main( buffer_size=config.task_config.buffer_size, global_seed=config.seed, dist_state=dist_state, + collate_fn=lm_collate_fn, ) eval_data_config = DatasetConfig( @@ -170,6 +171,7 @@ def main( buffer_size=config.task_config.buffer_size, global_seed=config.seed + 1, dist_state=dist_state, + collate_fn=lm_collate_fn, ) if is_main_process(): diff --git a/spd/experiments/lm/loaders.py b/spd/experiments/lm/loaders.py deleted file mode 100644 index 8e9fe11c9..000000000 --- a/spd/experiments/lm/loaders.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Loaders for LM ComponentModels.""" - -from collections.abc import Generator, Iterator -from typing import Any, override - -import torch -from torch import Tensor, nn -from torch.nn import Parameter - -from spd.configs import Config -from spd.identity_insertion import insert_identity_operations_ -from spd.interfaces import LoadableModule, RunInfo -from spd.models.component_model import ( - ComponentModel, - SPDRunInfo, - handle_deprecated_state_dict_keys_, -) -from spd.spd_types import ModelPath -from spd.utils.general_utils import resolve_class -from spd.utils.module_utils import expand_module_patterns - - -class LogitsOnlyWrapper(nn.Module): - """Wrapper that extracts logits from models that return (logits, loss) tuples.""" - - def __init__(self, model: nn.Module): - super().__init__() - self.model = model - - @override - def forward(self, *args: Any, **kwargs: Any) -> Tensor: - out = self.model(*args, **kwargs) - if isinstance(out, tuple): - return out[0] - return out - - @override - def get_submodule(self, target: str) -> nn.Module: - # Delegate to wrapped model so paths don't need "model." prefix - return self.model.get_submodule(target) - - @override - def named_parameters( - self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True - ) -> Iterator[tuple[str, Parameter]]: - # Delegate to wrapped model so parameter names don't have "model." prefix - return self.model.named_parameters( - prefix=prefix, recurse=recurse, remove_duplicate=remove_duplicate - ) - - @override - def named_modules( - self, memo: set[nn.Module] | None = None, prefix: str = "", remove_duplicate: bool = True - ) -> Generator[tuple[str, nn.Module]]: - # Delegate to wrapped model so module names don't have "model." prefix - yield from self.model.named_modules( - memo=memo, prefix=prefix, remove_duplicate=remove_duplicate - ) - - @override - def __getattr__(self, name: str) -> Any: - # Delegate attribute access to the wrapped model for things like wte, lm_head, etc. - if name == "model": - return super().__getattr__(name) - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.model, name) - - -def load_lm_component_model_from_run_info( - run_info: RunInfo[Config], -) -> ComponentModel[Tensor, Tensor]: - """Load a trained LM ComponentModel from a run info object.""" - config = run_info.config - - model_class = resolve_class(config.pretrained_model_class) - if config.pretrained_model_name is not None: - assert hasattr(model_class, "from_pretrained") - target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] - else: - assert issubclass(model_class, LoadableModule) - assert config.pretrained_model_path is not None - target_model = model_class.from_pretrained(config.pretrained_model_path) - - target_model.eval() - target_model.requires_grad_(False) - - if config.identity_module_info is not None: - insert_identity_operations_( - target_model, - identity_module_info=config.identity_module_info, - ) - - # Wrap the model to extract logits from (logits, loss) tuple outputs - wrapped_model = LogitsOnlyWrapper(target_model) - - module_path_info = expand_module_patterns(target_model, config.all_module_info) - - comp_model: ComponentModel[Tensor, Tensor] = ComponentModel( - target_model=wrapped_model, - module_path_info=module_path_info, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - ci_fn_type=config.ci_fn_type, - sigmoid_type=config.sigmoid_type, - ) - - comp_model_weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) - handle_deprecated_state_dict_keys_(comp_model_weights) - comp_model.load_state_dict(comp_model_weights) - - return comp_model - - -def load_lm_component_model(path: ModelPath) -> ComponentModel[Tensor, Tensor]: - """Load a trained LM ComponentModel from a wandb or local path.""" - run_info = SPDRunInfo.from_path(path) - return load_lm_component_model_from_run_info(run_info) diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index ce68d3409..df0605a1b 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -10,20 +10,13 @@ from jaxtyping import Float from torch import Tensor, nn -from spd.configs import Config from spd.experiments.resid_mlp.configs import ( ResidMLPModelConfig, ResidMLPTrainConfig, ) -from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo -from spd.models.component_model import ( - ComponentModel, - SPDRunInfo, - handle_deprecated_state_dict_keys_, -) from spd.spd_types import ModelPath -from spd.utils.module_utils import expand_module_patterns, init_param_ +from spd.utils.module_utils import init_param_ ResidMLPBatch = tuple[Float[Tensor, "... n_features"], Float[Tensor, "... n_features"]] ResidMLPOutput = Float[Tensor, "... n_features"] @@ -132,45 +125,3 @@ def from_pretrained(cls, path: ModelPath) -> "ResidMLP": """Fetch a pretrained model from wandb or a local path to a checkpoint.""" run_info = ResidMLPTargetRunInfo.from_path(path) return cls.from_run_info(run_info) - - -def load_resid_mlp_component_model_from_run_info( - run_info: RunInfo[Config], -) -> ComponentModel[ResidMLPBatch, ResidMLPOutput]: - """Load a trained ResidMLP ComponentModel from a run info object.""" - config = run_info.config - assert config.pretrained_model_path is not None - - target_model = ResidMLP.from_pretrained(config.pretrained_model_path) - target_model.eval() - target_model.requires_grad_(False) - - if config.identity_module_info is not None: - insert_identity_operations_( - target_model, - identity_module_info=config.identity_module_info, - ) - - module_path_info = expand_module_patterns(target_model, config.all_module_info) - - comp_model: ComponentModel[ResidMLPBatch, ResidMLPOutput] = ComponentModel( - target_model=target_model, - module_path_info=module_path_info, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - ci_fn_type=config.ci_fn_type, - sigmoid_type=config.sigmoid_type, - ) - - comp_model_weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) - handle_deprecated_state_dict_keys_(comp_model_weights) - comp_model.load_state_dict(comp_model_weights) - - return comp_model - - -def load_resid_mlp_component_model( - path: ModelPath, -) -> ComponentModel[ResidMLPBatch, ResidMLPOutput]: - """Load a trained ResidMLP ComponentModel from a wandb or local path.""" - run_info = SPDRunInfo.from_path(path) - return load_resid_mlp_component_model_from_run_info(run_info) diff --git a/spd/experiments/resid_mlp/resid_mlp_interp.py b/spd/experiments/resid_mlp/resid_mlp_interp.py index 93a8d7618..506989176 100644 --- a/spd/experiments/resid_mlp/resid_mlp_interp.py +++ b/spd/experiments/resid_mlp/resid_mlp_interp.py @@ -12,11 +12,10 @@ from spd.experiments.resid_mlp.models import ( MLP, ResidMLP, - load_resid_mlp_component_model_from_run_info, ) from spd.experiments.tms.models import TMSModel from spd.log import logger -from spd.models.component_model import SPDRunInfo +from spd.models.component_model import ComponentModel, SPDRunInfo from spd.models.components import Components from spd.plotting import plot_causal_importance_vals from spd.registry import EXPERIMENT_REGISTRY @@ -39,7 +38,7 @@ def extract_ci_val_figures( Dictionary containing causal importances data and metadata """ run_info = SPDRunInfo.from_path(run_id) - model = load_resid_mlp_component_model_from_run_info(run_info) + model = ComponentModel.from_run_info(run_info) model.to(device) config = run_info.config @@ -482,7 +481,7 @@ def main(out_dir: Path, device: str): wandb_id = path.split("/")[-1] run_info = SPDRunInfo.from_path(path) - model = load_resid_mlp_component_model_from_run_info(run_info) + model = ComponentModel.from_run_info(run_info) config = run_info.config assert isinstance(model.target_model, ResidMLP) model.target_model.to(device) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 93e0187d5..95327b51d 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -6,17 +6,9 @@ from torch import Tensor, nn from torch.nn import functional as F -from spd.configs import Config from spd.experiments.tms.configs import TMSModelConfig, TMSTrainConfig -from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo -from spd.models.component_model import ( - ComponentModel, - SPDRunInfo, - handle_deprecated_state_dict_keys_, -) from spd.spd_types import ModelPath -from spd.utils.module_utils import expand_module_patterns TMSBatch = tuple[Float[Tensor, "... n_features"], Float[Tensor, "... n_features"]] TMSOutput = Float[Tensor, "... n_features"] @@ -92,43 +84,3 @@ def from_pretrained(cls, path: ModelPath) -> "TMSModel": """Fetch a pretrained model from wandb or a local path to a checkpoint.""" run_info = TMSTargetRunInfo.from_path(path) return cls.from_run_info(run_info) - - -def load_tms_component_model_from_run_info( - run_info: RunInfo[Config], -) -> ComponentModel[TMSBatch, TMSOutput]: - """Load a trained TMS ComponentModel from a run info object.""" - config = run_info.config - assert config.pretrained_model_path is not None - - target_model = TMSModel.from_pretrained(config.pretrained_model_path) - target_model.eval() - target_model.requires_grad_(False) - - if config.identity_module_info is not None: - insert_identity_operations_( - target_model, - identity_module_info=config.identity_module_info, - ) - - module_path_info = expand_module_patterns(target_model, config.all_module_info) - - comp_model: ComponentModel[TMSBatch, TMSOutput] = ComponentModel( - target_model=target_model, - module_path_info=module_path_info, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - ci_fn_type=config.ci_fn_type, - sigmoid_type=config.sigmoid_type, - ) - - comp_model_weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) - handle_deprecated_state_dict_keys_(comp_model_weights) - comp_model.load_state_dict(comp_model_weights) - - return comp_model - - -def load_tms_component_model(path: ModelPath) -> ComponentModel[TMSBatch, TMSOutput]: - """Load a trained TMS ComponentModel from a wandb or local path.""" - run_info = SPDRunInfo.from_path(path) - return load_tms_component_model_from_run_info(run_info) diff --git a/spd/experiments/tms/plotting.py b/spd/experiments/tms/plotting.py index ca2094810..daa7ec57b 100644 --- a/spd/experiments/tms/plotting.py +++ b/spd/experiments/tms/plotting.py @@ -20,8 +20,9 @@ from matplotlib.figure import Figure from torch import Tensor -from spd.experiments.tms.models import TMSModel, load_tms_component_model +from spd.experiments.tms.models import TMSModel from spd.log import logger +from spd.models.component_model import ComponentModel from spd.models.components import Components from spd.settings import REPO_ROOT @@ -980,7 +981,7 @@ def main(): out_dir.mkdir(parents=True, exist_ok=True) # Load models - model = load_tms_component_model(run_id) + model = ComponentModel.from_pretrained(run_id) assert isinstance(model.target_model, TMSModel) # Get custom config and name for this run diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index 8e82d7af2..a246bc5e2 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -24,7 +24,6 @@ from torch import Tensor from spd.data import train_loader_and_tokenizer -from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.harvest.lib.harvester import Harvester, HarvesterState from spd.harvest.schemas import ( ActivationExample, @@ -203,7 +202,7 @@ def harvest_activation_contexts( logger.info(f"Loading model on {device}") run_info = SPDRunInfo.from_path(config.wandb_path) - model = load_lm_component_model_from_run_info(run_info).to(device) + model = ComponentModel.from_run_info(run_info).to(device) model.eval() spd_config = run_info.config diff --git a/spd/metrics/identity_ci_error.py b/spd/metrics/identity_ci_error.py index b5771d7e9..b47ab8f8d 100644 --- a/spd/metrics/identity_ci_error.py +++ b/spd/metrics/identity_ci_error.py @@ -32,9 +32,10 @@ def __init__( self.batch_shape: tuple[int, ...] | None = None @override - def update(self, *, batch: Tensor, **_: Any) -> None: + def update(self, *, batch: Tensor | tuple[Tensor, ...], **_: Any) -> None: if self.batch_shape is None: - self.batch_shape = tuple(batch.shape) + input_tensor = batch[0] if isinstance(batch, tuple) else batch + self.batch_shape = tuple(input_tensor.shape) @override def compute(self) -> dict[str, float]: diff --git a/spd/metrics/permuted_ci_plots.py b/spd/metrics/permuted_ci_plots.py index f0b340b70..f81b70fb1 100644 --- a/spd/metrics/permuted_ci_plots.py +++ b/spd/metrics/permuted_ci_plots.py @@ -30,9 +30,10 @@ def __init__( self.batch_shape: tuple[int, ...] | None = None @override - def update(self, *, batch: Tensor, **_: Any) -> None: + def update(self, *, batch: Tensor | tuple[Tensor, ...], **_: Any) -> None: if self.batch_shape is None: - self.batch_shape = tuple(batch.shape) + input_tensor = batch[0] if isinstance(batch, tuple) else batch + self.batch_shape = tuple(input_tensor.shape) @override def compute(self) -> dict[str, Image.Image]: diff --git a/spd/metrics/uv_plots.py b/spd/metrics/uv_plots.py index 2c1071ec3..8e5849df3 100644 --- a/spd/metrics/uv_plots.py +++ b/spd/metrics/uv_plots.py @@ -30,9 +30,10 @@ def __init__( self.batch_shape: tuple[int, ...] | None = None @override - def update(self, *, batch: Tensor, **_: Any) -> None: + def update(self, *, batch: Tensor | tuple[Tensor, ...], **_: Any) -> None: if self.batch_shape is None: - self.batch_shape = tuple(batch.shape) + input_tensor = batch[0] if isinstance(batch, tuple) else batch + self.batch_shape = tuple(input_tensor.shape) @override def compute(self) -> dict[str, Image.Image]: diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 84c9c5a9e..2c96fd4b4 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -11,7 +11,8 @@ from transformers.pytorch_utils import Conv1D as RadfordConv1D from spd.configs import Config, SamplingType -from spd.interfaces import RunInfo +from spd.identity_insertion import insert_identity_operations_ +from spd.interfaces import LoadableModule, RunInfo from spd.models.components import ( Components, ComponentsMaskInfo, @@ -23,9 +24,9 @@ VectorSharedMLPCiFn, ) from spd.models.sigmoids import SIGMOID_TYPES, SigmoidType -from spd.spd_types import CiFnType -from spd.utils.general_utils import runtime_cast -from spd.utils.module_utils import ModulePathInfo +from spd.spd_types import CiFnType, ModelPath +from spd.utils.general_utils import resolve_class, runtime_cast +from spd.utils.module_utils import ModulePathInfo, expand_module_patterns @dataclass @@ -128,6 +129,50 @@ def __init__( self.lower_leaky_fn = SIGMOID_TYPES[sigmoid_type] self.upper_leaky_fn = SIGMOID_TYPES[sigmoid_type] + @classmethod + def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any, Any]": + """Load a trained ComponentModel from a run info object.""" + config = run_info.config + + model_class = resolve_class(config.pretrained_model_class) + if config.pretrained_model_name is not None: + assert hasattr(model_class, "from_pretrained") + target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] + else: + assert issubclass(model_class, LoadableModule) + assert config.pretrained_model_path is not None + target_model = model_class.from_pretrained(config.pretrained_model_path) + + target_model.eval() + target_model.requires_grad_(False) + + if config.identity_module_info is not None: + insert_identity_operations_( + target_model, + identity_module_info=config.identity_module_info, + ) + + module_path_info = expand_module_patterns(target_model, config.all_module_info) + + comp_model: ComponentModel[Any, Any] = cls( + target_model=target_model, + module_path_info=module_path_info, + ci_fn_hidden_dims=config.ci_fn_hidden_dims, + ci_fn_type=config.ci_fn_type, + sigmoid_type=config.sigmoid_type, + ) + + weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) + handle_deprecated_state_dict_keys_(weights) + comp_model.load_state_dict(weights) + return comp_model + + @classmethod + def from_pretrained(cls, path: ModelPath) -> "ComponentModel[Any, Any]": + """Load a trained ComponentModel from a wandb or local path.""" + run_info = SPDRunInfo.from_path(path) + return cls.from_run_info(run_info) + def target_weight(self, module_name: str) -> Float[Tensor, "rows cols"]: target_module = self.target_model.get_submodule(module_name) @@ -302,7 +347,7 @@ def forward( """ if mask_infos is None and cache_type == "none": # No hooks needed. Do a regular forward pass of the target model. - return self.target_model(batch) + return self._get_first_element_if_tuple(self.target_model(batch)) cache: dict[str, Tensor] = {} hooks: dict[str, Callable[..., Any]] = {} @@ -323,7 +368,7 @@ def forward( ) with self._attach_forward_hooks(hooks): - out: OutputT = self.target_model(batch) + out: OutputT = self._get_first_element_if_tuple(self.target_model(batch)) match cache_type: case "input" | "component_acts": @@ -331,6 +376,21 @@ def forward( case "none": return out + @staticmethod + def _get_first_element_if_tuple(out: Any) -> Any: + """Extract primary output from various model output formats. + + Handles: + - Tuple outputs: returns first element (e.g. (logits, past_key_values, ...)) + - HuggingFace ModelOutput: extracts .logits attribute + - Raw tensors: returned as-is + """ + if isinstance(out, tuple): + return out[0] + if hasattr(out, "logits"): + return out.logits + return out + def _components_and_cache_hook( self, _module: nn.Module, diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index f1b25e608..1b4cb177e 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -22,7 +22,6 @@ from spd.base_config import BaseConfig from spd.configs import Config -from spd.experiments.lm.loaders import load_lm_component_model_from_run_info from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device @@ -84,7 +83,7 @@ def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel[Any, A """Load model and config using the standard pattern from existing codebase.""" run_info = SPDRunInfo.from_path(model_path) # TODO(oli): this should actually be generic (one of the only instances of this I think) - model = load_lm_component_model_from_run_info(run_info) + model = ComponentModel.from_run_info(run_info) model.to(self.device) model.eval() model.requires_grad_(False) diff --git a/spd/simple_trainer.py b/spd/simple_trainer.py deleted file mode 100644 index bcf92945c..000000000 --- a/spd/simple_trainer.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Run SPD on a model.""" - -import gc -from collections import defaultdict -from pathlib import Path -from typing import cast - -import torch -import torch.nn as nn -import torch.nn.parallel -import wandb -from PIL import Image -from torch import optim -from torch.nn.utils import clip_grad_norm_ -from torch.utils.data import DataLoader -from tqdm import tqdm - -from spd.configs import Config -from spd.eval import evaluate -from spd.log import logger -from spd.losses import compute_total_loss -from spd.models.batch_and_loss_fns import ReconstructionLoss -from spd.models.component_model import ComponentModel, OutputWithCache, TargetModel -from spd.run_spd import get_unique_metric_configs, run_faithfulness_warmup -from spd.utils.component_utils import calc_ci_l_zero -from spd.utils.distributed_utils import ( - avg_metrics_across_ranks, - get_distributed_state, - is_main_process, - sync_across_processes, -) -from spd.utils.general_utils import dict_safe_update_, get_scheduled_value, runtime_cast -from spd.utils.logging_utils import get_grad_norms_dict, local_log -from spd.utils.module_utils import expand_module_patterns -from spd.utils.run_utils import save_file -from spd.utils.wandb_utils import try_wandb - - -def optimize[BatchT, OutputT]( - target_model: TargetModel[BatchT, OutputT], - config: Config, - device: str, - train_loader: DataLoader[BatchT], - eval_loader: DataLoader[BatchT], - n_eval_steps: int, - reconstruction_loss: ReconstructionLoss[OutputT], - out_dir: Path | None, -) -> None: - """Run the optimization loop for LM decomposition.""" - train_iterator = iter(train_loader) - eval_iterator = iter(eval_loader) - - runtime_cast(nn.Module, target_model).requires_grad_(False) - model = ComponentModel( - target_model=target_model, - module_path_info=expand_module_patterns( - runtime_cast(nn.Module, target_model), config.all_module_info - ), - ci_fn_type=config.ci_fn_type, - ci_fn_hidden_dims=config.ci_fn_hidden_dims, - sigmoid_type=config.sigmoid_type, - ) - model.to(device) - - # Wrap model with DDP if distributed - dist_state = get_distributed_state() - wrapped_model: nn.Module = model - - component_model: ComponentModel[BatchT, OutputT] - if dist_state is not None: - if dist_state.backend == "nccl": - device_id = dist_state.local_rank - wrapped_model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[device_id], - output_device=device_id, - ) - else: - # For CPU, don't pass device_ids or output_device - wrapped_model = torch.nn.parallel.DistributedDataParallel(model) - # Access the underlying module for component operations - component_model = wrapped_model.module # type: ignore[attr-defined] - else: - component_model = model - assert isinstance(component_model, ComponentModel), "component_model is not a ComponentModel" - - component_params: list[torch.nn.Parameter] = [] - ci_fn_params: list[torch.nn.Parameter] = [] - for name in component_model.target_module_paths: - component_params.extend(component_model.components[name].parameters()) - ci_fn_params.extend(component_model.ci_fns[name].parameters()) - - assert len(component_params) > 0, "No parameters found in components to optimize" - - optimizer = optim.AdamW( - component_params + ci_fn_params, - lr=config.lr_schedule.start_val, - weight_decay=0, - ) - - logger.info(f"LR scheduler: {config.lr_schedule.fn_type}") - - if config.faithfulness_warmup_steps > 0: - run_faithfulness_warmup(component_model, component_params, config) - - eval_metric_configs = get_unique_metric_configs( - loss_configs=config.loss_metric_configs, eval_configs=config.eval_metric_configs - ) - - for step in tqdm(range(config.steps + 1), ncols=0, disable=not is_main_process()): - optimizer.zero_grad() - - step_lr = get_scheduled_value( - step=step, total_steps=config.steps, config=config.lr_schedule - ) - for group in optimizer.param_groups: - group["lr"] = step_lr - - weight_deltas = component_model.calc_weight_deltas() - - microbatch_log_data: defaultdict[str, float] = defaultdict(float) - - for _ in range(config.gradient_accumulation_steps): - microbatch = next(train_iterator) - - # NOTE: we need to call the wrapped_model at least once each step in order to setup - # the DDP gradient syncing for all parameters in the component model. Gradients will - # sync regardless of whether the parameters are used in this call to wrapped_model. - target_model_output: OutputWithCache[OutputT] = wrapped_model( - microbatch, cache_type="input" - ) - - ci = component_model.calc_causal_importances( - pre_weight_acts=target_model_output.cache, - detach_inputs=False, - sampling=config.sampling, - ) - - microbatch_total_loss, microbatch_loss_terms = compute_total_loss( - loss_metric_configs=config.loss_metric_configs, - model=component_model, - batch=microbatch, - ci=ci, - target_out=target_model_output.output, - weight_deltas=weight_deltas, - pre_weight_acts=target_model_output.cache, - current_frac_of_training=step / config.steps, - sampling=config.sampling, - use_delta_component=config.use_delta_component, - n_mask_samples=config.n_mask_samples, - reconstruction_loss=reconstruction_loss, - ) - microbatch_total_loss.div_(config.gradient_accumulation_steps).backward() - - for loss_name, loss_value in microbatch_loss_terms.items(): - microbatch_log_data[f"train/{loss_name}"] += ( - loss_value / config.gradient_accumulation_steps - ) - - for layer_name, layer_ci in ci.lower_leaky.items(): - l0_val = calc_ci_l_zero(layer_ci, config.ci_alive_threshold) - microbatch_log_data[f"train/l0/{layer_name}"] += ( - l0_val / config.gradient_accumulation_steps - ) - - # --- Train Logging --- # - if step % config.train_log_freq == 0: - avg_metrics = avg_metrics_across_ranks(microbatch_log_data, device=device) - microbatch_log_data = cast(defaultdict[str, float], avg_metrics) - - grad_norms = get_grad_norms_dict(component_model, device) - dict_safe_update_( - microbatch_log_data, {f"train/grad_norms/{k}": v for k, v in grad_norms.items()} - ) - - microbatch_log_data["train/schedules/lr"] = step_lr - - if is_main_process(): - assert out_dir is not None - tqdm.write(f"--- Step {step} ---") - tqdm.write(f"LR: {step_lr:.6f}") - for name, value in microbatch_log_data.items(): - tqdm.write(f"{name}: {value:.15f}") - local_log(microbatch_log_data, step, out_dir) - if config.wandb_project: - try_wandb(wandb.log, microbatch_log_data, step=step) - - # --- Evaluation --- # - if step % config.eval_freq == 0: - with torch.no_grad(): - slow_step: bool = ( - config.slow_eval_on_first_step - if step == 0 - else step % config.slow_eval_freq == 0 - ) - - metrics = evaluate( - eval_metric_configs=eval_metric_configs, - model=component_model, # No backward passes so DDP wrapped_model not needed - eval_iterator=eval_iterator, - device=device, - run_config=config, - slow_step=slow_step, - n_eval_steps=n_eval_steps, - current_frac_of_training=step / config.steps, - reconstruction_loss=reconstruction_loss, - ) - - if is_main_process(): - assert out_dir is not None - for k, v in metrics.items(): - tqdm.write(f"eval/{k}: {v}") - local_log(metrics, step, out_dir) - if config.wandb_project: - wandb_logs = { - f"eval/{k}": wandb.Image(v) if isinstance(v, Image.Image) else v - for k, v in metrics.items() - } - try_wandb(wandb.log, wandb_logs, step=step) - - del metrics - - gc.collect() - torch.cuda.empty_cache() - - # --- Saving Checkpoint --- # - if ( - (config.save_freq is not None and step % config.save_freq == 0 and step > 0) - or step == config.steps - ) and is_main_process(): - assert out_dir is not None - # Save the state dict of the underlying module (not DDP wrapper) - save_file(component_model.state_dict(), out_dir / f"model_{step}.pth") - logger.info(f"Saved model, optimizer, and out_dir to {out_dir}") - if config.wandb_project: - try_wandb( - wandb.save, - str(out_dir / f"model_{step}.pth"), - base_path=str(out_dir), - policy="now", - ) - - sync_across_processes() - if config.grad_clip_norm_components is not None: - clip_grad_norm_(component_params, config.grad_clip_norm_components) - if config.grad_clip_norm_ci_fns is not None: - clip_grad_norm_(ci_fn_params, config.grad_clip_norm_ci_fns) - optimizer.step() - - if is_main_process(): - logger.info("Finished training loop.") diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 992edd6ab..958620dd8 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -22,7 +22,6 @@ from spd.app.backend.server import app from spd.app.backend.state import HarvestCache, RunState, StateManager from spd.configs import Config, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig -from spd.experiments.lm.loaders import LogitsOnlyWrapper from spd.models.component_model import ComponentModel from spd.utils.module_utils import expand_module_patterns @@ -113,10 +112,8 @@ def app_with_state(): ), ) module_path_info = expand_module_patterns(target_model, config.module_info) - # Wrap the model to extract logits from (logits, loss) tuple outputs - wrapped_model = LogitsOnlyWrapper(target_model) model = ComponentModel( - target_model=wrapped_model, + target_model=target_model, module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 0289173ab..e11cf591c 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -226,11 +226,16 @@ def _compare_saved_models( self, dp1_out_dir: Path, dp2_out_dir: Path, - atol: float = 1e-6, - rtol: float = 1e-5, + atol: float = 2e-4, + rtol: float = 1e-3, ) -> None: """Compare saved model parameters between dp=1 and dp=2 runs. + Tolerances are relatively loose because CI-masked reconstruction losses use hard + masking: tiny allreduce rounding differences can push a CI value across the mask + threshold, causing a different gradient path that compounds over training steps. + Empirically, across many seeds, max parameter diffs stay below ~1.5e-4. + Args: dp1_out_dir: Output directory for dp=1 run dp2_out_dir: Output directory for dp=2 run diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index 308249248..2da46b86f 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -14,7 +14,7 @@ StochasticReconLayerwiseLossConfig, StochasticReconLossConfig, ) -from spd.data import DatasetConfig, create_data_loader +from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.identity_insertion import insert_identity_operations_ from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize @@ -125,6 +125,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: batch_size=config.batch_size, buffer_size=config.task_config.buffer_size, global_seed=config.seed, + collate_fn=lm_collate_fn, ) eval_data_config = DatasetConfig( @@ -142,6 +143,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: batch_size=config.batch_size, buffer_size=config.task_config.buffer_size, global_seed=config.seed + 1, + collate_fn=lm_collate_fn, ) # Run optimize function diff --git a/tests/test_wandb_run_loading.py b/tests/test_wandb_run_loading.py index be150d17e..3c8fd2efd 100644 --- a/tests/test_wandb_run_loading.py +++ b/tests/test_wandb_run_loading.py @@ -5,35 +5,14 @@ the canonical configs, and update the registry with your new run(s). """ -from typing import Any - import pytest -from spd.experiments.lm.loaders import load_lm_component_model_from_run_info -from spd.experiments.resid_mlp.models import load_resid_mlp_component_model_from_run_info -from spd.experiments.tms.models import load_tms_component_model_from_run_info from spd.models.component_model import ComponentModel, SPDRunInfo from spd.registry import EXPERIMENT_REGISTRY from spd.utils.wandb_utils import parse_wandb_run_path - -def load_component_model(canonical_run: str, task_name: str) -> ComponentModel[Any, Any]: - """Load a ComponentModel using the appropriate experiment-specific loader.""" - run_info = SPDRunInfo.from_path(canonical_run) - - loaders: dict[str, Any] = { - "tms": load_tms_component_model_from_run_info, - "resid_mlp": load_resid_mlp_component_model_from_run_info, - "lm": load_lm_component_model_from_run_info, - } - - loader = loaders.get(task_name) - assert loader is not None, f"No loader found for task_name: {task_name}" - return loader(run_info) - - CANONICAL_EXPS = [ - (exp_name, exp_config.canonical_run, exp_config.task_name) + (exp_name, exp_config.canonical_run) for exp_name, exp_config in EXPERIMENT_REGISTRY.items() if exp_config.canonical_run is not None ] @@ -41,12 +20,13 @@ def load_component_model(canonical_run: str, task_name: str) -> ComponentModel[A @pytest.mark.requires_wandb @pytest.mark.slow -@pytest.mark.parametrize("exp_name, canonical_run, task_name", CANONICAL_EXPS) -def test_loading_from_wandb(exp_name: str, canonical_run: str, task_name: str) -> None: +@pytest.mark.parametrize("exp_name, canonical_run", CANONICAL_EXPS) +def test_loading_from_wandb(exp_name: str, canonical_run: str) -> None: try: - load_component_model(canonical_run, task_name) + run_info = SPDRunInfo.from_path(canonical_run) + ComponentModel.from_run_info(run_info) except Exception as e: - e.add_note(f"Error loading {exp_name} from {canonical_run} (task: {task_name})") + e.add_note(f"Error loading {exp_name} from {canonical_run}") raise e From d2c5465e68c9191acf02e3cf50e88ef5410d315a Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 9 Feb 2026 21:35:48 +0000 Subject: [PATCH 08/16] Add extract_tensor_output config arg --- spd/configs.py | 14 +++++- spd/experiments/lm/gpt2_config.yaml | 2 +- .../lm/pile_llama_simple_mlp-2L.yaml | 2 +- .../lm/pile_llama_simple_mlp-4L.yaml | 2 +- spd/experiments/lm/ss_gpt2_config.yaml | 2 +- spd/experiments/lm/ss_gpt2_simple-1L.yaml | 2 +- spd/experiments/lm/ss_gpt2_simple-2L.yaml | 2 +- spd/experiments/lm/ss_gpt2_simple_config.yaml | 2 +- .../lm/ss_gpt2_simple_noln_config.yaml | 2 +- spd/experiments/lm/ss_llama_simple-1L.yaml | 2 +- spd/experiments/lm/ss_llama_simple-2L.yaml | 2 +- .../lm/ss_llama_simple_config.yaml | 2 +- .../lm/ss_llama_simple_mlp-1L.yaml | 2 +- .../lm/ss_llama_simple_mlp-2L-wide.yaml | 2 +- .../lm/ss_llama_simple_mlp-2L.yaml | 2 +- spd/experiments/lm/ss_llama_simple_mlp.yaml | 2 +- spd/experiments/lm/ts_config.yaml | 2 +- spd/models/component_model.py | 49 +++++++++++++------ spd/run_spd.py | 1 + tests/app/test_server_api.py | 3 +- tests/test_component_model.py | 46 +++++++++++++++++ tests/test_distributed.py | 2 +- tests/test_gpt2.py | 2 +- tests/test_ih_transformer.py | 2 +- tests/test_resid_mlp.py | 2 +- tests/test_tms.py | 2 +- 26 files changed, 115 insertions(+), 40 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index f71eec08e..3deccee0b 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -566,9 +566,9 @@ def microbatch_size(self) -> PositiveInt: default=None, description="hf model identifier. E.g. 'SimpleStories/SimpleStories-1.25M'", ) - pretrained_model_output_attr: str | None = Field( + extract_tensor_output: str | None = Field( default=None, - description="Name of the attribute on the forward output that contains logits or activations", + description="Accessor path for extracting tensor from model output, e.g. '[0]' or '.logits'", ) tokenizer_name: str | None = Field( default=None, @@ -652,6 +652,16 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, "simple_stories_train.models.", "spd.pretrain.models.", 1 ) + # Migrate old pretrained_model_output_attr to extract_tensor_output + if "pretrained_model_output_attr" in config_dict: + old_val = config_dict.pop("pretrained_model_output_attr") + if old_val is not None: + accessor = "[0]" if old_val == "idx_0" else f".{old_val}" + config_dict["extract_tensor_output"] = accessor + logger.info( + f"Migrated pretrained_model_output_attr={old_val!r} to extract_tensor_output={accessor!r}" + ) + if "eval_batch_size" not in config_dict: config_dict["eval_batch_size"] = config_dict["batch_size"] if "train_log_freq" not in config_dict: diff --git a/spd/experiments/lm/gpt2_config.yaml b/spd/experiments/lm/gpt2_config.yaml index 3aa180fe5..3a2061df9 100644 --- a/spd/experiments/lm/gpt2_config.yaml +++ b/spd/experiments/lm/gpt2_config.yaml @@ -64,7 +64,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: openai-community/gpt2 -pretrained_model_output_attr: logits +extract_tensor_output: ".logits" tokenizer_name: openai-community/gpt2 # --- Task Specific --- diff --git a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml index fc179119f..e78b13819 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml @@ -115,7 +115,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-bd02d372 -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index c3c3a106a..becda3651 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -119,7 +119,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-32d1bb3b -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/ss_gpt2_config.yaml b/spd/experiments/lm/ss_gpt2_config.yaml index 9fcbeec47..ef1b2a6a2 100644 --- a/spd/experiments/lm/ss_gpt2_config.yaml +++ b/spd/experiments/lm/ss_gpt2_config.yaml @@ -64,7 +64,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: SimpleStories/test-SimpleStories-gpt2-1.25M -pretrained_model_output_attr: logits +extract_tensor_output: ".logits" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index 790002de7..e80a03a8a 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -92,7 +92,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/3qhd7rnb # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index 4080b1634..ffb197d19 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -94,7 +94,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/wr1su18m # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index ed6c497ee..f9a24a7dc 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -97,7 +97,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/rvu66183 # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml index 72dba01f0..5f9ede6a3 100644 --- a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml @@ -94,7 +94,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/xi36b9az # No ln -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # We'll load this from wandb in future # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index f9324bb2d..0442c6a0f 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -92,7 +92,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tfacbi70 # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index 512d0d3a1..3f0cc95e3 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -94,7 +94,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tb8373uo # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_config.yaml b/spd/experiments/lm/ss_llama_simple_config.yaml index 92d9eced4..1fbb6a661 100644 --- a/spd/experiments/lm/ss_llama_simple_config.yaml +++ b/spd/experiments/lm/ss_llama_simple_config.yaml @@ -94,7 +94,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_name: wandb:goodfire/spd/runs/erq48r3w # 100k steps 4019 tok -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index 9cb54c2de..c8a0eb7a8 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -86,7 +86,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gvbmdt9w -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index 712a288da..8a3327212 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -94,7 +94,7 @@ ci_alive_threshold: 0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: buffer_size: 1000 diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index 72cf4937e..7e2a19097 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -92,7 +92,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/7pt957pf # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index 79622689d..c8a33b540 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -117,7 +117,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/9de1zu65 # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ts_config.yaml b/spd/experiments/lm/ts_config.yaml index 4f8c966ac..2c23bdce0 100644 --- a/spd/experiments/lm/ts_config.yaml +++ b/spd/experiments/lm/ts_config.yaml @@ -65,7 +65,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.AutoModelForCausalLM pretrained_model_name: roneneldan/TinyStories-1M -pretrained_model_output_attr: logits +extract_tensor_output: ".logits" tokenizer_name: EleutherAI/gpt-neo-125M # --- Task Specific --- diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 2c96fd4b4..fa6c7434a 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,3 +1,4 @@ +import re from collections.abc import Callable, Generator, Iterator from contextlib import contextmanager from dataclasses import dataclass @@ -28,6 +29,29 @@ from spd.utils.general_utils import resolve_class, runtime_cast from spd.utils.module_utils import ModulePathInfo, expand_module_patterns +_ACCESSOR_TOKEN_RE = re.compile(r'\.\w+|\[\d+\]|\["\w+"\]') + + +def extract_with_accessor(obj: Any, accessor: str) -> Any: + """Navigate a nested object using an accessor path string. + + Supports attribute access (.attr), integer indexing ([i]), and string-key + dictionary access (["key"]). + Examples: "[0]", ".logits", "[0].logits[2]", '["hidden_states"]' + """ + assert accessor, "Accessor must be non-empty (use None for no extraction)" + tokens = _ACCESSOR_TOKEN_RE.findall(accessor) + assert "".join(tokens) == accessor, f"Invalid accessor: {accessor!r}" + result = obj + for token in tokens: + if token.startswith('["'): + result = result[token[2:-2]] + elif token.startswith("["): + result = result[int(token[1:-1])] + else: + result = getattr(result, token[1:]) + return result + @dataclass class SPDRunInfo(RunInfo[Config]): @@ -90,8 +114,10 @@ def __init__( ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], sigmoid_type: SigmoidType, + extract_tensor_output: str | None = None, ): super().__init__() + self.extract_tensor_output = extract_tensor_output for name, param in target_model.named_parameters(): assert not param.requires_grad, ( @@ -160,6 +186,7 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any, Any]": ci_fn_hidden_dims=config.ci_fn_hidden_dims, ci_fn_type=config.ci_fn_type, sigmoid_type=config.sigmoid_type, + extract_tensor_output=config.extract_tensor_output, ) weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) @@ -347,7 +374,7 @@ def forward( """ if mask_infos is None and cache_type == "none": # No hooks needed. Do a regular forward pass of the target model. - return self._get_first_element_if_tuple(self.target_model(batch)) + return self._extract_output(self.target_model(batch)) cache: dict[str, Tensor] = {} hooks: dict[str, Callable[..., Any]] = {} @@ -368,7 +395,7 @@ def forward( ) with self._attach_forward_hooks(hooks): - out: OutputT = self._get_first_element_if_tuple(self.target_model(batch)) + out: OutputT = self._extract_output(self.target_model(batch)) match cache_type: case "input" | "component_acts": @@ -376,20 +403,10 @@ def forward( case "none": return out - @staticmethod - def _get_first_element_if_tuple(out: Any) -> Any: - """Extract primary output from various model output formats. - - Handles: - - Tuple outputs: returns first element (e.g. (logits, past_key_values, ...)) - - HuggingFace ModelOutput: extracts .logits attribute - - Raw tensors: returned as-is - """ - if isinstance(out, tuple): - return out[0] - if hasattr(out, "logits"): - return out.logits - return out + def _extract_output(self, raw_output: Any) -> Any: + if self.extract_tensor_output is None: + return raw_output + return extract_with_accessor(raw_output, self.extract_tensor_output) def _components_and_cache_hook( self, diff --git a/spd/run_spd.py b/spd/run_spd.py index 2e258946b..11d345027 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -151,6 +151,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, sigmoid_type=config.sigmoid_type, + extract_tensor_output=config.extract_tensor_output, ) model.to(device) diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 1d62c0448..8224d2fd5 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -91,7 +91,7 @@ def app_with_state(): ModulePatternInfoConfig(module_pattern=p, C=C) for p in target_module_patterns ], pretrained_model_class="spd.pretrain.models.gpt2_simple.GPT2Simple", - pretrained_model_output_attr="idx_0", + extract_tensor_output="[0]", tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", output_loss_type="kl", lr_schedule=ScheduleConfig(start_val=1e-3), @@ -118,6 +118,7 @@ def app_with_state(): ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, sigmoid_type=config.sigmoid_type, + extract_tensor_output=config.extract_tensor_output, ) model.eval() sources_by_target = get_sources_by_target( diff --git a/tests/test_component_model.py b/tests/test_component_model.py index ab706ea44..37e8d3655 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -20,6 +20,7 @@ from spd.models.component_model import ( ComponentModel, SPDRunInfo, + extract_with_accessor, ) from spd.models.components import ( ComponentsMaskInfo, @@ -538,3 +539,48 @@ def forward(self, x: Tensor) -> Tensor: # but it should be the same for the second example (where it's not routed to components) assert torch.allclose(cm_routed_out[1], target_out[1]) + + +class TestExtractWithAccessor: + def test_integer_index(self): + obj = ("a", "b", "c") + assert extract_with_accessor(obj, "[0]") == "a" + assert extract_with_accessor(obj, "[2]") == "c" + + def test_attribute_access(self): + class Obj: + logits = 42 + + assert extract_with_accessor(Obj(), ".logits") == 42 + + def test_string_key_dict_access(self): + obj = {"hidden_states": "hs", "logits": "lg"} + assert extract_with_accessor(obj, '["hidden_states"]') == "hs" + assert extract_with_accessor(obj, '["logits"]') == "lg" + + def test_chained_accessors(self): + class Inner: + value = 99 + + obj = ({"data": Inner()},) + assert extract_with_accessor(obj, '[0]["data"].value') == 99 + + def test_index_then_attribute(self): + class Out: + logits = torch.tensor([1.0, 2.0]) + + obj = (Out(),) + result = extract_with_accessor(obj, "[0].logits") + assert torch.equal(result, torch.tensor([1.0, 2.0])) + + def test_invalid_accessor_raises(self): + with pytest.raises(AssertionError, match="Invalid accessor"): + extract_with_accessor({}, "invalid") + + def test_invalid_accessor_with_special_chars(self): + with pytest.raises(AssertionError, match="Invalid accessor"): + extract_with_accessor({}, '[" spaces "]') + + def test_empty_accessor_raises(self): + with pytest.raises(AssertionError, match="non-empty"): + extract_with_accessor({}, "") diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 34e1535b3..04ad8b0f2 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -54,7 +54,7 @@ # --- Pretrained model info --- "pretrained_model_class": "transformers.LlamaForCausalLM", "pretrained_model_name": "SimpleStories/SimpleStories-1.25M", - "pretrained_model_output_attr": "logits", + "extract_tensor_output": ".logits", "tokenizer_name": "SimpleStories/SimpleStories-1.25M", # --- Task Specific --- "task_config": { diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index 7d666a28e..fa965b0b1 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -79,7 +79,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="transformers.GPT2LMHeadModel", pretrained_model_path=None, pretrained_model_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - pretrained_model_output_attr="logits", + extract_tensor_output=".logits", tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", # Task Specific task_config=LMTaskConfig( diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 61ca8773b..bff6a25e6 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -96,7 +96,7 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.ih.model.InductionTransformer", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, + extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=IHTaskConfig( diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 77fb294df..be02eae6d 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -83,7 +83,7 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.resid_mlp.models.ResidMLP", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, + extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=ResidMLPTaskConfig( diff --git a/tests/test_tms.py b/tests/test_tms.py index 2cdfea373..9117992bf 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -92,7 +92,7 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.tms.models.TMSModel", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, + extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=TMSTaskConfig( From 52e9d19160d2bc4760a60b09ac5e04e46740b983 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Tue, 10 Feb 2026 14:07:04 +0000 Subject: [PATCH 09/16] Replace accessor DSL with RunBatch protocol (#375) --- spd/app/backend/compute.py | 2 +- spd/autointerp/interpret.py | 4 +- spd/configs.py | 40 ++++++++---- spd/experiments/ih/ih_config.yaml | 1 - spd/experiments/ih/ih_decomposition.py | 3 +- spd/experiments/lm/gpt2_config.yaml | 5 +- spd/experiments/lm/lm_decomposition.py | 4 +- .../lm/pile_llama_simple_mlp-2L.yaml | 5 +- .../lm/pile_llama_simple_mlp-4L.yaml | 5 +- spd/experiments/lm/ss_gpt2_config.yaml | 5 +- spd/experiments/lm/ss_gpt2_simple-1L.yaml | 5 +- spd/experiments/lm/ss_gpt2_simple-2L.yaml | 5 +- spd/experiments/lm/ss_gpt2_simple_config.yaml | 5 +- .../lm/ss_gpt2_simple_noln_config.yaml | 5 +- spd/experiments/lm/ss_llama_simple-1L.yaml | 5 +- spd/experiments/lm/ss_llama_simple-2L.yaml | 5 +- .../lm/ss_llama_simple_config.yaml | 5 +- .../lm/ss_llama_simple_mlp-1L.yaml | 5 +- .../lm/ss_llama_simple_mlp-2L-wide.yaml | 5 +- .../lm/ss_llama_simple_mlp-2L.yaml | 5 +- spd/experiments/lm/ss_llama_simple_mlp.yaml | 5 +- spd/experiments/lm/ts_config.yaml | 5 +- .../resid_mlp/resid_mlp1_config.yaml | 1 - .../resid_mlp/resid_mlp2_config.yaml | 1 - .../resid_mlp/resid_mlp3_config.yaml | 1 - .../resid_mlp/resid_mlp_decomposition.py | 3 +- spd/experiments/tms/tms_40-10-id_config.yaml | 1 - spd/experiments/tms/tms_40-10_config.yaml | 1 - spd/experiments/tms/tms_5-2-id_config.yaml | 1 - spd/experiments/tms/tms_5-2_config.yaml | 1 - spd/experiments/tms/tms_decomposition.py | 3 +- spd/models/batch_and_loss_fns.py | 32 ++++++++- spd/models/component_model.py | 65 ++++--------------- spd/run_spd.py | 18 +++-- tests/app/test_server_api.py | 14 ++-- tests/metrics/fixtures.py | 3 + tests/scripts_run/test_grid_search.py | 3 - tests/test_component_model.py | 57 +++------------- tests/test_distributed.py | 3 +- tests/test_gpt2.py | 8 ++- tests/test_ih_transformer.py | 5 +- tests/test_resid_mlp.py | 5 +- tests/test_spd_losses.py | 3 +- tests/test_tms.py | 5 +- 44 files changed, 173 insertions(+), 195 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 39a169aa2..8b1fe0606 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -193,7 +193,7 @@ def wte_hook( "mlp.c_fc", "mlp.down_proj", ] - n_blocks = get_model_n_blocks(cast(nn.Module, model.target_model)) + n_blocks = get_model_n_blocks(model.target_model) for i in range(n_blocks): layers.extend([f"h.{i}.{layer_name}" for layer_name in component_layer_names]) diff --git a/spd/autointerp/interpret.py b/spd/autointerp/interpret.py index cbf892ad7..5e38e4985 100644 --- a/spd/autointerp/interpret.py +++ b/spd/autointerp/interpret.py @@ -5,10 +5,8 @@ from dataclasses import asdict, dataclass from enum import StrEnum from pathlib import Path -from typing import cast import httpx -import torch.nn as nn from openrouter import OpenRouter from openrouter.components import JSONSchemaConfig, MessageTypedDict, ResponseFormatJSONSchema from openrouter.errors import ( @@ -337,7 +335,7 @@ async def process_one( def get_architecture_info(wandb_path: str) -> ArchitectureInfo: run_info = SPDRunInfo.from_path(wandb_path) model = ComponentModel.from_run_info(run_info) - n_blocks = get_model_n_blocks(cast(nn.Module, model.target_model)) + n_blocks = get_model_n_blocks(model.target_model) config = run_info.config task_config = config.task_config assert isinstance(task_config, LMTaskConfig) diff --git a/spd/configs.py b/spd/configs.py index 3deccee0b..8d362caf3 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -148,6 +148,19 @@ class LMTaskConfig(BaseConfig): ) +class IndexOutputExtract(BaseConfig): + type: Literal["index"] = "index" + index: int + + +class AttrOutputExtract(BaseConfig): + type: Literal["attr"] = "attr" + attr: str + + +OutputExtractConfig = IndexOutputExtract | AttrOutputExtract + + class ModulePatternInfoConfig(BaseConfig): """Configuration for a module pattern with its number of components. @@ -461,11 +474,6 @@ def all_module_info(self) -> list[ModulePatternInfoConfig]: ), ) ) - output_loss_type: Literal["mse", "kl"] = Field( - ..., - description="Metric used to measure recon error between model outputs and targets", - ) - # --- Training --- lr_schedule: ScheduleConfig = Field(..., description="Learning rate schedule configuration") steps: NonNegativeInt = Field(..., description="Total number of optimisation steps") @@ -566,9 +574,9 @@ def microbatch_size(self) -> PositiveInt: default=None, description="hf model identifier. E.g. 'SimpleStories/SimpleStories-1.25M'", ) - extract_tensor_output: str | None = Field( + output_extract: Annotated[OutputExtractConfig, Field(discriminator="type")] | None = Field( default=None, - description="Accessor path for extracting tensor from model output, e.g. '[0]' or '.logits'", + description="How to extract tensor from model output. None = raw output. Note that you can ignore this field if you plan to create your own `run_batch` function to pass to run_spd.optimize().", ) tokenizer_name: str | None = Field( default=None, @@ -607,6 +615,7 @@ def microbatch_size(self) -> PositiveInt: "lr_exponential_halflife", "out_dir", "n_examples_until_dead", + "output_loss_type", ] RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = { "grad_clip_norm": "grad_clip_norm_components", @@ -652,15 +661,18 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, "simple_stories_train.models.", "spd.pretrain.models.", 1 ) - # Migrate old pretrained_model_output_attr to extract_tensor_output + # Migrate old pretrained_model_output_attr to output_extract if "pretrained_model_output_attr" in config_dict: old_val = config_dict.pop("pretrained_model_output_attr") - if old_val is not None: - accessor = "[0]" if old_val == "idx_0" else f".{old_val}" - config_dict["extract_tensor_output"] = accessor - logger.info( - f"Migrated pretrained_model_output_attr={old_val!r} to extract_tensor_output={accessor!r}" - ) + match old_val: + case None: + pass + case "idx_0": + config_dict["output_extract"] = {"type": "index", "index": 0} + case "logits": + config_dict["output_extract"] = {"type": "attr", "attr": "logits"} + case _: + raise ValueError(f"Unknown pretrained_model_output_attr: {old_val!r}") if "eval_batch_size" not in config_dict: config_dict["eval_batch_size"] = config_dict["batch_size"] diff --git a/spd/experiments/ih/ih_config.yaml b/spd/experiments/ih/ih_config.yaml index 9c844723a..6391b5dd9 100644 --- a/spd/experiments/ih/ih_config.yaml +++ b/spd/experiments/ih/ih_config.yaml @@ -33,7 +33,6 @@ ci_recon_layerwise_coeff: null stochastic_recon_layerwise_coeff: 1 importance_minimality_coeff: 1e-2 pnorm: 0.1 -output_loss_type: kl ci_fn_type: "vector_mlp" ci_fn_hidden_dims: [128] diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 0fef3ff71..121747466 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -7,7 +7,7 @@ from spd.configs import Config, IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer from spd.log import logger -from spd.models.batch_and_loss_fns import recon_loss_kl +from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_passthrough from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.distributed_utils import get_device @@ -97,6 +97,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=run_batch_passthrough, reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/lm/gpt2_config.yaml b/spd/experiments/lm/gpt2_config.yaml index 3a2061df9..a58bd4d1b 100644 --- a/spd/experiments/lm/gpt2_config.yaml +++ b/spd/experiments/lm/gpt2_config.yaml @@ -25,7 +25,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 2 @@ -64,7 +63,9 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: openai-community/gpt2 -extract_tensor_output: ".logits" +output_extract: + type: attr + attr: logits tokenizer_name: openai-community/gpt2 # --- Task Specific --- diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 889b9c98f..e189ef1e2 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -9,7 +9,7 @@ from spd.configs import Config, LMTaskConfig from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.log import logger -from spd.models.batch_and_loss_fns import recon_loss_kl +from spd.models.batch_and_loss_fns import make_run_batch, recon_loss_kl from spd.pretrain.run_info import PretrainRunInfo from spd.run_spd import optimize from spd.utils.distributed_utils import ( @@ -177,12 +177,14 @@ def main( if is_main_process(): logger.info("Starting optimization...") + assert config.output_extract is not None, "LM models require output_extract" optimize( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=make_run_batch(config.output_extract), reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml index e78b13819..81529ec37 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml @@ -67,7 +67,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 1.0e-04 warmup_pct: 0.0 @@ -115,7 +114,9 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-bd02d372 -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index becda3651..787f10998 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -67,7 +67,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 1.0e-04 warmup_pct: 0.0 @@ -119,7 +118,9 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-32d1bb3b -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/ss_gpt2_config.yaml b/spd/experiments/lm/ss_gpt2_config.yaml index ef1b2a6a2..0ac2214e1 100644 --- a/spd/experiments/lm/ss_gpt2_config.yaml +++ b/spd/experiments/lm/ss_gpt2_config.yaml @@ -25,7 +25,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 16 @@ -64,7 +63,9 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: SimpleStories/test-SimpleStories-gpt2-1.25M -extract_tensor_output: ".logits" +output_extract: + type: attr + attr: logits tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index e80a03a8a..aab53d5c7 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -47,7 +47,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 0.0002 fn_type: cosine @@ -92,7 +91,9 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/3qhd7rnb # 100k steps. 4019 tokenizer -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index ffb197d19..1b2dd89fa 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -47,7 +47,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 0.0002 fn_type: cosine @@ -94,7 +93,9 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/wr1su18m # 100k steps. 4019 tokenizer -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index f9a24a7dc..9032cc46b 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -38,7 +38,6 @@ loss_metric_configs: routing: type: uniform_k_subset coeff: 1.0 -output_loss_type: kl # --- Training --- batch_size: 256 @@ -97,7 +96,9 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/rvu66183 # 100k steps. 4019 tokenizer -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml index 5f9ede6a3..212e9c961 100644 --- a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml @@ -36,7 +36,6 @@ loss_metric_configs: coeff: 2.0 - classname: "StochasticReconLoss" coeff: 0.2 -output_loss_type: kl # --- Training --- batch_size: 48 @@ -94,7 +93,9 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/xi36b9az # No ln -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # We'll load this from wandb in future # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index 0442c6a0f..c7e3f1711 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -48,7 +48,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -92,7 +91,9 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tfacbi70 # 100k steps -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index 3f0cc95e3..c1807a7e1 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -48,7 +48,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -94,7 +93,9 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tb8373uo # 100k steps -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_config.yaml b/spd/experiments/lm/ss_llama_simple_config.yaml index 1fbb6a661..71ea7a034 100644 --- a/spd/experiments/lm/ss_llama_simple_config.yaml +++ b/spd/experiments/lm/ss_llama_simple_config.yaml @@ -42,7 +42,6 @@ loss_metric_configs: type: uniform_k_subset coeff: 1.0 -output_loss_type: kl # --- Training --- batch_size: 256 @@ -94,7 +93,9 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_name: wandb:goodfire/spd/runs/erq48r3w # 100k steps 4019 tok -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index c8a0eb7a8..e7cabb6a4 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -42,7 +42,6 @@ loss_metric_configs: classname: PGDReconSubsetLoss - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 400000 batch_size: 64 gradient_accumulation_steps: 1 @@ -86,7 +85,9 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gvbmdt9w -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index 8a3327212..24406f216 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -46,7 +46,6 @@ loss_metric_configs: type: uniform_k_subset - classname: FaithfulnessLoss coeff: 1000000 -output_loss_type: kl steps: 400000 batch_size: 128 gradient_accumulation_steps: 1 @@ -94,7 +93,9 @@ ci_alive_threshold: 0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: buffer_size: 1000 diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index 7e2a19097..e0933f4d1 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -46,7 +46,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -92,7 +91,9 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/7pt957pf # 100k steps -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index c8a33b540..1cb007f87 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -42,7 +42,6 @@ loss_metric_configs: classname: PGDReconSubsetLoss - coeff: 100000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 400000 batch_size: 128 gradient_accumulation_steps: 1 @@ -117,7 +116,9 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/9de1zu65 # 100k steps -extract_tensor_output: "[0]" +output_extract: + type: index + index: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ts_config.yaml b/spd/experiments/lm/ts_config.yaml index 2c23bdce0..1e252c946 100644 --- a/spd/experiments/lm/ts_config.yaml +++ b/spd/experiments/lm/ts_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 4 @@ -65,7 +64,9 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.AutoModelForCausalLM pretrained_model_name: roneneldan/TinyStories-1M -extract_tensor_output: ".logits" +output_extract: + type: attr + attr: logits tokenizer_name: EleutherAI/gpt-neo-125M # --- Task Specific --- diff --git a/spd/experiments/resid_mlp/resid_mlp1_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_config.yaml index 1d8bc7fed..093c8e554 100644 --- a/spd/experiments/resid_mlp/resid_mlp1_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp1_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp2_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_config.yaml index bae662b6f..b9d7e48fb 100644 --- a/spd/experiments/resid_mlp/resid_mlp2_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp2_config.yaml @@ -34,7 +34,6 @@ loss_metric_configs: mask_scope: shared_across_batch - classname: "FaithfulnessLoss" coeff: 0.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp3_config.yaml b/spd/experiments/resid_mlp/resid_mlp3_config.yaml index dac4f9c10..c3dbda014 100644 --- a/spd/experiments/resid_mlp/resid_mlp3_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp3_config.yaml @@ -27,7 +27,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index f49615c20..09214c396 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -13,7 +13,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.log import logger -from spd.models.batch_and_loss_fns import recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.distributed_utils import get_device @@ -109,6 +109,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=run_batch_passthrough, reconstruction_loss=recon_loss_mse, out_dir=out_dir, ) diff --git a/spd/experiments/tms/tms_40-10-id_config.yaml b/spd/experiments/tms/tms_40-10-id_config.yaml index e3e40d5fc..f3b21c094 100644 --- a/spd/experiments/tms/tms_40-10-id_config.yaml +++ b/spd/experiments/tms/tms_40-10-id_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: "mse" # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_40-10_config.yaml b/spd/experiments/tms/tms_40-10_config.yaml index a4aeb6a97..7d264c77d 100644 --- a/spd/experiments/tms/tms_40-10_config.yaml +++ b/spd/experiments/tms/tms_40-10_config.yaml @@ -27,7 +27,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: "mse" # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_5-2-id_config.yaml b/spd/experiments/tms/tms_5-2-id_config.yaml index c9b2234e8..11670b63d 100644 --- a/spd/experiments/tms/tms_5-2-id_config.yaml +++ b/spd/experiments/tms/tms_5-2-id_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_5-2_config.yaml b/spd/experiments/tms/tms_5-2_config.yaml index 34c92fa08..cc5d6a668 100644 --- a/spd/experiments/tms/tms_5-2_config.yaml +++ b/spd/experiments/tms/tms_5-2_config.yaml @@ -26,7 +26,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 2471c1830..ee4258629 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -13,7 +13,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSTargetRunInfo from spd.log import logger -from spd.models.batch_and_loss_fns import recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.distributed_utils import get_device @@ -105,6 +105,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=run_batch_passthrough, reconstruction_loss=recon_loss_mse, out_dir=out_dir, tied_weights=tied_weights, diff --git a/spd/models/batch_and_loss_fns.py b/spd/models/batch_and_loss_fns.py index 90914e263..fb0ce4b8b 100644 --- a/spd/models/batch_and_loss_fns.py +++ b/spd/models/batch_and_loss_fns.py @@ -3,12 +3,20 @@ These functions parameterize ComponentModel and training for different target model architectures. """ -from typing import Protocol +from typing import Any, Protocol import torch import torch.nn.functional as F from jaxtyping import Float -from torch import Tensor +from torch import Tensor, nn + +from spd.configs import AttrOutputExtract, IndexOutputExtract, OutputExtractConfig + + +class RunBatch[BatchT, OutputT](Protocol): + """Protocol for running a batch through a model and returning the output.""" + + def __call__(self, model: nn.Module, batch: BatchT) -> OutputT: ... class ReconstructionLoss[OutputT](Protocol): @@ -17,6 +25,26 @@ class ReconstructionLoss[OutputT](Protocol): def __call__(self, pred: OutputT, target: OutputT) -> tuple[Float[Tensor, ""], int]: ... +def run_batch_passthrough(model: nn.Module, batch: Any) -> Any: + return model(batch) + + +def make_run_batch(output_extract: OutputExtractConfig | None) -> RunBatch[Any, Any]: + """creates a RunBatch function for a given configuration. + + Note that if you plan to override the RunBatch functionality, you can simply pass + a custom RunBatch function into optimize and do not need to use this function at + all. + """ + match output_extract: + case None: + return run_batch_passthrough + case IndexOutputExtract(index=idx): + return lambda model, batch: model(batch)[idx] + case AttrOutputExtract(attr=attr): + return lambda model, batch: getattr(model(batch), attr) + + def recon_loss_mse( pred: Float[Tensor, "... d"], target: Float[Tensor, "... d"], diff --git a/spd/models/component_model.py b/spd/models/component_model.py index fa6c7434a..2f7038ac7 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,9 +1,8 @@ -import re -from collections.abc import Callable, Generator, Iterator +from collections.abc import Callable, Generator from contextlib import contextmanager from dataclasses import dataclass from functools import partial -from typing import Any, Literal, NamedTuple, Protocol, overload, override +from typing import Any, Literal, NamedTuple, overload, override import torch from jaxtyping import Float, Int @@ -14,6 +13,7 @@ from spd.configs import Config, SamplingType from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo +from spd.models.batch_and_loss_fns import RunBatch, make_run_batch from spd.models.components import ( Components, ComponentsMaskInfo, @@ -29,29 +29,6 @@ from spd.utils.general_utils import resolve_class, runtime_cast from spd.utils.module_utils import ModulePathInfo, expand_module_patterns -_ACCESSOR_TOKEN_RE = re.compile(r'\.\w+|\[\d+\]|\["\w+"\]') - - -def extract_with_accessor(obj: Any, accessor: str) -> Any: - """Navigate a nested object using an accessor path string. - - Supports attribute access (.attr), integer indexing ([i]), and string-key - dictionary access (["key"]). - Examples: "[0]", ".logits", "[0].logits[2]", '["hidden_states"]' - """ - assert accessor, "Accessor must be non-empty (use None for no extraction)" - tokens = _ACCESSOR_TOKEN_RE.findall(accessor) - assert "".join(tokens) == accessor, f"Invalid accessor: {accessor!r}" - result = obj - for token in tokens: - if token.startswith('["'): - result = result[token[2:-2]] - elif token.startswith("["): - result = result[int(token[1:-1])] - else: - result = getattr(result, token[1:]) - return result - @dataclass class SPDRunInfo(RunInfo[Config]): @@ -76,18 +53,6 @@ class CIOutputs: pre_sigmoid: dict[str, Tensor] -class TargetModel[BatchT, OutputT](Protocol): - # def __call__(self, batch: BatchT) -> OutputT: ... - - def __call__(self, batch: BatchT) -> OutputT: ... - - def get_submodule(self, target: str) -> nn.Module: ... - - def named_parameters(self) -> Iterator[tuple[str, nn.Parameter]]: ... - - # def named_modules(self) -> Generator[tuple[str, nn.Module]]: ... - - class ComponentModel[BatchT, OutputT](nn.Module): """Wrapper around an arbitrary pytorch model for running SPD. @@ -109,15 +74,15 @@ class ComponentModel[BatchT, OutputT](nn.Module): def __init__( self, - target_model: TargetModel[BatchT, OutputT], + target_model: nn.Module, + run_batch: RunBatch[BatchT, OutputT], module_path_info: list[ModulePathInfo], ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], sigmoid_type: SigmoidType, - extract_tensor_output: str | None = None, ): super().__init__() - self.extract_tensor_output = extract_tensor_output + self._run_batch: RunBatch[BatchT, OutputT] = run_batch for name, param in target_model.named_parameters(): assert not param.requires_grad, ( @@ -180,13 +145,15 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any, Any]": module_path_info = expand_module_patterns(target_model, config.all_module_info) + run_batch = make_run_batch(config.output_extract) + comp_model: ComponentModel[Any, Any] = cls( target_model=target_model, + run_batch=run_batch, module_path_info=module_path_info, ci_fn_hidden_dims=config.ci_fn_hidden_dims, ci_fn_type=config.ci_fn_type, sigmoid_type=config.sigmoid_type, - extract_tensor_output=config.extract_tensor_output, ) weights = torch.load(run_info.checkpoint_path, map_location="cpu", weights_only=True) @@ -256,7 +223,7 @@ def _create_component( @staticmethod def _create_components( - target_model: TargetModel[BatchT, OutputT], + target_model: nn.Module, module_to_c: dict[str, int], ) -> dict[str, Components]: components: dict[str, Components] = {} @@ -299,7 +266,7 @@ def _create_ci_fn( @staticmethod def _create_ci_fns( - target_model: TargetModel[BatchT, OutputT], + target_model: nn.Module, module_to_c: dict[str, int], ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], @@ -373,8 +340,7 @@ def forward( model output tensor. """ if mask_infos is None and cache_type == "none": - # No hooks needed. Do a regular forward pass of the target model. - return self._extract_output(self.target_model(batch)) + return self._run_batch(self.target_model, batch) cache: dict[str, Tensor] = {} hooks: dict[str, Callable[..., Any]] = {} @@ -395,7 +361,7 @@ def forward( ) with self._attach_forward_hooks(hooks): - out: OutputT = self._extract_output(self.target_model(batch)) + out: OutputT = self._run_batch(self.target_model, batch) match cache_type: case "input" | "component_acts": @@ -403,11 +369,6 @@ def forward( case "none": return out - def _extract_output(self, raw_output: Any) -> Any: - if self.extract_tensor_output is None: - return raw_output - return extract_with_accessor(raw_output, self.extract_tensor_output) - def _components_and_cache_hook( self, _module: nn.Module, diff --git a/spd/run_spd.py b/spd/run_spd.py index 11d345027..66ef51032 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -30,8 +30,8 @@ from spd.log import logger from spd.losses import compute_total_loss from spd.metrics import faithfulness_loss -from spd.models.batch_and_loss_fns import ReconstructionLoss -from spd.models.component_model import ComponentModel, OutputWithCache, TargetModel +from spd.models.batch_and_loss_fns import ReconstructionLoss, RunBatch +from spd.models.component_model import ComponentModel, OutputWithCache from spd.utils.component_utils import calc_ci_l_zero from spd.utils.distributed_utils import ( avg_metrics_across_ranks, @@ -43,7 +43,6 @@ bf16_autocast, dict_safe_update_, get_scheduled_value, - runtime_cast, ) from spd.utils.logging_utils import get_grad_norms_dict, local_log from spd.utils.module_utils import expand_module_patterns @@ -111,11 +110,12 @@ def get_unique_metric_configs( def optimize[BatchT, OutputT]( - target_model: TargetModel[BatchT, OutputT], + target_model: nn.Module, config: Config, device: str, train_loader: DataLoader[BatchT], eval_loader: DataLoader[BatchT], + run_batch: RunBatch[BatchT, OutputT], reconstruction_loss: ReconstructionLoss[OutputT], out_dir: Path | None, tied_weights: list[tuple[str, str]] | None = None, @@ -135,23 +135,21 @@ def create_pgd_data_iter() -> Iterator[BatchT]: if config.identity_module_info is not None: insert_identity_operations_( - runtime_cast(nn.Module, target_model), + target_model, identity_module_info=config.identity_module_info, ) - cast(nn.Module, target_model).requires_grad_(False) + target_model.requires_grad_(False) - module_path_info = expand_module_patterns( - runtime_cast(nn.Module, target_model), config.all_module_info - ) + module_path_info = expand_module_patterns(target_model, config.all_module_info) model = ComponentModel( target_model=target_model, + run_batch=run_batch, module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, sigmoid_type=config.sigmoid_type, - extract_tensor_output=config.extract_tensor_output, ) model.to(device) diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 8224d2fd5..70f3d62c2 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -20,7 +20,14 @@ from spd.app.backend.routers import runs as runs_router from spd.app.backend.server import app from spd.app.backend.state import HarvestCache, RunState, StateManager -from spd.configs import Config, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig +from spd.configs import ( + Config, + IndexOutputExtract, + LMTaskConfig, + ModulePatternInfoConfig, + ScheduleConfig, +) +from spd.models.batch_and_loss_fns import make_run_batch from spd.models.component_model import ComponentModel from spd.pretrain.models.gpt2_simple import GPT2Simple, GPT2SimpleConfig from spd.utils.module_utils import expand_module_patterns @@ -91,9 +98,8 @@ def app_with_state(): ModulePatternInfoConfig(module_pattern=p, C=C) for p in target_module_patterns ], pretrained_model_class="spd.pretrain.models.gpt2_simple.GPT2Simple", - extract_tensor_output="[0]", + output_extract=IndexOutputExtract(index=0), tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - output_loss_type="kl", lr_schedule=ScheduleConfig(start_val=1e-3), steps=1, batch_size=1, @@ -114,11 +120,11 @@ def app_with_state(): module_path_info = expand_module_patterns(target_model, config.module_info) model = ComponentModel( target_model=target_model, + run_batch=make_run_batch(config.output_extract), module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, sigmoid_type=config.sigmoid_type, - extract_tensor_output=config.extract_tensor_output, ) model.eval() sources_by_target = get_sources_by_target( diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index ae97199ff..7c35e2473 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -7,6 +7,7 @@ from jaxtyping import Float from torch import Tensor +from spd.models.batch_and_loss_fns import run_batch_passthrough from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo @@ -55,6 +56,7 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo comp_model = ComponentModel( target_model=target, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_fn_hidden_dims=[2], ci_fn_type="mlp", @@ -88,6 +90,7 @@ def make_two_layer_component_model( comp_model = ComponentModel( target_model=target, + run_batch=run_batch_passthrough, module_path_info=[ ModulePathInfo(module_path="fc1", C=1), ModulePathInfo(module_path="fc2", C=1), diff --git a/tests/scripts_run/test_grid_search.py b/tests/scripts_run/test_grid_search.py index a5f59c65d..51bd84618 100644 --- a/tests/scripts_run/test_grid_search.py +++ b/tests/scripts_run/test_grid_search.py @@ -336,7 +336,6 @@ def test_tms_config_with_loss_sweep(self): "coeff": 1.0, }, ], - "output_loss_type": "mse", "lr": 0.001, "steps": 1000, "batch_size": 32, @@ -386,7 +385,6 @@ def test_lm_config_with_loss_sweep(self): "eps": 1e-12, } ], - "output_loss_type": "kl", "lr": 0.001, "steps": 1000, "batch_size": 32, @@ -451,7 +449,6 @@ def test_full_sweep_workflow(self): "eps": 1e-12, } ], - "output_loss_type": "mse", "lr": 0.01, # Will be overridden "steps": 1000, "batch_size": 32, diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 37e8d3655..9868749f6 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -17,10 +17,10 @@ ) from spd.identity_insertion import insert_identity_operations_ from spd.interfaces import LoadableModule, RunInfo +from spd.models.batch_and_loss_fns import run_batch_passthrough from spd.models.component_model import ( ComponentModel, SPDRunInfo, - extract_with_accessor, ) from spd.models.components import ( ComponentsMaskInfo, @@ -83,6 +83,7 @@ def test_correct_parameters_require_grad(): component_model = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ ModulePathInfo(module_path="linear1", C=4), ModulePathInfo(module_path="linear2", C=8), @@ -152,7 +153,6 @@ def test_from_run_info(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -171,6 +171,7 @@ def test_from_run_info(): module_path_info = expand_module_patterns(target_model, config.all_module_info) cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=module_path_info, ci_fn_type=config.ci_fn_type, ci_fn_hidden_dims=config.ci_fn_hidden_dims, @@ -199,6 +200,7 @@ def test_from_run_info(): ) cm_loaded = ComponentModel( target_model=loaded_target, + run_batch=run_batch_passthrough, module_path_info=loaded_module_path_info, ci_fn_type=cm_run_info.config.ci_fn_type, ci_fn_hidden_dims=cm_run_info.config.ci_fn_hidden_dims, @@ -299,6 +301,7 @@ def test_full_weight_delta_matches_target_behaviour(): C = 4 cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[4], @@ -330,6 +333,7 @@ def test_input_cache_captures_pre_weight_input(): cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path=p, C=2) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[2], @@ -364,6 +368,7 @@ def test_weight_deltas(): target_module_paths = ["embed", "mlp", "out"] cm = ComponentModel( target_model=target_model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path=p, C=3) for p in target_module_paths], ci_fn_type="mlp", ci_fn_hidden_dims=[2], @@ -398,6 +403,7 @@ def forward(self, x: Tensor) -> Tensor: cm = ComponentModel( target_model=model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], @@ -453,6 +459,7 @@ def forward(self, x: Tensor) -> Tensor: # wrapped in a component model that decomposes the prepended identity layer cm = ComponentModel( target_model=model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="linear.pre_identity", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], @@ -502,6 +509,7 @@ def forward(self, x: Tensor) -> Tensor: # wrapped in a component model that decomposes the layer cm = ComponentModel( target_model=model, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_fn_type="mlp", ci_fn_hidden_dims=[2], @@ -539,48 +547,3 @@ def forward(self, x: Tensor) -> Tensor: # but it should be the same for the second example (where it's not routed to components) assert torch.allclose(cm_routed_out[1], target_out[1]) - - -class TestExtractWithAccessor: - def test_integer_index(self): - obj = ("a", "b", "c") - assert extract_with_accessor(obj, "[0]") == "a" - assert extract_with_accessor(obj, "[2]") == "c" - - def test_attribute_access(self): - class Obj: - logits = 42 - - assert extract_with_accessor(Obj(), ".logits") == 42 - - def test_string_key_dict_access(self): - obj = {"hidden_states": "hs", "logits": "lg"} - assert extract_with_accessor(obj, '["hidden_states"]') == "hs" - assert extract_with_accessor(obj, '["logits"]') == "lg" - - def test_chained_accessors(self): - class Inner: - value = 99 - - obj = ({"data": Inner()},) - assert extract_with_accessor(obj, '[0]["data"].value') == 99 - - def test_index_then_attribute(self): - class Out: - logits = torch.tensor([1.0, 2.0]) - - obj = (Out(),) - result = extract_with_accessor(obj, "[0].logits") - assert torch.equal(result, torch.tensor([1.0, 2.0])) - - def test_invalid_accessor_raises(self): - with pytest.raises(AssertionError, match="Invalid accessor"): - extract_with_accessor({}, "invalid") - - def test_invalid_accessor_with_special_chars(self): - with pytest.raises(AssertionError, match="Invalid accessor"): - extract_with_accessor({}, '[" spaces "]') - - def test_empty_accessor_raises(self): - with pytest.raises(AssertionError, match="non-empty"): - extract_with_accessor({}, "") diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 04ad8b0f2..b88f4a7b6 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -34,7 +34,6 @@ {"classname": "CIMaskedReconLayerwiseLoss", "coeff": 1.0}, {"classname": "CIMaskedReconLoss", "coeff": 1.0}, ], - "output_loss_type": "kl", # --- Training --- "batch_size": 2, "steps": 20, @@ -54,7 +53,7 @@ # --- Pretrained model info --- "pretrained_model_class": "transformers.LlamaForCausalLM", "pretrained_model_name": "SimpleStories/SimpleStories-1.25M", - "extract_tensor_output": ".logits", + "output_extract": {"type": "attr", "attr": "logits"}, "tokenizer_name": "SimpleStories/SimpleStories-1.25M", # --- Task Specific --- "task_config": { diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index fa965b0b1..0784a0e5d 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -4,6 +4,7 @@ from transformers import PreTrainedModel from spd.configs import ( + AttrOutputExtract, CI_L0Config, Config, FaithfulnessLossConfig, @@ -16,7 +17,7 @@ ) from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.identity_insertion import insert_identity_operations_ -from spd.models.batch_and_loss_fns import recon_loss_kl +from spd.models.batch_and_loss_fns import make_run_batch, recon_loss_kl from spd.run_spd import optimize from spd.utils.general_utils import resolve_class, set_seed @@ -56,7 +57,6 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=200), ], - output_loss_type="kl", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -79,7 +79,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="transformers.GPT2LMHeadModel", pretrained_model_path=None, pretrained_model_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - extract_tensor_output=".logits", + output_extract=AttrOutputExtract(attr="logits"), tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", # Task Specific task_config=LMTaskConfig( @@ -146,12 +146,14 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: ) # Run optimize function + assert config.output_extract is not None optimize( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=make_run_batch(config.output_extract), reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index bff6a25e6..4de05d05f 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -17,7 +17,7 @@ from spd.experiments.ih.configs import InductionModelConfig from spd.experiments.ih.model import InductionTransformer from spd.identity_insertion import insert_identity_operations_ -from spd.models.batch_and_loss_fns import recon_loss_kl +from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_passthrough from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.general_utils import set_seed @@ -72,7 +72,6 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=200), ], - output_loss_type="kl", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -96,7 +95,6 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.ih.model.InductionTransformer", pretrained_model_path=None, pretrained_model_name=None, - extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=IHTaskConfig( @@ -134,6 +132,7 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=run_batch_passthrough, reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index be02eae6d..4c502aabc 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -13,7 +13,7 @@ from spd.experiments.resid_mlp.models import ResidMLP from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.identity_insertion import insert_identity_operations_ -from spd.models.batch_and_loss_fns import recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.general_utils import set_seed @@ -63,7 +63,6 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: identity_module_info=[ ModulePatternInfoConfig(module_pattern="layers.*.mlp_in", C=10), ], - output_loss_type="mse", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -83,7 +82,6 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.resid_mlp.models.ResidMLP", pretrained_model_path=None, pretrained_model_name=None, - extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=ResidMLPTaskConfig( @@ -130,6 +128,7 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=run_batch_passthrough, reconstruction_loss=recon_loss_mse, out_dir=tmp_path, ) diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index b83a4fe3e..8bd948b44 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -16,7 +16,7 @@ stochastic_recon_loss, stochastic_recon_subset_loss, ) -from spd.models.batch_and_loss_fns import recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough from spd.models.component_model import ComponentModel from spd.utils.module_utils import ModulePathInfo @@ -40,6 +40,7 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel comp_model = ComponentModel( target_model=target, + run_batch=run_batch_passthrough, module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_fn_hidden_dims=[2], ci_fn_type="mlp", diff --git a/tests/test_tms.py b/tests/test_tms.py index 9117992bf..e717afcb8 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -18,7 +18,7 @@ from spd.experiments.tms.models import TMSModel from spd.experiments.tms.train_tms import get_model_and_dataloader, train from spd.identity_insertion import insert_identity_operations_ -from spd.models.batch_and_loss_fns import recon_loss_mse +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.general_utils import set_seed @@ -69,7 +69,6 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=1.0), ], - output_loss_type="mse", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.0, final_val_frac=0.0 @@ -92,7 +91,6 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.tms.models.TMSModel", pretrained_model_path=None, pretrained_model_name=None, - extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=TMSTaskConfig( @@ -138,6 +136,7 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, + run_batch=run_batch_passthrough, reconstruction_loss=recon_loss_mse, out_dir=tmp_path, tied_weights=tied_weights, From a94b580e3ed0acce4d7fbcf5ed7328955cf99dad Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 11 Feb 2026 07:58:41 +0000 Subject: [PATCH 10/16] Fix typing --- spd/eval.py | 56 +++++++++++++++++---------------- spd/metrics/ce_and_kl_losses.py | 4 +-- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/spd/eval.py b/spd/eval.py index daf281ef8..a2f724e75 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -36,28 +36,31 @@ UnmaskedReconLossConfig, UVPlotsConfig, ) -from spd.metrics import UnmaskedReconLoss +from spd.metrics import ( + CI_L0, + CEandKLLosses, + CIHistograms, + CIMaskedReconLayerwiseLoss, + CIMaskedReconLoss, + CIMaskedReconSubsetLoss, + CIMeanPerComponent, + ComponentActivationDensity, + FaithfulnessLoss, + IdentityCIError, + ImportanceMinimalityLoss, + PermutedCIPlots, + PGDReconLayerwiseLoss, + PGDReconLoss, + PGDReconSubsetLoss, + StochasticHiddenActsReconLoss, + StochasticReconLayerwiseLoss, + StochasticReconLoss, + StochasticReconSubsetLoss, + UnmaskedReconLoss, + UVPlots, +) from spd.metrics.base import Metric -from spd.metrics.ci_histograms import CIHistograms -from spd.metrics.ci_l0 import CI_L0 -from spd.metrics.ci_masked_recon_layerwise_loss import CIMaskedReconLayerwiseLoss -from spd.metrics.ci_masked_recon_loss import CIMaskedReconLoss -from spd.metrics.ci_masked_recon_subset_loss import CIMaskedReconSubsetLoss -from spd.metrics.ci_mean_per_component import CIMeanPerComponent -from spd.metrics.component_activation_density import ComponentActivationDensity -from spd.metrics.faithfulness_loss import FaithfulnessLoss -from spd.metrics.identity_ci_error import IdentityCIError -from spd.metrics.importance_minimality_loss import ImportanceMinimalityLoss -from spd.metrics.permuted_ci_plots import PermutedCIPlots -from spd.metrics.pgd_masked_recon_layerwise_loss import PGDReconLayerwiseLoss -from spd.metrics.pgd_masked_recon_loss import PGDReconLoss -from spd.metrics.pgd_masked_recon_subset_loss import PGDReconSubsetLoss from spd.metrics.pgd_utils import CreateDataIter, calc_multibatch_pgd_masked_recon_loss -from spd.metrics.stochastic_hidden_acts_recon_loss import StochasticHiddenActsReconLoss -from spd.metrics.stochastic_recon_layerwise_loss import StochasticReconLayerwiseLoss -from spd.metrics.stochastic_recon_loss import StochasticReconLoss -from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss -from spd.metrics.uv_plots import UVPlots from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel, OutputWithCache from spd.routing import AllLayersRouter, get_subset_router @@ -138,13 +141,12 @@ def init_metric[BatchT, OutputT]( device=device, ) case CEandKLLossesConfig(): - raise ValueError("fix this typing!") - # metric = CEandKLLosses( - # model=model, - # device=device, - # sampling=run_config.sampling, - # rounding_threshold=cfg.rounding_threshold, - # ) + metric = CEandKLLosses( + model=model, + device=device, + sampling=run_config.sampling, + rounding_threshold=cfg.rounding_threshold, + ) case CIHistogramsConfig(): metric = CIHistograms(model=model, n_batches_accum=cfg.n_batches_accum) case CI_L0Config(): diff --git a/spd/metrics/ce_and_kl_losses.py b/spd/metrics/ce_and_kl_losses.py index 872336954..79db5ec51 100644 --- a/spd/metrics/ce_and_kl_losses.py +++ b/spd/metrics/ce_and_kl_losses.py @@ -17,7 +17,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class CEandKLLosses(Metric[Tensor, Tensor]): +class CEandKLLosses(Metric[Any, Any]): """CE and KL losses for different masking strategies. NOTE: Assumes all batches and sequences are the same size. @@ -47,7 +47,7 @@ class CEandKLLosses(Metric[Tensor, Tensor]): def __init__( self, - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Any, Any], device: str, sampling: SamplingType, rounding_threshold: float, From ca22236f789937168bbdc8435a121181aaec2c90 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 11 Feb 2026 08:03:53 +0000 Subject: [PATCH 11/16] Fix typing for StochasticReconSubsetCEAndKL --- spd/eval.py | 20 +++++++++---------- .../stochastic_recon_subset_ce_and_kl.py | 4 ++-- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/spd/eval.py b/spd/eval.py index a2f724e75..12ce8ee12 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -55,6 +55,7 @@ StochasticHiddenActsReconLoss, StochasticReconLayerwiseLoss, StochasticReconLoss, + StochasticReconSubsetCEAndKL, StochasticReconSubsetLoss, UnmaskedReconLoss, UVPlots, @@ -249,16 +250,15 @@ def init_metric[BatchT, OutputT]( reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetCEAndKLConfig(): - raise ValueError("fix this typing!") - # metric = StochasticReconSubsetCEAndKL( - # model=model, - # device=device, - # sampling=run_config.sampling, - # use_delta_component=run_config.use_delta_component, - # n_mask_samples=run_config.n_mask_samples, - # include_patterns=cfg.include_patterns, - # exclude_patterns=cfg.exclude_patterns, - # ) + metric = StochasticReconSubsetCEAndKL( + model=model, + device=device, + sampling=run_config.sampling, + use_delta_component=run_config.use_delta_component, + n_mask_samples=run_config.n_mask_samples, + include_patterns=cfg.include_patterns, + exclude_patterns=cfg.exclude_patterns, + ) case StochasticHiddenActsReconLossConfig(): metric = StochasticHiddenActsReconLoss( model=model, diff --git a/spd/metrics/stochastic_recon_subset_ce_and_kl.py b/spd/metrics/stochastic_recon_subset_ce_and_kl.py index b1d67f86f..1dc0ee79f 100644 --- a/spd/metrics/stochastic_recon_subset_ce_and_kl.py +++ b/spd/metrics/stochastic_recon_subset_ce_and_kl.py @@ -19,7 +19,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class StochasticReconSubsetCEAndKL(Metric[Tensor, Tensor]): +class StochasticReconSubsetCEAndKL(Metric[Any, Any]): """Compute reconstruction loss for specific subsets of components. NOTE: Assumes all batches and sequences are the same size. @@ -29,7 +29,7 @@ class StochasticReconSubsetCEAndKL(Metric[Tensor, Tensor]): def __init__( self, - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Any, Any], device: str, sampling: SamplingType, use_delta_component: bool, From e695241a081966d97c51821a53d7fd273caac6c1 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 11 Feb 2026 08:08:45 +0000 Subject: [PATCH 12/16] FIx non-deterministic test in CI --- tests/test_distributed.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index b88f4a7b6..216602cff 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -160,6 +160,10 @@ def _run_experiment( new_env = os.environ.copy() new_env["CUDA_VISIBLE_DEVICES"] = "" new_env["SPD_OUT_DIR"] = str(spd_out_dir) + # Force single-threaded execution so that within-rank float32 operations + # are deterministic across different machines/CI environments. + new_env["OMP_NUM_THREADS"] = "1" + new_env["MKL_NUM_THREADS"] = "1" result = subprocess.run(cmd, env=new_env, capture_output=True, text=True, timeout=300) From 8b35022a682b3a6636e14b8ac5c953ebf8c92c21 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 11 Feb 2026 08:10:12 +0000 Subject: [PATCH 13/16] Remove explicit __call__ --- spd/eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/eval.py b/spd/eval.py index 12ce8ee12..4c9ab9a01 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -320,7 +320,7 @@ def evaluate[BatchT, OutputT]( for _ in range(n_eval_steps): batch = next(eval_iterator) - target_output: OutputWithCache[OutputT] = model.__call__(batch, cache_type="input") + target_output: OutputWithCache[OutputT] = model(batch, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=target_output.cache, detach_inputs=False, From 4fe40e3f3a3fc8c67d85da0a0b384f8176c664c8 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 11 Feb 2026 08:42:35 +0000 Subject: [PATCH 14/16] Remove OutputT as it was always a Tensor --- spd/app/backend/compute.py | 20 ++++++------- spd/app/backend/optim_cis.py | 12 ++++---- spd/app/backend/state.py | 2 +- spd/clustering/activations.py | 6 ++-- spd/dataset_attributions/harvest.py | 4 +-- spd/dataset_attributions/harvester.py | 4 +-- spd/eval.py | 24 +++++++-------- spd/harvest/harvest.py | 2 +- spd/losses.py | 8 ++--- spd/metrics/base.py | 6 ++-- spd/metrics/ce_and_kl_losses.py | 4 +-- spd/metrics/ci_histograms.py | 4 +-- spd/metrics/ci_l0.py | 4 +-- spd/metrics/ci_masked_recon_layerwise_loss.py | 24 +++++++-------- spd/metrics/ci_masked_recon_loss.py | 24 +++++++-------- spd/metrics/ci_masked_recon_subset_loss.py | 24 +++++++-------- spd/metrics/ci_mean_per_component.py | 4 +-- spd/metrics/component_activation_density.py | 6 ++-- spd/metrics/faithfulness_loss.py | 4 +-- spd/metrics/identity_ci_error.py | 4 +-- spd/metrics/importance_minimality_loss.py | 4 +-- spd/metrics/permuted_ci_plots.py | 4 +-- .../pgd_masked_recon_layerwise_loss.py | 24 +++++++-------- spd/metrics/pgd_masked_recon_loss.py | 16 +++++----- spd/metrics/pgd_masked_recon_subset_loss.py | 16 +++++----- spd/metrics/pgd_utils.py | 30 +++++++++---------- .../stochastic_hidden_acts_recon_loss.py | 12 ++++---- .../stochastic_recon_layerwise_loss.py | 24 +++++++-------- spd/metrics/stochastic_recon_loss.py | 24 +++++++-------- .../stochastic_recon_subset_ce_and_kl.py | 4 +-- spd/metrics/stochastic_recon_subset_loss.py | 24 +++++++-------- spd/metrics/unmasked_recon_loss.py | 24 +++++++-------- spd/metrics/uv_plots.py | 4 +-- spd/models/batch_and_loss_fns.py | 27 +++++++++++------ spd/models/component_model.py | 28 ++++++++--------- spd/plotting.py | 4 +-- spd/run_spd.py | 16 +++++----- spd/scripts/compare_models/compare_models.py | 4 +-- spd/utils/logging_utils.py | 2 +- tests/metrics/fixtures.py | 6 ++-- tests/metrics/test_faithfulness_loss.py | 2 +- tests/test_spd_losses.py | 4 +-- 42 files changed, 250 insertions(+), 243 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 8b1fe0606..b67d5ff71 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -127,7 +127,7 @@ def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: def get_sources_by_target( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], device: str, sampling: SamplingType, ) -> dict[str, list[str]]: @@ -142,7 +142,7 @@ def get_sources_by_target( batch: Float[Tensor, "batch seq"] = torch.zeros(2, 3, dtype=torch.long, device=device) with torch.no_grad(), bf16_autocast(): - output_with_cache: OutputWithCache[Any] = model(batch, cache_type="input") + output_with_cache: OutputWithCache = model(batch, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=output_with_cache.cache, @@ -171,7 +171,7 @@ def wte_hook( wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) with torch.enable_grad(), bf16_autocast(): - comp_output_with_cache: OutputWithCache[Any] = model( + comp_output_with_cache: OutputWithCache = model( batch, mask_infos=mask_infos, cache_type="component_acts", @@ -306,7 +306,7 @@ def _compute_edges_for_target( def compute_edges_from_ci( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], tokens: Float[Tensor, "1 seq"], ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], pre_weight_acts: dict[str, Float[Tensor, "1 seq d_in"]], @@ -356,7 +356,7 @@ def compute_edges_from_ci( weight_deltas_and_masks=weight_deltas_and_masks, ) with torch.enable_grad(), bf16_autocast(): - comp_output_with_cache: OutputWithCache[Any] = model( + comp_output_with_cache: OutputWithCache = model( tokens, mask_infos=unmasked_masks, cache_type="component_acts" ) @@ -492,7 +492,7 @@ def filter_ci_to_included_nodes( def compute_prompt_attributions( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], output_prob_threshold: float, @@ -542,7 +542,7 @@ def compute_prompt_attributions( def compute_prompt_attributions_optimized( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], optim_config: OptimCIConfig, @@ -626,7 +626,7 @@ class CIOnlyResult: def compute_ci_only( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], tokens: Float[Tensor, "1 seq"], sampling: SamplingType, ) -> CIOnlyResult: @@ -644,7 +644,7 @@ def compute_ci_only( CIOnlyResult containing CI values per layer, target model output probabilities, pre-weight activations, and component activations. """ with torch.no_grad(), bf16_autocast(): - output_with_cache: OutputWithCache[Any] = model(tokens, cache_type="input") + output_with_cache: OutputWithCache = model(tokens, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=output_with_cache.cache, sampling=sampling, @@ -791,7 +791,7 @@ class InterventionResult: def compute_intervention_forward( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] top_k: int, diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 76255df8a..79c109620 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -4,7 +4,7 @@ from collections.abc import Callable from dataclasses import dataclass from pathlib import Path -from typing import Any, Literal +from typing import Literal import torch import torch.nn.functional as F @@ -73,7 +73,7 @@ class OptimizableCIParams: ci_pre_sigmoid: dict[str, list[Tensor]] # layer_name -> list of [alive_at_pos] values alive_info: AliveComponentInfo - def create_ci_outputs(self, model: ComponentModel[Tensor, Tensor], device: str) -> CIOutputs: + def create_ci_outputs(self, model: ComponentModel[Tensor], device: str) -> CIOutputs: """Expand sparse pre-sigmoid values to full CI tensors and create CIOutputs.""" pre_sigmoid: dict[str, Tensor] = {} @@ -140,7 +140,7 @@ def create_optimizable_ci_params( def compute_label_prob( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], tokens: Tensor, ci_lower_leaky: dict[str, Tensor], label_token: int, @@ -167,7 +167,7 @@ def compute_l0_stats( def compute_final_token_ce_kl( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], batch: Tensor, target_out: Tensor, ci: dict[str, Tensor], @@ -272,7 +272,7 @@ class OptimCIConfig: def optimize_ci_values( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], tokens: Tensor, config: OptimCIConfig, device: str, @@ -297,7 +297,7 @@ def optimize_ci_values( # Get initial CI values from the model with torch.no_grad(), bf16_autocast(): - output_with_cache: OutputWithCache[Any] = model(tokens, cache_type="input") + output_with_cache: OutputWithCache = model(tokens, cache_type="input") initial_ci_outputs = model.calc_causal_importances( pre_weight_acts=output_with_cache.cache, sampling=config.sampling, diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index ca178b688..bc898eab1 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -111,7 +111,7 @@ class RunState: """Runtime state for a loaded run (model, tokenizer, etc.)""" run: Run - model: ComponentModel[Tensor, Tensor] + model: ComponentModel[Tensor] tokenizer: PreTrainedTokenizerBase sources_by_target: dict[str, list[str]] config: Config diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index 6c427d6b1..a933e070e 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from functools import cached_property -from typing import Any, Literal, NamedTuple +from typing import Literal, NamedTuple import torch from jaxtyping import Bool, Float, Float16, Int @@ -17,14 +17,14 @@ def component_activations( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], device: torch.device | str, batch: Int[Tensor, "batch_size n_ctx"], ) -> dict[str, ActivationsTensor]: """Get the component activations over a **single** batch.""" causal_importances: dict[str, ActivationsTensor] with torch.no_grad(): - model_output: OutputWithCache[Any] = model( + model_output: OutputWithCache = model( batch.to(device), cache_type="input", ) diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index 2068be277..fa95fcbe5 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -41,7 +41,7 @@ class DatasetAttributionConfig: ci_threshold: float -def _build_component_layer_keys(model: ComponentModel[Tensor, Tensor]) -> list[str]: +def _build_component_layer_keys(model: ComponentModel[Tensor]) -> list[str]: """Build list of component layer keys in canonical order. Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. @@ -56,7 +56,7 @@ def _build_component_layer_keys(model: ComponentModel[Tensor, Tensor]) -> list[s def _build_alive_masks( - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], run_id: str, ci_threshold: float, n_components: int, diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index f263dee7c..98a98ea77 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -46,7 +46,7 @@ class AttributionHarvester: def __init__( self, - model: ComponentModel[Tensor, Tensor], + model: ComponentModel[Tensor], sources_by_target: dict[str, list[str]], n_components: int, vocab_size: int, @@ -162,7 +162,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No # Forward pass with gradients with torch.enable_grad(), bf16_autocast(): - comp_output: OutputWithCache[Any] = self.model( + comp_output: OutputWithCache = self.model( tokens, mask_infos=mask_infos, cache_type="component_acts" ) diff --git a/spd/eval.py b/spd/eval.py index 4c9ab9a01..8bd622e21 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -118,13 +118,13 @@ def avg_eval_metrics_across_ranks(metrics: MetricOutType, device: str) -> DistMe return {**metrics, **avg_metrics} -def init_metric[BatchT, OutputT]( +def init_metric[BatchT]( cfg: MetricConfigType, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], run_config: Config, device: str, - reconstruction_loss: ReconstructionLoss[OutputT], -) -> Metric[BatchT, OutputT]: + reconstruction_loss: ReconstructionLoss, +) -> Metric[BatchT]: match cfg: case ImportanceMinimalityLossConfig(): metric = ImportanceMinimalityLoss( @@ -288,20 +288,20 @@ def init_metric[BatchT, OutputT]( return metric -def evaluate[BatchT, OutputT]( +def evaluate[BatchT]( eval_metric_configs: list[MetricConfigType], - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], eval_iterator: Iterator[BatchT], device: str, run_config: Config, slow_step: bool, n_eval_steps: int, current_frac_of_training: float, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> MetricOutType: """Run evaluation and return a mapping of metric names to values/images.""" - metrics: list[Metric[BatchT, OutputT]] = [] + metrics: list[Metric[BatchT]] = [] for cfg in eval_metric_configs: metric = init_metric( cfg=cfg, @@ -320,7 +320,7 @@ def evaluate[BatchT, OutputT]( for _ in range(n_eval_steps): batch = next(eval_iterator) - target_output: OutputWithCache[OutputT] = model(batch, cache_type="input") + target_output: OutputWithCache = model(batch, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=target_output.cache, detach_inputs=False, @@ -350,15 +350,15 @@ def evaluate[BatchT, OutputT]( return outputs -def evaluate_multibatch_pgd[BatchT, OutputT]( +def evaluate_multibatch_pgd[BatchT]( multibatch_pgd_eval_configs: list[ PGDMultiBatchReconLossConfig | PGDMultiBatchReconSubsetLossConfig ], - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], create_data_iter: CreateDataIter[BatchT], config: Config, device: str, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> dict[str, float]: """Calculate multibatch PGD metrics.""" weight_deltas = model.calc_weight_deltas() if config.use_delta_component else None diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index b1703c82c..9b10e9077 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -38,7 +38,7 @@ from spd.utils.general_utils import bf16_autocast -def _compute_u_norms(model: ComponentModel[Any, Any]) -> dict[str, Float[Tensor, " C"]]: +def _compute_u_norms(model: ComponentModel[Any]) -> dict[str, Float[Tensor, " C"]]: """Compute ||U[c,:]|| for each component c in each layer. Component activations (v_i^T @ a) have a scale invariance: scaling V by α and U by 1/α diff --git a/spd/losses.py b/spd/losses.py index a756c7d06..6531279c2 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -39,19 +39,19 @@ from spd.utils.general_utils import get_obj_device -def compute_total_loss[BatchT, OutputT]( +def compute_total_loss[BatchT]( loss_metric_configs: list[LossMetricConfigType], - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], batch: BatchT, ci: CIOutputs, - target_out: OutputT, + target_out: Tensor, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], pre_weight_acts: dict[str, Float[Tensor, "..."]], current_frac_of_training: float, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], dict[str, float]]: """Compute weighted total loss and per-term raw values using new loss primitives. diff --git a/spd/metrics/base.py b/spd/metrics/base.py index a7903c860..36c3de024 100644 --- a/spd/metrics/base.py +++ b/spd/metrics/base.py @@ -12,7 +12,7 @@ from spd.models.component_model import CIOutputs -class Metric[BatchT, OutputT](Protocol): +class Metric[BatchT](Protocol): """Interface for metrics that can be used in training and/or evaluation.""" slow: ClassVar[bool] = False @@ -22,11 +22,11 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: CIOutputs, current_frac_of_training: float, - weight_deltas: dict[str, Float[Tensor, "... C"]], + weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], ) -> None: """Update metric state with a batch of data.""" ... diff --git a/spd/metrics/ce_and_kl_losses.py b/spd/metrics/ce_and_kl_losses.py index 79db5ec51..83c6e8766 100644 --- a/spd/metrics/ce_and_kl_losses.py +++ b/spd/metrics/ce_and_kl_losses.py @@ -17,7 +17,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class CEandKLLosses(Metric[Any, Any]): +class CEandKLLosses(Metric[Any]): """CE and KL losses for different masking strategies. NOTE: Assumes all batches and sequences are the same size. @@ -47,7 +47,7 @@ class CEandKLLosses(Metric[Any, Any]): def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], device: str, sampling: SamplingType, rounding_threshold: float, diff --git a/spd/metrics/ci_histograms.py b/spd/metrics/ci_histograms.py index 22e6f7386..518788572 100644 --- a/spd/metrics/ci_histograms.py +++ b/spd/metrics/ci_histograms.py @@ -12,13 +12,13 @@ from spd.utils.distributed_utils import gather_all_tensors -class CIHistograms(Metric[Any, Any]): +class CIHistograms(Metric[Any]): slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], n_batches_accum: int | None = None, ): self.n_batches_accum = n_batches_accum diff --git a/spd/metrics/ci_l0.py b/spd/metrics/ci_l0.py index b534f5c9b..84290e96c 100644 --- a/spd/metrics/ci_l0.py +++ b/spd/metrics/ci_l0.py @@ -12,7 +12,7 @@ from spd.utils.distributed_utils import all_reduce -class CI_L0(Metric[Any, Any]): +class CI_L0(Metric[Any]): """L0 metric for CI values. NOTE: Assumes all batches and sequences are the same size. @@ -22,7 +22,7 @@ class CI_L0(Metric[Any, Any]): def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], device: str, ci_alive_threshold: float, groups: dict[str, list[str]] | None = None, diff --git a/spd/metrics/ci_masked_recon_layerwise_loss.py b/spd/metrics/ci_masked_recon_layerwise_loss.py index 44539c236..bfb5103e4 100644 --- a/spd/metrics/ci_masked_recon_layerwise_loss.py +++ b/spd/metrics/ci_masked_recon_layerwise_loss.py @@ -13,12 +13,12 @@ from spd.utils.general_utils import get_obj_device -def _ci_masked_recon_layerwise_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _ci_masked_recon_layerwise_loss_update[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: sum_loss = torch.tensor(0.0, device=get_obj_device(model)) sum_n_examples = 0 @@ -37,12 +37,12 @@ def _ci_masked_recon_layerwise_loss_compute( return sum_loss / sum_n_examples -def ci_masked_recon_layerwise_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def ci_masked_recon_layerwise_loss[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, sum_n_examples = _ci_masked_recon_layerwise_loss_update( model=model, @@ -54,16 +54,16 @@ def ci_masked_recon_layerwise_loss[BatchT, OutputT]( return _ci_masked_recon_layerwise_loss_compute(sum_loss, sum_n_examples) -class CIMaskedReconLayerwiseLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class CIMaskedReconLayerwiseLoss[BatchT](Metric[BatchT]): """Recon loss when masking with CI values directly one layer at a time.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.reconstruction_loss = reconstruction_loss @@ -75,7 +75,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, **_: Any, ) -> None: diff --git a/spd/metrics/ci_masked_recon_loss.py b/spd/metrics/ci_masked_recon_loss.py index a54ceb6e2..836460289 100644 --- a/spd/metrics/ci_masked_recon_loss.py +++ b/spd/metrics/ci_masked_recon_loss.py @@ -12,12 +12,12 @@ from spd.utils.distributed_utils import all_reduce -def _ci_masked_recon_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _ci_masked_recon_loss_update[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) out = model(batch, mask_infos=mask_infos) @@ -30,12 +30,12 @@ def _ci_masked_recon_loss_compute( return sum_loss / n_examples -def ci_masked_recon_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def ci_masked_recon_loss[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_loss_update( model=model, @@ -47,16 +47,16 @@ def ci_masked_recon_loss[BatchT, OutputT]( return _ci_masked_recon_loss_compute(sum_loss, n_examples) -class CIMaskedReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class CIMaskedReconLoss[BatchT](Metric[BatchT]): """Recon loss when masking with CI values directly on all component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.reconstruction_loss = reconstruction_loss @@ -68,7 +68,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, **_: Any, ) -> None: diff --git a/spd/metrics/ci_masked_recon_subset_loss.py b/spd/metrics/ci_masked_recon_subset_loss.py index 6b0fba016..3e494a36c 100644 --- a/spd/metrics/ci_masked_recon_subset_loss.py +++ b/spd/metrics/ci_masked_recon_subset_loss.py @@ -15,13 +15,13 @@ from spd.utils.general_utils import get_obj_device -def _ci_masked_recon_subset_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _ci_masked_recon_subset_loss_update[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], router: Router, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: subset_routing_masks = router.get_masks( module_names=model.target_module_paths, @@ -42,13 +42,13 @@ def _ci_masked_recon_subset_loss_compute( return sum_loss / n_examples -def ci_masked_recon_subset_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def ci_masked_recon_subset_loss[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], routing: SubsetRoutingType, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=model, @@ -61,17 +61,17 @@ def ci_masked_recon_subset_loss[BatchT, OutputT]( return _ci_masked_recon_subset_loss_compute(sum_loss, n_examples) -class CIMaskedReconSubsetLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class CIMaskedReconSubsetLoss[BatchT](Metric[BatchT]): """Recon loss when masking with raw CI values and routing to subsets of component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, routing: SubsetRoutingType, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.router = get_subset_router(routing, device) @@ -84,7 +84,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, **_: Any, ) -> None: diff --git a/spd/metrics/ci_mean_per_component.py b/spd/metrics/ci_mean_per_component.py index fb4373727..7fd9e0c7c 100644 --- a/spd/metrics/ci_mean_per_component.py +++ b/spd/metrics/ci_mean_per_component.py @@ -11,11 +11,11 @@ from spd.utils.distributed_utils import all_reduce -class CIMeanPerComponent(Metric[Any, Any]): +class CIMeanPerComponent(Metric[Any]): slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" - def __init__(self, model: ComponentModel[Any, Any], device: str) -> None: + def __init__(self, model: ComponentModel[Any], device: str) -> None: self.components = model.components self.component_ci_sums: dict[str, Tensor] = { module_name: torch.zeros(model.module_to_c[module_name], device=device) diff --git a/spd/metrics/component_activation_density.py b/spd/metrics/component_activation_density.py index eb56f83fd..7ff1a150d 100644 --- a/spd/metrics/component_activation_density.py +++ b/spd/metrics/component_activation_density.py @@ -13,15 +13,13 @@ from spd.utils.distributed_utils import all_reduce -class ComponentActivationDensity(Metric[Any, Any]): +class ComponentActivationDensity(Metric[Any]): """Activation density for each component.""" slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" - def __init__( - self, model: ComponentModel[Any, Any], device: str, ci_alive_threshold: float - ) -> None: + def __init__(self, model: ComponentModel[Any], device: str, ci_alive_threshold: float) -> None: self.model = model self.ci_alive_threshold = ci_alive_threshold diff --git a/spd/metrics/faithfulness_loss.py b/spd/metrics/faithfulness_loss.py index 3c902b65c..751a2e84c 100644 --- a/spd/metrics/faithfulness_loss.py +++ b/spd/metrics/faithfulness_loss.py @@ -35,12 +35,12 @@ def faithfulness_loss(weight_deltas: dict[str, Float[Tensor, "d_out d_in"]]) -> return _faithfulness_loss_compute(sum_loss, total_params) -class FaithfulnessLoss(Metric[Any, Any]): +class FaithfulnessLoss(Metric[Any]): """MSE between the target weights and the sum of the components.""" metric_section: ClassVar[str] = "loss" - def __init__(self, model: ComponentModel[Any, Any], device: str) -> None: + def __init__(self, model: ComponentModel[Any], device: str) -> None: self.model = model self.sum_loss = torch.tensor(0.0, device=device) self.total_params = torch.tensor(0, device=device) diff --git a/spd/metrics/identity_ci_error.py b/spd/metrics/identity_ci_error.py index b47ab8f8d..ca1210ee5 100644 --- a/spd/metrics/identity_ci_error.py +++ b/spd/metrics/identity_ci_error.py @@ -9,7 +9,7 @@ from spd.utils.target_ci_solutions import compute_target_metrics, make_target_ci_solution -class IdentityCIError(Metric[Any, Any]): +class IdentityCIError(Metric[Any]): """Error between the CI values and an Identity or Dense CI pattern.""" slow: ClassVar[bool] = True @@ -19,7 +19,7 @@ class IdentityCIError(Metric[Any, Any]): def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], sampling: SamplingType, identity_ci: list[dict[str, str | int]] | None = None, dense_ci: list[dict[str, str | int]] | None = None, diff --git a/spd/metrics/importance_minimality_loss.py b/spd/metrics/importance_minimality_loss.py index d06f9e47e..5fc607797 100644 --- a/spd/metrics/importance_minimality_loss.py +++ b/spd/metrics/importance_minimality_loss.py @@ -144,7 +144,7 @@ def importance_minimality_loss( ) -class ImportanceMinimalityLoss(Metric[Any, Any]): +class ImportanceMinimalityLoss(Metric[Any]): """L_p loss on the sum of CI values. NOTE: We don't normalize over the number of layers because a change in the number of layers @@ -165,7 +165,7 @@ class ImportanceMinimalityLoss(Metric[Any, Any]): def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], device: str, pnorm: float, beta: float, diff --git a/spd/metrics/permuted_ci_plots.py b/spd/metrics/permuted_ci_plots.py index f81b70fb1..d859b4119 100644 --- a/spd/metrics/permuted_ci_plots.py +++ b/spd/metrics/permuted_ci_plots.py @@ -9,7 +9,7 @@ from spd.plotting import plot_causal_importance_vals -class PermutedCIPlots(Metric[Any, Any]): +class PermutedCIPlots(Metric[Any]): slow: ClassVar[bool] = True input_magnitude: ClassVar[float] = 0.75 @@ -17,7 +17,7 @@ class PermutedCIPlots(Metric[Any, Any]): def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], sampling: SamplingType, identity_patterns: list[str] | None = None, dense_patterns: list[str] | None = None, diff --git a/spd/metrics/pgd_masked_recon_layerwise_loss.py b/spd/metrics/pgd_masked_recon_layerwise_loss.py index cd24a3c5d..9a878fe96 100644 --- a/spd/metrics/pgd_masked_recon_layerwise_loss.py +++ b/spd/metrics/pgd_masked_recon_layerwise_loss.py @@ -14,15 +14,15 @@ from spd.utils.distributed_utils import all_reduce -def _pgd_recon_layerwise_loss_update[BatchT, OutputT]( +def _pgd_recon_layerwise_loss_update[BatchT]( *, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], Int[Tensor, ""]]: device = next(iter(ci.values())).device sum_loss = torch.tensor(0.0, device=device) @@ -43,15 +43,15 @@ def _pgd_recon_layerwise_loss_update[BatchT, OutputT]( return sum_loss, n_examples -def pgd_recon_layerwise_loss[BatchT, OutputT]( +def pgd_recon_layerwise_loss[BatchT]( *, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _pgd_recon_layerwise_loss_update( model=model, @@ -65,7 +65,7 @@ def pgd_recon_layerwise_loss[BatchT, OutputT]( return sum_loss / n_examples -class PGDReconLayerwiseLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class PGDReconLayerwiseLoss[BatchT](Metric[BatchT]): """Recon loss when masking with adversarially-optimized values and routing to one layer at a time.""" @@ -73,11 +73,11 @@ class PGDReconLayerwiseLoss[BatchT, OutputT](Metric[BatchT, OutputT]): def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], pgd_config: PGDConfig, device: str, use_delta_component: bool, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config @@ -91,7 +91,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, diff --git a/spd/metrics/pgd_masked_recon_loss.py b/spd/metrics/pgd_masked_recon_loss.py index b82763b75..a64a7b4f7 100644 --- a/spd/metrics/pgd_masked_recon_loss.py +++ b/spd/metrics/pgd_masked_recon_loss.py @@ -14,15 +14,15 @@ from spd.utils.distributed_utils import all_reduce -def pgd_recon_loss[BatchT, OutputT]( +def pgd_recon_loss[BatchT]( *, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = pgd_masked_recon_loss_update( model=model, @@ -37,7 +37,7 @@ def pgd_recon_loss[BatchT, OutputT]( return sum_loss / n_examples -class PGDReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class PGDReconLoss[BatchT](Metric[BatchT]): """Recon loss when masking with adversarially-optimized values and routing to all component layers.""" @@ -45,11 +45,11 @@ class PGDReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, pgd_config: PGDConfig, use_delta_component: bool, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config @@ -63,7 +63,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, diff --git a/spd/metrics/pgd_masked_recon_subset_loss.py b/spd/metrics/pgd_masked_recon_subset_loss.py index 1dfdf22e9..1fdd10180 100644 --- a/spd/metrics/pgd_masked_recon_subset_loss.py +++ b/spd/metrics/pgd_masked_recon_subset_loss.py @@ -15,16 +15,16 @@ from spd.utils.general_utils import get_obj_device -def pgd_recon_subset_loss[BatchT, OutputT]( +def pgd_recon_subset_loss[BatchT]( *, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, routing: SubsetRoutingType, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = pgd_masked_recon_loss_update( model=model, @@ -39,7 +39,7 @@ def pgd_recon_subset_loss[BatchT, OutputT]( return sum_loss / n_examples -class PGDReconSubsetLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class PGDReconSubsetLoss[BatchT](Metric[BatchT]): """Recon loss when masking with adversarially-optimized values and routing to subsets of component layers.""" @@ -47,12 +47,12 @@ class PGDReconSubsetLoss[BatchT, OutputT](Metric[BatchT, OutputT]): def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, use_delta_component: bool, pgd_config: PGDConfig, routing: SubsetRoutingType, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config @@ -68,7 +68,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index 5500955c6..d7eb3e873 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -17,15 +17,15 @@ from spd.utils.general_utils import get_obj_device -def pgd_masked_recon_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def pgd_masked_recon_loss_update[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - target_out: OutputT, + target_out: Tensor, router: Router, pgd_config: PGDConfig, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: """Central implementation of PGD masked reconstruction loss. @@ -84,16 +84,16 @@ class CreateDataIter[BatchT](Protocol): def __call__(self) -> Iterator[BatchT]: ... -def calc_multibatch_pgd_masked_recon_loss[BatchT, OutputT]( +def calc_multibatch_pgd_masked_recon_loss[BatchT]( pgd_config: PGDMultiBatchConfig, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, create_data_iter: CreateDataIter[BatchT], router: Router, sampling: SamplingType, use_delta_component: bool, device: str, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: """PGD masked reconstruction loss with gradient accumulation over multiple batches. @@ -156,16 +156,16 @@ def calc_multibatch_pgd_masked_recon_loss[BatchT, OutputT]( return final_loss / final_sum_n_examples -def _forward_with_adv_sources[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _forward_with_adv_sources[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, adv_sources: dict[str, Float[Tensor, "*batch_dim_or_ones mask_c"]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing_masks: RoutingMasks, - target_out: OutputT, + target_out: Tensor, batch_dims: tuple[int, ...], - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ): expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} adv_sources_components: dict[str, Float[Tensor, "*batch_dims C"]] @@ -191,16 +191,16 @@ def _forward_with_adv_sources[BatchT, OutputT]( return sum_loss, n_examples -def _multibatch_pgd_fwd_bwd[BatchT, OutputT]( +def _multibatch_pgd_fwd_bwd[BatchT]( adv_sources: dict[str, Float[Tensor, "*ones mask_c"]], pgd_config: PGDMultiBatchConfig, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, data_iter: Iterator[BatchT], device: torch.device | str, router: Router, sampling: SamplingType, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int, dict[str, Float[Tensor, "*ones mask_c"]]]: """Perform a forward and backward pass over multiple batches with gradient accumulation. @@ -223,7 +223,7 @@ def _multibatch_pgd_fwd_bwd[BatchT, OutputT]( # NOTE: technically this is duplicated work across PGD steps, but that's the price we pay to # enable accumulating gradients over more microbatches than we'd be able to fit CI values in # memory for. In other words, you can't fit 100,000 microbatches worth of CI values in memory. - target_model_output: OutputWithCache[OutputT] = model(microbatch, cache_type="input") + target_model_output: OutputWithCache = model(microbatch, cache_type="input") ci = model.calc_causal_importances( pre_weight_acts=target_model_output.cache, sampling=sampling, diff --git a/spd/metrics/stochastic_hidden_acts_recon_loss.py b/spd/metrics/stochastic_hidden_acts_recon_loss.py index 7b889da93..f2fad6f33 100644 --- a/spd/metrics/stochastic_hidden_acts_recon_loss.py +++ b/spd/metrics/stochastic_hidden_acts_recon_loss.py @@ -14,8 +14,8 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_hidden_acts_recon_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _stochastic_hidden_acts_recon_loss_update[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, @@ -59,8 +59,8 @@ def _stochastic_hidden_acts_recon_loss_compute( return sum_mse / n_examples -def stochastic_hidden_acts_recon_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def stochastic_hidden_acts_recon_loss[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, @@ -80,14 +80,14 @@ def stochastic_hidden_acts_recon_loss[BatchT, OutputT]( return _stochastic_hidden_acts_recon_loss_compute(sum_mse, n_examples) -class StochasticHiddenActsReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class StochasticHiddenActsReconLoss[BatchT](Metric[BatchT]): """Reconstruction loss between target and stochastic hidden activations when sampling with stochastic masks.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, sampling: SamplingType, use_delta_component: bool, diff --git a/spd/metrics/stochastic_recon_layerwise_loss.py b/spd/metrics/stochastic_recon_layerwise_loss.py index 6d93ae430..de01b58a0 100644 --- a/spd/metrics/stochastic_recon_layerwise_loss.py +++ b/spd/metrics/stochastic_recon_layerwise_loss.py @@ -15,15 +15,15 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_recon_layerwise_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _stochastic_recon_layerwise_loss_update[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -55,15 +55,15 @@ def _stochastic_recon_layerwise_loss_compute( return sum_loss / sum_n_examples -def stochastic_recon_layerwise_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def stochastic_recon_layerwise_loss[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, sum_n_examples = _stochastic_recon_layerwise_loss_update( model=model, @@ -78,19 +78,19 @@ def stochastic_recon_layerwise_loss[BatchT, OutputT]( return _stochastic_recon_layerwise_loss_compute(sum_loss, sum_n_examples) -class StochasticReconLayerwiseLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class StochasticReconLayerwiseLoss[BatchT](Metric[BatchT]): """Recon loss when sampling with stochastic masks one layer at a time.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling @@ -105,7 +105,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, diff --git a/spd/metrics/stochastic_recon_loss.py b/spd/metrics/stochastic_recon_loss.py index 893ca432e..59650c322 100644 --- a/spd/metrics/stochastic_recon_loss.py +++ b/spd/metrics/stochastic_recon_loss.py @@ -15,15 +15,15 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_recon_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _stochastic_recon_loss_update[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -51,15 +51,15 @@ def _stochastic_recon_loss_compute( return sum_loss / sum_n_examples -def stochastic_recon_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def stochastic_recon_loss[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, sum_n_examples = _stochastic_recon_loss_update( model=model, @@ -74,19 +74,19 @@ def stochastic_recon_loss[BatchT, OutputT]( return _stochastic_recon_loss_compute(sum_loss, sum_n_examples) -class StochasticReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class StochasticReconLoss[BatchT](Metric[BatchT]): """Recon loss when sampling with stochastic masks on all component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling @@ -101,7 +101,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, diff --git a/spd/metrics/stochastic_recon_subset_ce_and_kl.py b/spd/metrics/stochastic_recon_subset_ce_and_kl.py index 1dc0ee79f..cdf52d183 100644 --- a/spd/metrics/stochastic_recon_subset_ce_and_kl.py +++ b/spd/metrics/stochastic_recon_subset_ce_and_kl.py @@ -19,7 +19,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class StochasticReconSubsetCEAndKL(Metric[Any, Any]): +class StochasticReconSubsetCEAndKL(Metric[Any]): """Compute reconstruction loss for specific subsets of components. NOTE: Assumes all batches and sequences are the same size. @@ -29,7 +29,7 @@ class StochasticReconSubsetCEAndKL(Metric[Any, Any]): def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], device: str, sampling: SamplingType, use_delta_component: bool, diff --git a/spd/metrics/stochastic_recon_subset_loss.py b/spd/metrics/stochastic_recon_subset_loss.py index ea7cbd5f1..0fc074fa9 100644 --- a/spd/metrics/stochastic_recon_subset_loss.py +++ b/spd/metrics/stochastic_recon_subset_loss.py @@ -15,16 +15,16 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_recon_subset_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _stochastic_recon_subset_loss_update[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, router: Router, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -55,16 +55,16 @@ def _stochastic_recon_subset_loss_compute( return sum_loss / n_examples -def stochastic_recon_subset_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def stochastic_recon_subset_loss[BatchT]( + model: ComponentModel[BatchT], sampling: SamplingType, n_mask_samples: int, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing: SubsetRoutingType, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _stochastic_recon_subset_loss_update( model=model, @@ -80,20 +80,20 @@ def stochastic_recon_subset_loss[BatchT, OutputT]( return _stochastic_recon_subset_loss_compute(sum_loss, n_examples) -class StochasticReconSubsetLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class StochasticReconSubsetLoss[BatchT](Metric[BatchT]): """Recon loss when sampling with stochastic masks and routing to subsets of component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, routing: SubsetRoutingType, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling @@ -109,7 +109,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, diff --git a/spd/metrics/unmasked_recon_loss.py b/spd/metrics/unmasked_recon_loss.py index f9bc34545..027022274 100644 --- a/spd/metrics/unmasked_recon_loss.py +++ b/spd/metrics/unmasked_recon_loss.py @@ -13,11 +13,11 @@ from spd.utils.general_utils import get_obj_device -def _unmasked_recon_loss_update[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def _unmasked_recon_loss_update[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, - reconstruction_loss: ReconstructionLoss[OutputT], + target_out: Tensor, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: all_ones_mask_infos = make_mask_infos( # (C,) will broadcast to (B, S, C) @@ -36,11 +36,11 @@ def _unmasked_recon_loss_compute( return sum_loss / n_examples -def unmasked_recon_loss[BatchT, OutputT]( - model: ComponentModel[BatchT, OutputT], +def unmasked_recon_loss[BatchT]( + model: ComponentModel[BatchT], batch: BatchT, - target_out: OutputT, - reconstruction_loss: ReconstructionLoss[OutputT], + target_out: Tensor, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _unmasked_recon_loss_update( model, @@ -51,16 +51,16 @@ def unmasked_recon_loss[BatchT, OutputT]( return _unmasked_recon_loss_compute(sum_loss, n_examples) -class UnmaskedReconLoss[BatchT, OutputT](Metric[BatchT, OutputT]): +class UnmaskedReconLoss[BatchT](Metric[BatchT]): """Recon loss using the unmasked components and without the delta component.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT, OutputT], + model: ComponentModel[BatchT], device: str, - reconstruction_loss: ReconstructionLoss[OutputT], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.reconstruction_loss = reconstruction_loss @@ -72,7 +72,7 @@ def update( self, *, batch: BatchT, - target_out: OutputT, + target_out: Tensor, **_: Any, ) -> None: sum_loss, n_examples = _unmasked_recon_loss_update( diff --git a/spd/metrics/uv_plots.py b/spd/metrics/uv_plots.py index 8e5849df3..5b2b4e775 100644 --- a/spd/metrics/uv_plots.py +++ b/spd/metrics/uv_plots.py @@ -9,7 +9,7 @@ from spd.plotting import plot_causal_importance_vals, plot_UV_matrices -class UVPlots(Metric[Any, Any]): +class UVPlots(Metric[Any]): metric_section: ClassVar[str] = "figures" slow: ClassVar[bool] = True @@ -17,7 +17,7 @@ class UVPlots(Metric[Any, Any]): def __init__( self, - model: ComponentModel[Any, Any], + model: ComponentModel[Any], sampling: SamplingType, identity_patterns: list[str] | None = None, dense_patterns: list[str] | None = None, diff --git a/spd/models/batch_and_loss_fns.py b/spd/models/batch_and_loss_fns.py index fb0ce4b8b..69ed3fb7e 100644 --- a/spd/models/batch_and_loss_fns.py +++ b/spd/models/batch_and_loss_fns.py @@ -11,25 +11,26 @@ from torch import Tensor, nn from spd.configs import AttrOutputExtract, IndexOutputExtract, OutputExtractConfig +from spd.utils.general_utils import runtime_cast -class RunBatch[BatchT, OutputT](Protocol): +class RunBatch[BatchT](Protocol): """Protocol for running a batch through a model and returning the output.""" - def __call__(self, model: nn.Module, batch: BatchT) -> OutputT: ... + def __call__(self, model: nn.Module, batch: BatchT) -> Tensor: ... -class ReconstructionLoss[OutputT](Protocol): +class ReconstructionLoss(Protocol): """Protocol for computing reconstruction loss between predictions and targets.""" - def __call__(self, pred: OutputT, target: OutputT) -> tuple[Float[Tensor, ""], int]: ... + def __call__(self, pred: Tensor, target: Tensor) -> tuple[Float[Tensor, ""], int]: ... -def run_batch_passthrough(model: nn.Module, batch: Any) -> Any: - return model(batch) +def run_batch_passthrough(model: nn.Module, batch: Any) -> Tensor: + return runtime_cast(Tensor, model(batch)) -def make_run_batch(output_extract: OutputExtractConfig | None) -> RunBatch[Any, Any]: +def make_run_batch(output_extract: OutputExtractConfig | None) -> RunBatch[Any]: """creates a RunBatch function for a given configuration. Note that if you plan to override the RunBatch functionality, you can simply pass @@ -40,9 +41,17 @@ def make_run_batch(output_extract: OutputExtractConfig | None) -> RunBatch[Any, case None: return run_batch_passthrough case IndexOutputExtract(index=idx): - return lambda model, batch: model(batch)[idx] + + def _run_index(model: nn.Module, batch: Any) -> Tensor: + return model(batch)[idx] + + return _run_index case AttrOutputExtract(attr=attr): - return lambda model, batch: getattr(model(batch), attr) + + def _run_attr(model: nn.Module, batch: Any) -> Tensor: + return getattr(model(batch), attr) + + return _run_attr def recon_loss_mse( diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 2f7038ac7..ccd373ba7 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -39,10 +39,10 @@ class SPDRunInfo(RunInfo[Config]): checkpoint_prefix = "model" -class OutputWithCache[OutputT](NamedTuple): +class OutputWithCache(NamedTuple): """Output tensor and cached activations.""" - output: OutputT + output: Tensor cache: dict[str, Tensor] @@ -53,7 +53,7 @@ class CIOutputs: pre_sigmoid: dict[str, Tensor] -class ComponentModel[BatchT, OutputT](nn.Module): +class ComponentModel[BatchT](nn.Module): """Wrapper around an arbitrary pytorch model for running SPD. The underlying *base model* can be any subclass of `nn.Module` (e.g. @@ -75,14 +75,14 @@ class ComponentModel[BatchT, OutputT](nn.Module): def __init__( self, target_model: nn.Module, - run_batch: RunBatch[BatchT, OutputT], + run_batch: RunBatch[BatchT], module_path_info: list[ModulePathInfo], ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], sigmoid_type: SigmoidType, ): super().__init__() - self._run_batch: RunBatch[BatchT, OutputT] = run_batch + self._run_batch: RunBatch[BatchT] = run_batch for name, param in target_model.named_parameters(): assert not param.requires_grad, ( @@ -121,7 +121,7 @@ def __init__( self.upper_leaky_fn = SIGMOID_TYPES[sigmoid_type] @classmethod - def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any, Any]": + def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any]": """Load a trained ComponentModel from a run info object.""" config = run_info.config @@ -147,7 +147,7 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any, Any]": run_batch = make_run_batch(config.output_extract) - comp_model: ComponentModel[Any, Any] = cls( + comp_model = ComponentModel( target_model=target_model, run_batch=run_batch, module_path_info=module_path_info, @@ -162,7 +162,7 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any, Any]": return comp_model @classmethod - def from_pretrained(cls, path: ModelPath) -> "ComponentModel[Any, Any]": + def from_pretrained(cls, path: ModelPath) -> "ComponentModel[Any]": """Load a trained ComponentModel from a wandb or local path.""" run_info = SPDRunInfo.from_path(path) return cls.from_run_info(run_info) @@ -288,7 +288,7 @@ def __call__( batch: BatchT, cache_type: Literal["component_acts"], mask_infos: dict[str, ComponentsMaskInfo] | None = None, - ) -> OutputWithCache[OutputT]: ... + ) -> OutputWithCache: ... @overload def __call__( @@ -296,7 +296,7 @@ def __call__( batch: BatchT, cache_type: Literal["input"], mask_infos: dict[str, ComponentsMaskInfo] | None = None, - ) -> OutputWithCache[OutputT]: ... + ) -> OutputWithCache: ... @overload def __call__( @@ -304,10 +304,10 @@ def __call__( batch: BatchT, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["none"] = "none", - ) -> OutputT: ... + ) -> Tensor: ... @override - def __call__(self, *args: Any, **kwargs: Any) -> OutputT | OutputWithCache[OutputT]: + def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache: return super().__call__(*args, **kwargs) @override @@ -316,7 +316,7 @@ def forward( batch: BatchT, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["component_acts", "input", "none"] = "none", - ) -> OutputT | OutputWithCache[OutputT]: + ) -> Tensor | OutputWithCache: """Forward pass with optional component replacement and/or input caching. This method handles the following 4 cases: @@ -361,7 +361,7 @@ def forward( ) with self._attach_forward_hooks(hooks): - out: OutputT = self._run_batch(self.target_model, batch) + out: Tensor = self._run_batch(self.target_model, batch) match cache_type: case "input" | "component_acts": diff --git a/spd/plotting.py b/spd/plotting.py index e66c391d4..dffdff417 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -183,7 +183,7 @@ def plot_mean_component_cis_both_scales( def get_single_feature_causal_importances( - model: ComponentModel[Any, Any], + model: ComponentModel[Any], batch_shape: tuple[int, ...], input_magnitude: float, sampling: SamplingType, @@ -217,7 +217,7 @@ def get_single_feature_causal_importances( def plot_causal_importance_vals( - model: ComponentModel[Any, Any], + model: ComponentModel[Any], batch_shape: tuple[int, ...], input_magnitude: float, sampling: SamplingType, diff --git a/spd/run_spd.py b/spd/run_spd.py index 66ef51032..e7af079c6 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -51,7 +51,7 @@ def run_faithfulness_warmup( - component_model: ComponentModel[Any, Any], + component_model: ComponentModel[Any], component_params: list[torch.nn.Parameter], config: Config, ) -> None: @@ -109,14 +109,14 @@ def get_unique_metric_configs( return eval_metric_configs -def optimize[BatchT, OutputT]( +def optimize[BatchT]( target_model: nn.Module, config: Config, device: str, train_loader: DataLoader[BatchT], eval_loader: DataLoader[BatchT], - run_batch: RunBatch[BatchT, OutputT], - reconstruction_loss: ReconstructionLoss[OutputT], + run_batch: RunBatch[BatchT], + reconstruction_loss: ReconstructionLoss, out_dir: Path | None, tied_weights: list[tuple[str, str]] | None = None, ) -> None: @@ -158,7 +158,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: dist_state = get_distributed_state() wrapped_model: nn.Module = model - component_model: ComponentModel[BatchT, OutputT] + component_model: ComponentModel[BatchT] if dist_state is not None: if dist_state.backend == "nccl": device_id = dist_state.local_rank @@ -171,7 +171,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: # For CPU, don't pass device_ids or output_device wrapped_model = torch.nn.parallel.DistributedDataParallel(model) # Access the underlying module for component operations - component_model = wrapped_model.module # type: ignore[attr-defined] + component_model = cast(ComponentModel[BatchT], wrapped_model.module) # type: ignore[attr-defined] else: component_model = model assert isinstance(component_model, ComponentModel), "component_model is not a ComponentModel" @@ -233,9 +233,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: # NOTE: we need to call the wrapped_model at least once each step in order to setup # the DDP gradient syncing for all parameters in the component model. Gradients will # sync regardless of whether the parameters are used in this call to wrapped_model. - target_model_output: OutputWithCache[OutputT] = wrapped_model( - microbatch, cache_type="input" - ) + target_model_output: OutputWithCache = wrapped_model(microbatch, cache_type="input") ci = component_model.calc_causal_importances( pre_weight_acts=target_model_output.cache, diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index 1b4cb177e..0acefeb53 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -79,7 +79,7 @@ def __init__(self, config: CompareModelsConfig): config.reference_model_path ) - def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel[Any, Any], Config]: + def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel[Any], Config]: """Load model and config using the standard pattern from existing codebase.""" run_info = SPDRunInfo.from_path(model_path) # TODO(oli): this should actually be generic (one of the only instances of this I think) @@ -234,7 +234,7 @@ def _create_ih_data_loader(self) -> Iterator[Any]: ) def compute_activation_densities( - self, model: ComponentModel[Any, Any], eval_iterator: Iterator[Any], n_steps: int + self, model: ComponentModel[Any], eval_iterator: Iterator[Any], n_steps: int ) -> dict[str, Float[Tensor, " C"]]: """Compute activation densities using same logic as ComponentActivationDensity.""" diff --git a/spd/utils/logging_utils.py b/spd/utils/logging_utils.py index d006f082d..3555cadc7 100644 --- a/spd/utils/logging_utils.py +++ b/spd/utils/logging_utils.py @@ -40,7 +40,7 @@ def local_log(data: dict[str, Any], step: int, out_dir: Path) -> None: def get_grad_norms_dict( - component_model: ComponentModel[Any, Any], device: torch.device | str + component_model: ComponentModel[Any], device: torch.device | str ) -> dict[str, float]: """Create a dictionary of gradient norms for the parameters of a component model.""" diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index 7c35e2473..dfbf95a51 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -39,7 +39,9 @@ def forward(self, x: Tensor) -> Tensor: return x -def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel[Any, Any]: +def make_one_layer_component_model( + weight: Float[Tensor, "d_out d_in"], +) -> ComponentModel[Any]: """Create a ComponentModel with a single linear layer for testing. Args: @@ -68,7 +70,7 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo def make_two_layer_component_model( weight1: Float[Tensor, " d_hidden d_in"], weight2: Float[Tensor, " d_out d_hidden"] -) -> ComponentModel[Any, Any]: +) -> ComponentModel[Any]: """Create a ComponentModel with two linear layers for testing. Args: diff --git a/tests/metrics/test_faithfulness_loss.py b/tests/metrics/test_faithfulness_loss.py index a2635c22d..a8b89e259 100644 --- a/tests/metrics/test_faithfulness_loss.py +++ b/tests/metrics/test_faithfulness_loss.py @@ -7,7 +7,7 @@ from tests.metrics.fixtures import make_one_layer_component_model -def zero_out_components(model: ComponentModel[Any, Any]) -> None: +def zero_out_components(model: ComponentModel[Any]) -> None: with torch.no_grad(): for cm in model.components.values(): cm.V.zero_() diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 8bd948b44..e75776966 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -31,7 +31,7 @@ def forward(self, x: Tensor) -> Tensor: return self.fc(x) -def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel[Any, Any]: +def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel[Any]: d_out, d_in = weight.shape target = TinyLinearModel(d_in=d_in, d_out=d_out) with torch.no_grad(): @@ -50,7 +50,7 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel return comp_model -def _zero_components_for_test(model: ComponentModel[Any, Any]) -> None: +def _zero_components_for_test(model: ComponentModel[Any]) -> None: with torch.no_grad(): for cm in model.components.values(): cm.V.zero_() From 1f51a37df956199624a23c8bb80e10421367f5f8 Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 11 Feb 2026 16:09:54 +0000 Subject: [PATCH 15/16] Remove generics and simplify output_extract --- spd/app/backend/compute.py | 12 +++---- spd/app/backend/optim_cis.py | 8 ++--- spd/app/backend/state.py | 3 +- spd/clustering/activations.py | 2 +- spd/configs.py | 22 +++---------- spd/dataset_attributions/harvest.py | 4 +-- spd/dataset_attributions/harvester.py | 2 +- spd/eval.py | 20 ++++++------ spd/experiments/lm/gpt2_config.yaml | 4 +-- .../lm/pile_llama_simple_mlp-2L.yaml | 4 +-- .../lm/pile_llama_simple_mlp-4L.yaml | 4 +-- spd/experiments/lm/ss_gpt2_config.yaml | 4 +-- spd/experiments/lm/ss_gpt2_simple-1L.yaml | 4 +-- spd/experiments/lm/ss_gpt2_simple-2L.yaml | 4 +-- spd/experiments/lm/ss_gpt2_simple_config.yaml | 4 +-- .../lm/ss_gpt2_simple_noln_config.yaml | 4 +-- spd/experiments/lm/ss_llama_simple-1L.yaml | 4 +-- spd/experiments/lm/ss_llama_simple-2L.yaml | 4 +-- .../lm/ss_llama_simple_config.yaml | 4 +-- .../lm/ss_llama_simple_mlp-1L.yaml | 4 +-- .../lm/ss_llama_simple_mlp-2L-wide.yaml | 4 +-- .../lm/ss_llama_simple_mlp-2L.yaml | 4 +-- spd/experiments/lm/ss_llama_simple_mlp.yaml | 4 +-- spd/experiments/lm/ts_config.yaml | 4 +-- spd/harvest/harvest.py | 3 +- spd/losses.py | 8 +++-- spd/metrics/base.py | 4 +-- spd/metrics/ce_and_kl_losses.py | 4 +-- spd/metrics/ci_histograms.py | 4 +-- spd/metrics/ci_l0.py | 4 +-- spd/metrics/ci_masked_recon_layerwise_loss.py | 18 +++++------ spd/metrics/ci_masked_recon_loss.py | 18 +++++------ spd/metrics/ci_masked_recon_subset_loss.py | 18 +++++------ spd/metrics/ci_mean_per_component.py | 4 +-- spd/metrics/component_activation_density.py | 4 +-- spd/metrics/faithfulness_loss.py | 4 +-- spd/metrics/identity_ci_error.py | 4 +-- spd/metrics/importance_minimality_loss.py | 4 +-- spd/metrics/permuted_ci_plots.py | 4 +-- .../pgd_masked_recon_layerwise_loss.py | 18 +++++------ spd/metrics/pgd_masked_recon_loss.py | 12 +++---- spd/metrics/pgd_masked_recon_subset_loss.py | 12 +++---- spd/metrics/pgd_utils.py | 31 +++++++++---------- .../stochastic_hidden_acts_recon_loss.py | 18 +++++------ .../stochastic_recon_layerwise_loss.py | 18 +++++------ spd/metrics/stochastic_recon_loss.py | 18 +++++------ .../stochastic_recon_subset_ce_and_kl.py | 4 +-- spd/metrics/stochastic_recon_subset_loss.py | 18 +++++------ spd/metrics/unmasked_recon_loss.py | 18 +++++------ spd/metrics/uv_plots.py | 4 +-- spd/models/batch_and_loss_fns.py | 21 ++++++++----- spd/models/component_model.py | 18 +++++------ spd/plotting.py | 5 ++- spd/run_spd.py | 16 +++++----- spd/scripts/compare_models/compare_models.py | 4 +-- spd/utils/logging_utils.py | 2 +- tests/app/test_server_api.py | 3 +- tests/metrics/fixtures.py | 6 ++-- tests/metrics/test_faithfulness_loss.py | 4 +-- tests/test_distributed.py | 2 +- tests/test_gpt2.py | 3 +- tests/test_spd_losses.py | 6 ++-- 62 files changed, 229 insertions(+), 274 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index b67d5ff71..0c82f105b 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -127,7 +127,7 @@ def is_kv_to_o_pair(in_layer: str, out_layer: str) -> bool: def get_sources_by_target( - model: ComponentModel[Tensor], + model: ComponentModel, device: str, sampling: SamplingType, ) -> dict[str, list[str]]: @@ -306,7 +306,7 @@ def _compute_edges_for_target( def compute_edges_from_ci( - model: ComponentModel[Tensor], + model: ComponentModel, tokens: Float[Tensor, "1 seq"], ci_lower_leaky: dict[str, Float[Tensor, "1 seq C"]], pre_weight_acts: dict[str, Float[Tensor, "1 seq d_in"]], @@ -492,7 +492,7 @@ def filter_ci_to_included_nodes( def compute_prompt_attributions( - model: ComponentModel[Tensor], + model: ComponentModel, tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], output_prob_threshold: float, @@ -542,7 +542,7 @@ def compute_prompt_attributions( def compute_prompt_attributions_optimized( - model: ComponentModel[Tensor], + model: ComponentModel, tokens: Float[Tensor, "1 seq"], sources_by_target: dict[str, list[str]], optim_config: OptimCIConfig, @@ -626,7 +626,7 @@ class CIOnlyResult: def compute_ci_only( - model: ComponentModel[Tensor], + model: ComponentModel, tokens: Float[Tensor, "1 seq"], sampling: SamplingType, ) -> CIOnlyResult: @@ -791,7 +791,7 @@ class InterventionResult: def compute_intervention_forward( - model: ComponentModel[Tensor], + model: ComponentModel, tokens: Float[Tensor, "1 seq"], active_nodes: list[tuple[str, int, int]], # [(layer, seq_pos, component_idx)] top_k: int, diff --git a/spd/app/backend/optim_cis.py b/spd/app/backend/optim_cis.py index 79c109620..798c88176 100644 --- a/spd/app/backend/optim_cis.py +++ b/spd/app/backend/optim_cis.py @@ -73,7 +73,7 @@ class OptimizableCIParams: ci_pre_sigmoid: dict[str, list[Tensor]] # layer_name -> list of [alive_at_pos] values alive_info: AliveComponentInfo - def create_ci_outputs(self, model: ComponentModel[Tensor], device: str) -> CIOutputs: + def create_ci_outputs(self, model: ComponentModel, device: str) -> CIOutputs: """Expand sparse pre-sigmoid values to full CI tensors and create CIOutputs.""" pre_sigmoid: dict[str, Tensor] = {} @@ -140,7 +140,7 @@ def create_optimizable_ci_params( def compute_label_prob( - model: ComponentModel[Tensor], + model: ComponentModel, tokens: Tensor, ci_lower_leaky: dict[str, Tensor], label_token: int, @@ -167,7 +167,7 @@ def compute_l0_stats( def compute_final_token_ce_kl( - model: ComponentModel[Tensor], + model: ComponentModel, batch: Tensor, target_out: Tensor, ci: dict[str, Tensor], @@ -272,7 +272,7 @@ class OptimCIConfig: def optimize_ci_values( - model: ComponentModel[Tensor], + model: ComponentModel, tokens: Tensor, config: OptimCIConfig, device: str, diff --git a/spd/app/backend/state.py b/spd/app/backend/state.py index bc898eab1..47dacfe51 100644 --- a/spd/app/backend/state.py +++ b/spd/app/backend/state.py @@ -8,7 +8,6 @@ from dataclasses import dataclass, field from typing import Any -from torch import Tensor from transformers.tokenization_utils_base import PreTrainedTokenizerBase from spd.app.backend.database import PromptAttrDB, Run @@ -111,7 +110,7 @@ class RunState: """Runtime state for a loaded run (model, tokenizer, etc.)""" run: Run - model: ComponentModel[Tensor] + model: ComponentModel tokenizer: PreTrainedTokenizerBase sources_by_target: dict[str, list[str]] config: Config diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index a933e070e..cd6a2b742 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -17,7 +17,7 @@ def component_activations( - model: ComponentModel[Tensor], + model: ComponentModel, device: torch.device | str, batch: Int[Tensor, "batch_size n_ctx"], ) -> dict[str, ActivationsTensor]: diff --git a/spd/configs.py b/spd/configs.py index 8d362caf3..20eecbee8 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -148,19 +148,6 @@ class LMTaskConfig(BaseConfig): ) -class IndexOutputExtract(BaseConfig): - type: Literal["index"] = "index" - index: int - - -class AttrOutputExtract(BaseConfig): - type: Literal["attr"] = "attr" - attr: str - - -OutputExtractConfig = IndexOutputExtract | AttrOutputExtract - - class ModulePatternInfoConfig(BaseConfig): """Configuration for a module pattern with its number of components. @@ -574,9 +561,10 @@ def microbatch_size(self) -> PositiveInt: default=None, description="hf model identifier. E.g. 'SimpleStories/SimpleStories-1.25M'", ) - output_extract: Annotated[OutputExtractConfig, Field(discriminator="type")] | None = Field( + output_extract: int | str | None = Field( default=None, - description="How to extract tensor from model output. None = raw output. Note that you can ignore this field if you plan to create your own `run_batch` function to pass to run_spd.optimize().", + description="How to extract tensor from model output. None = raw output, int = index into " + "output tuple, str = attribute name.", ) tokenizer_name: str | None = Field( default=None, @@ -668,9 +656,9 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, case None: pass case "idx_0": - config_dict["output_extract"] = {"type": "index", "index": 0} + config_dict["output_extract"] = 0 case "logits": - config_dict["output_extract"] = {"type": "attr", "attr": "logits"} + config_dict["output_extract"] = "logits" case _: raise ValueError(f"Unknown pretrained_model_output_attr: {old_val!r}") diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index fa95fcbe5..32e7aaf7c 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -41,7 +41,7 @@ class DatasetAttributionConfig: ci_threshold: float -def _build_component_layer_keys(model: ComponentModel[Tensor]) -> list[str]: +def _build_component_layer_keys(model: ComponentModel) -> list[str]: """Build list of component layer keys in canonical order. Returns keys like ["h.0.attn.q_proj:0", "h.0.attn.q_proj:1", ...] for all layers. @@ -56,7 +56,7 @@ def _build_component_layer_keys(model: ComponentModel[Tensor]) -> list[str]: def _build_alive_masks( - model: ComponentModel[Tensor], + model: ComponentModel, run_id: str, ci_threshold: float, n_components: int, diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index 98a98ea77..bd46c3c6c 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -46,7 +46,7 @@ class AttributionHarvester: def __init__( self, - model: ComponentModel[Tensor], + model: ComponentModel, sources_by_target: dict[str, list[str]], n_components: int, vocab_size: int, diff --git a/spd/eval.py b/spd/eval.py index 8bd622e21..eb2c946d6 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -118,13 +118,13 @@ def avg_eval_metrics_across_ranks(metrics: MetricOutType, device: str) -> DistMe return {**metrics, **avg_metrics} -def init_metric[BatchT]( +def init_metric( cfg: MetricConfigType, - model: ComponentModel[BatchT], + model: ComponentModel, run_config: Config, device: str, reconstruction_loss: ReconstructionLoss, -) -> Metric[BatchT]: +) -> Metric: match cfg: case ImportanceMinimalityLossConfig(): metric = ImportanceMinimalityLoss( @@ -288,10 +288,10 @@ def init_metric[BatchT]( return metric -def evaluate[BatchT]( +def evaluate( eval_metric_configs: list[MetricConfigType], - model: ComponentModel[BatchT], - eval_iterator: Iterator[BatchT], + model: ComponentModel, + eval_iterator: Iterator[Any], device: str, run_config: Config, slow_step: bool, @@ -301,7 +301,7 @@ def evaluate[BatchT]( ) -> MetricOutType: """Run evaluation and return a mapping of metric names to values/images.""" - metrics: list[Metric[BatchT]] = [] + metrics: list[Metric] = [] for cfg in eval_metric_configs: metric = init_metric( cfg=cfg, @@ -350,12 +350,12 @@ def evaluate[BatchT]( return outputs -def evaluate_multibatch_pgd[BatchT]( +def evaluate_multibatch_pgd( multibatch_pgd_eval_configs: list[ PGDMultiBatchReconLossConfig | PGDMultiBatchReconSubsetLossConfig ], - model: ComponentModel[BatchT], - create_data_iter: CreateDataIter[BatchT], + model: ComponentModel, + create_data_iter: CreateDataIter, config: Config, device: str, reconstruction_loss: ReconstructionLoss, diff --git a/spd/experiments/lm/gpt2_config.yaml b/spd/experiments/lm/gpt2_config.yaml index a58bd4d1b..a2414f99c 100644 --- a/spd/experiments/lm/gpt2_config.yaml +++ b/spd/experiments/lm/gpt2_config.yaml @@ -63,9 +63,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: openai-community/gpt2 -output_extract: - type: attr - attr: logits +output_extract: logits tokenizer_name: openai-community/gpt2 # --- Task Specific --- diff --git a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml index 81529ec37..efb62bc9c 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml @@ -114,9 +114,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-bd02d372 -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index 787f10998..8a8f7f632 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -118,9 +118,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-32d1bb3b -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/ss_gpt2_config.yaml b/spd/experiments/lm/ss_gpt2_config.yaml index 0ac2214e1..d1b8977ca 100644 --- a/spd/experiments/lm/ss_gpt2_config.yaml +++ b/spd/experiments/lm/ss_gpt2_config.yaml @@ -63,9 +63,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: SimpleStories/test-SimpleStories-gpt2-1.25M -output_extract: - type: attr - attr: logits +output_extract: logits tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index aab53d5c7..8ee2e0e0f 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -91,9 +91,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/3qhd7rnb # 100k steps. 4019 tokenizer -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index 1b2dd89fa..f80c49936 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -93,9 +93,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/wr1su18m # 100k steps. 4019 tokenizer -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index 9032cc46b..36e6a581d 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -96,9 +96,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/rvu66183 # 100k steps. 4019 tokenizer -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml index 212e9c961..56ec9e14e 100644 --- a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml @@ -93,9 +93,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/xi36b9az # No ln -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # We'll load this from wandb in future # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index c7e3f1711..59a609a4e 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -91,9 +91,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tfacbi70 # 100k steps -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index c1807a7e1..0c321c289 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -93,9 +93,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tb8373uo # 100k steps -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_config.yaml b/spd/experiments/lm/ss_llama_simple_config.yaml index 71ea7a034..d90cf9868 100644 --- a/spd/experiments/lm/ss_llama_simple_config.yaml +++ b/spd/experiments/lm/ss_llama_simple_config.yaml @@ -93,9 +93,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_name: wandb:goodfire/spd/runs/erq48r3w # 100k steps 4019 tok -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index e7cabb6a4..62f7003cd 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -85,9 +85,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gvbmdt9w -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index 24406f216..e7a6929ea 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -93,9 +93,7 @@ ci_alive_threshold: 0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: buffer_size: 1000 diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index e0933f4d1..2262cbeaf 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -91,9 +91,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/7pt957pf # 100k steps -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index 1cb007f87..271ed0813 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -116,9 +116,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/9de1zu65 # 100k steps -output_extract: - type: index - index: 0 +output_extract: 0 tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ts_config.yaml b/spd/experiments/lm/ts_config.yaml index 1e252c946..f6f39aeb8 100644 --- a/spd/experiments/lm/ts_config.yaml +++ b/spd/experiments/lm/ts_config.yaml @@ -64,9 +64,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.AutoModelForCausalLM pretrained_model_name: roneneldan/TinyStories-1M -output_extract: - type: attr - attr: logits +output_extract: logits tokenizer_name: EleutherAI/gpt-neo-125M # --- Task Specific --- diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index 9b10e9077..e2afd6eab 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -16,7 +16,6 @@ import time from dataclasses import asdict, dataclass from pathlib import Path -from typing import Any import torch import tqdm @@ -38,7 +37,7 @@ from spd.utils.general_utils import bf16_autocast -def _compute_u_norms(model: ComponentModel[Any]) -> dict[str, Float[Tensor, " C"]]: +def _compute_u_norms(model: ComponentModel) -> dict[str, Float[Tensor, " C"]]: """Compute ||U[c,:]|| for each component c in each layer. Component activations (v_i^T @ a) have a scale invariance: scaling V by α and U by 1/α diff --git a/spd/losses.py b/spd/losses.py index 6531279c2..5d0aa9525 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -1,3 +1,5 @@ +from typing import Any + import torch from jaxtyping import Float from torch import Tensor @@ -39,10 +41,10 @@ from spd.utils.general_utils import get_obj_device -def compute_total_loss[BatchT]( +def compute_total_loss( loss_metric_configs: list[LossMetricConfigType], - model: ComponentModel[BatchT], - batch: BatchT, + model: ComponentModel, + batch: Any, ci: CIOutputs, target_out: Tensor, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/base.py b/spd/metrics/base.py index 36c3de024..2e9a0fd4d 100644 --- a/spd/metrics/base.py +++ b/spd/metrics/base.py @@ -12,7 +12,7 @@ from spd.models.component_model import CIOutputs -class Metric[BatchT](Protocol): +class Metric(Protocol): """Interface for metrics that can be used in training and/or evaluation.""" slow: ClassVar[bool] = False @@ -21,7 +21,7 @@ class Metric[BatchT](Protocol): def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: CIOutputs, diff --git a/spd/metrics/ce_and_kl_losses.py b/spd/metrics/ce_and_kl_losses.py index 83c6e8766..d93dcbc86 100644 --- a/spd/metrics/ce_and_kl_losses.py +++ b/spd/metrics/ce_and_kl_losses.py @@ -17,7 +17,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class CEandKLLosses(Metric[Any]): +class CEandKLLosses(Metric): """CE and KL losses for different masking strategies. NOTE: Assumes all batches and sequences are the same size. @@ -47,7 +47,7 @@ class CEandKLLosses(Metric[Any]): def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, device: str, sampling: SamplingType, rounding_threshold: float, diff --git a/spd/metrics/ci_histograms.py b/spd/metrics/ci_histograms.py index 518788572..fcf6fb2ac 100644 --- a/spd/metrics/ci_histograms.py +++ b/spd/metrics/ci_histograms.py @@ -12,13 +12,13 @@ from spd.utils.distributed_utils import gather_all_tensors -class CIHistograms(Metric[Any]): +class CIHistograms(Metric): slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, n_batches_accum: int | None = None, ): self.n_batches_accum = n_batches_accum diff --git a/spd/metrics/ci_l0.py b/spd/metrics/ci_l0.py index 84290e96c..9a2047ff8 100644 --- a/spd/metrics/ci_l0.py +++ b/spd/metrics/ci_l0.py @@ -12,7 +12,7 @@ from spd.utils.distributed_utils import all_reduce -class CI_L0(Metric[Any]): +class CI_L0(Metric): """L0 metric for CI values. NOTE: Assumes all batches and sequences are the same size. @@ -22,7 +22,7 @@ class CI_L0(Metric[Any]): def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, device: str, ci_alive_threshold: float, groups: dict[str, list[str]] | None = None, diff --git a/spd/metrics/ci_masked_recon_layerwise_loss.py b/spd/metrics/ci_masked_recon_layerwise_loss.py index bfb5103e4..2862eb9e7 100644 --- a/spd/metrics/ci_masked_recon_layerwise_loss.py +++ b/spd/metrics/ci_masked_recon_layerwise_loss.py @@ -13,9 +13,9 @@ from spd.utils.general_utils import get_obj_device -def _ci_masked_recon_layerwise_loss_update[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def _ci_masked_recon_layerwise_loss_update( + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], reconstruction_loss: ReconstructionLoss, @@ -37,9 +37,9 @@ def _ci_masked_recon_layerwise_loss_compute( return sum_loss / sum_n_examples -def ci_masked_recon_layerwise_loss[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def ci_masked_recon_layerwise_loss( + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], reconstruction_loss: ReconstructionLoss, @@ -54,14 +54,14 @@ def ci_masked_recon_layerwise_loss[BatchT]( return _ci_masked_recon_layerwise_loss_compute(sum_loss, sum_n_examples) -class CIMaskedReconLayerwiseLoss[BatchT](Metric[BatchT]): +class CIMaskedReconLayerwiseLoss(Metric): """Recon loss when masking with CI values directly one layer at a time.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, reconstruction_loss: ReconstructionLoss, ) -> None: @@ -74,7 +74,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, **_: Any, diff --git a/spd/metrics/ci_masked_recon_loss.py b/spd/metrics/ci_masked_recon_loss.py index 836460289..2cb871064 100644 --- a/spd/metrics/ci_masked_recon_loss.py +++ b/spd/metrics/ci_masked_recon_loss.py @@ -12,9 +12,9 @@ from spd.utils.distributed_utils import all_reduce -def _ci_masked_recon_loss_update[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def _ci_masked_recon_loss_update( + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], reconstruction_loss: ReconstructionLoss, @@ -30,9 +30,9 @@ def _ci_masked_recon_loss_compute( return sum_loss / n_examples -def ci_masked_recon_loss[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def ci_masked_recon_loss( + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], reconstruction_loss: ReconstructionLoss, @@ -47,14 +47,14 @@ def ci_masked_recon_loss[BatchT]( return _ci_masked_recon_loss_compute(sum_loss, n_examples) -class CIMaskedReconLoss[BatchT](Metric[BatchT]): +class CIMaskedReconLoss(Metric): """Recon loss when masking with CI values directly on all component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, reconstruction_loss: ReconstructionLoss, ) -> None: @@ -67,7 +67,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, **_: Any, diff --git a/spd/metrics/ci_masked_recon_subset_loss.py b/spd/metrics/ci_masked_recon_subset_loss.py index 3e494a36c..785ae15bd 100644 --- a/spd/metrics/ci_masked_recon_subset_loss.py +++ b/spd/metrics/ci_masked_recon_subset_loss.py @@ -15,9 +15,9 @@ from spd.utils.general_utils import get_obj_device -def _ci_masked_recon_subset_loss_update[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def _ci_masked_recon_subset_loss_update( + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], router: Router, @@ -42,9 +42,9 @@ def _ci_masked_recon_subset_loss_compute( return sum_loss / n_examples -def ci_masked_recon_subset_loss[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def ci_masked_recon_subset_loss( + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], routing: SubsetRoutingType, @@ -61,14 +61,14 @@ def ci_masked_recon_subset_loss[BatchT]( return _ci_masked_recon_subset_loss_compute(sum_loss, n_examples) -class CIMaskedReconSubsetLoss[BatchT](Metric[BatchT]): +class CIMaskedReconSubsetLoss(Metric): """Recon loss when masking with raw CI values and routing to subsets of component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, routing: SubsetRoutingType, reconstruction_loss: ReconstructionLoss, @@ -83,7 +83,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, **_: Any, diff --git a/spd/metrics/ci_mean_per_component.py b/spd/metrics/ci_mean_per_component.py index 7fd9e0c7c..88800e7a4 100644 --- a/spd/metrics/ci_mean_per_component.py +++ b/spd/metrics/ci_mean_per_component.py @@ -11,11 +11,11 @@ from spd.utils.distributed_utils import all_reduce -class CIMeanPerComponent(Metric[Any]): +class CIMeanPerComponent(Metric): slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" - def __init__(self, model: ComponentModel[Any], device: str) -> None: + def __init__(self, model: ComponentModel, device: str) -> None: self.components = model.components self.component_ci_sums: dict[str, Tensor] = { module_name: torch.zeros(model.module_to_c[module_name], device=device) diff --git a/spd/metrics/component_activation_density.py b/spd/metrics/component_activation_density.py index 7ff1a150d..5f10a86cc 100644 --- a/spd/metrics/component_activation_density.py +++ b/spd/metrics/component_activation_density.py @@ -13,13 +13,13 @@ from spd.utils.distributed_utils import all_reduce -class ComponentActivationDensity(Metric[Any]): +class ComponentActivationDensity(Metric): """Activation density for each component.""" slow: ClassVar[bool] = True metric_section: ClassVar[str] = "figures" - def __init__(self, model: ComponentModel[Any], device: str, ci_alive_threshold: float) -> None: + def __init__(self, model: ComponentModel, device: str, ci_alive_threshold: float) -> None: self.model = model self.ci_alive_threshold = ci_alive_threshold diff --git a/spd/metrics/faithfulness_loss.py b/spd/metrics/faithfulness_loss.py index 751a2e84c..d2b02e0d4 100644 --- a/spd/metrics/faithfulness_loss.py +++ b/spd/metrics/faithfulness_loss.py @@ -35,12 +35,12 @@ def faithfulness_loss(weight_deltas: dict[str, Float[Tensor, "d_out d_in"]]) -> return _faithfulness_loss_compute(sum_loss, total_params) -class FaithfulnessLoss(Metric[Any]): +class FaithfulnessLoss(Metric): """MSE between the target weights and the sum of the components.""" metric_section: ClassVar[str] = "loss" - def __init__(self, model: ComponentModel[Any], device: str) -> None: + def __init__(self, model: ComponentModel, device: str) -> None: self.model = model self.sum_loss = torch.tensor(0.0, device=device) self.total_params = torch.tensor(0, device=device) diff --git a/spd/metrics/identity_ci_error.py b/spd/metrics/identity_ci_error.py index ca1210ee5..7c3dcf881 100644 --- a/spd/metrics/identity_ci_error.py +++ b/spd/metrics/identity_ci_error.py @@ -9,7 +9,7 @@ from spd.utils.target_ci_solutions import compute_target_metrics, make_target_ci_solution -class IdentityCIError(Metric[Any]): +class IdentityCIError(Metric): """Error between the CI values and an Identity or Dense CI pattern.""" slow: ClassVar[bool] = True @@ -19,7 +19,7 @@ class IdentityCIError(Metric[Any]): def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, sampling: SamplingType, identity_ci: list[dict[str, str | int]] | None = None, dense_ci: list[dict[str, str | int]] | None = None, diff --git a/spd/metrics/importance_minimality_loss.py b/spd/metrics/importance_minimality_loss.py index 5fc607797..5bd6ca31f 100644 --- a/spd/metrics/importance_minimality_loss.py +++ b/spd/metrics/importance_minimality_loss.py @@ -144,7 +144,7 @@ def importance_minimality_loss( ) -class ImportanceMinimalityLoss(Metric[Any]): +class ImportanceMinimalityLoss(Metric): """L_p loss on the sum of CI values. NOTE: We don't normalize over the number of layers because a change in the number of layers @@ -165,7 +165,7 @@ class ImportanceMinimalityLoss(Metric[Any]): def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, device: str, pnorm: float, beta: float, diff --git a/spd/metrics/permuted_ci_plots.py b/spd/metrics/permuted_ci_plots.py index d859b4119..77d713499 100644 --- a/spd/metrics/permuted_ci_plots.py +++ b/spd/metrics/permuted_ci_plots.py @@ -9,7 +9,7 @@ from spd.plotting import plot_causal_importance_vals -class PermutedCIPlots(Metric[Any]): +class PermutedCIPlots(Metric): slow: ClassVar[bool] = True input_magnitude: ClassVar[float] = 0.75 @@ -17,7 +17,7 @@ class PermutedCIPlots(Metric[Any]): def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, sampling: SamplingType, identity_patterns: list[str] | None = None, dense_patterns: list[str] | None = None, diff --git a/spd/metrics/pgd_masked_recon_layerwise_loss.py b/spd/metrics/pgd_masked_recon_layerwise_loss.py index 9a878fe96..97c45f3bd 100644 --- a/spd/metrics/pgd_masked_recon_layerwise_loss.py +++ b/spd/metrics/pgd_masked_recon_layerwise_loss.py @@ -14,10 +14,10 @@ from spd.utils.distributed_utils import all_reduce -def _pgd_recon_layerwise_loss_update[BatchT]( +def _pgd_recon_layerwise_loss_update( *, - model: ComponentModel[BatchT], - batch: BatchT, + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -43,10 +43,10 @@ def _pgd_recon_layerwise_loss_update[BatchT]( return sum_loss, n_examples -def pgd_recon_layerwise_loss[BatchT]( +def pgd_recon_layerwise_loss( *, - model: ComponentModel[BatchT], - batch: BatchT, + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -65,7 +65,7 @@ def pgd_recon_layerwise_loss[BatchT]( return sum_loss / n_examples -class PGDReconLayerwiseLoss[BatchT](Metric[BatchT]): +class PGDReconLayerwiseLoss(Metric): """Recon loss when masking with adversarially-optimized values and routing to one layer at a time.""" @@ -73,7 +73,7 @@ class PGDReconLayerwiseLoss[BatchT](Metric[BatchT]): def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, pgd_config: PGDConfig, device: str, use_delta_component: bool, @@ -90,7 +90,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, diff --git a/spd/metrics/pgd_masked_recon_loss.py b/spd/metrics/pgd_masked_recon_loss.py index a64a7b4f7..4ab242393 100644 --- a/spd/metrics/pgd_masked_recon_loss.py +++ b/spd/metrics/pgd_masked_recon_loss.py @@ -14,10 +14,10 @@ from spd.utils.distributed_utils import all_reduce -def pgd_recon_loss[BatchT]( +def pgd_recon_loss( *, - model: ComponentModel[BatchT], - batch: BatchT, + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -37,7 +37,7 @@ def pgd_recon_loss[BatchT]( return sum_loss / n_examples -class PGDReconLoss[BatchT](Metric[BatchT]): +class PGDReconLoss(Metric): """Recon loss when masking with adversarially-optimized values and routing to all component layers.""" @@ -45,7 +45,7 @@ class PGDReconLoss[BatchT](Metric[BatchT]): def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, pgd_config: PGDConfig, use_delta_component: bool, @@ -62,7 +62,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, diff --git a/spd/metrics/pgd_masked_recon_subset_loss.py b/spd/metrics/pgd_masked_recon_subset_loss.py index 1fdd10180..c904c14d7 100644 --- a/spd/metrics/pgd_masked_recon_subset_loss.py +++ b/spd/metrics/pgd_masked_recon_subset_loss.py @@ -15,10 +15,10 @@ from spd.utils.general_utils import get_obj_device -def pgd_recon_subset_loss[BatchT]( +def pgd_recon_subset_loss( *, - model: ComponentModel[BatchT], - batch: BatchT, + model: ComponentModel, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -39,7 +39,7 @@ def pgd_recon_subset_loss[BatchT]( return sum_loss / n_examples -class PGDReconSubsetLoss[BatchT](Metric[BatchT]): +class PGDReconSubsetLoss(Metric): """Recon loss when masking with adversarially-optimized values and routing to subsets of component layers.""" @@ -47,7 +47,7 @@ class PGDReconSubsetLoss[BatchT](Metric[BatchT]): def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, use_delta_component: bool, pgd_config: PGDConfig, @@ -67,7 +67,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index d7eb3e873..132c1a12f 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -1,6 +1,6 @@ -from collections.abc import Iterator +from collections.abc import Callable, Iterator from functools import partial -from typing import Protocol +from typing import Any import torch from jaxtyping import Float @@ -17,9 +17,9 @@ from spd.utils.general_utils import get_obj_device -def pgd_masked_recon_loss_update[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def pgd_masked_recon_loss_update( + model: ComponentModel, + batch: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, target_out: Tensor, @@ -80,15 +80,14 @@ def pgd_masked_recon_loss_update[BatchT]( return fwd_pass() -class CreateDataIter[BatchT](Protocol): - def __call__(self) -> Iterator[BatchT]: ... +CreateDataIter = Callable[[], Iterator[Any]] -def calc_multibatch_pgd_masked_recon_loss[BatchT]( +def calc_multibatch_pgd_masked_recon_loss( pgd_config: PGDMultiBatchConfig, - model: ComponentModel[BatchT], + model: ComponentModel, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - create_data_iter: CreateDataIter[BatchT], + create_data_iter: CreateDataIter, router: Router, sampling: SamplingType, use_delta_component: bool, @@ -156,9 +155,9 @@ def calc_multibatch_pgd_masked_recon_loss[BatchT]( return final_loss / final_sum_n_examples -def _forward_with_adv_sources[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def _forward_with_adv_sources( + model: ComponentModel, + batch: Any, adv_sources: dict[str, Float[Tensor, "*batch_dim_or_ones mask_c"]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -191,12 +190,12 @@ def _forward_with_adv_sources[BatchT]( return sum_loss, n_examples -def _multibatch_pgd_fwd_bwd[BatchT]( +def _multibatch_pgd_fwd_bwd( adv_sources: dict[str, Float[Tensor, "*ones mask_c"]], pgd_config: PGDMultiBatchConfig, - model: ComponentModel[BatchT], + model: ComponentModel, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - data_iter: Iterator[BatchT], + data_iter: Iterator[Any], device: torch.device | str, router: Router, sampling: SamplingType, diff --git a/spd/metrics/stochastic_hidden_acts_recon_loss.py b/spd/metrics/stochastic_hidden_acts_recon_loss.py index f2fad6f33..2e97e2dda 100644 --- a/spd/metrics/stochastic_hidden_acts_recon_loss.py +++ b/spd/metrics/stochastic_hidden_acts_recon_loss.py @@ -14,11 +14,11 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_hidden_acts_recon_loss_update[BatchT]( - model: ComponentModel[BatchT], +def _stochastic_hidden_acts_recon_loss_update( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -59,11 +59,11 @@ def _stochastic_hidden_acts_recon_loss_compute( return sum_mse / n_examples -def stochastic_hidden_acts_recon_loss[BatchT]( - model: ComponentModel[BatchT], +def stochastic_hidden_acts_recon_loss( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -80,14 +80,14 @@ def stochastic_hidden_acts_recon_loss[BatchT]( return _stochastic_hidden_acts_recon_loss_compute(sum_mse, n_examples) -class StochasticHiddenActsReconLoss[BatchT](Metric[BatchT]): +class StochasticHiddenActsReconLoss(Metric): """Reconstruction loss between target and stochastic hidden activations when sampling with stochastic masks.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, sampling: SamplingType, use_delta_component: bool, @@ -104,7 +104,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, pre_weight_acts: dict[str, Float[Tensor, "..."]], ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/stochastic_recon_layerwise_loss.py b/spd/metrics/stochastic_recon_layerwise_loss.py index de01b58a0..d90675e6e 100644 --- a/spd/metrics/stochastic_recon_layerwise_loss.py +++ b/spd/metrics/stochastic_recon_layerwise_loss.py @@ -15,11 +15,11 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_recon_layerwise_loss_update[BatchT]( - model: ComponentModel[BatchT], +def _stochastic_recon_layerwise_loss_update( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -55,11 +55,11 @@ def _stochastic_recon_layerwise_loss_compute( return sum_loss / sum_n_examples -def stochastic_recon_layerwise_loss[BatchT]( - model: ComponentModel[BatchT], +def stochastic_recon_layerwise_loss( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -78,14 +78,14 @@ def stochastic_recon_layerwise_loss[BatchT]( return _stochastic_recon_layerwise_loss_compute(sum_loss, sum_n_examples) -class StochasticReconLayerwiseLoss[BatchT](Metric[BatchT]): +class StochasticReconLayerwiseLoss(Metric): """Recon loss when sampling with stochastic masks one layer at a time.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, sampling: SamplingType, use_delta_component: bool, @@ -104,7 +104,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/stochastic_recon_loss.py b/spd/metrics/stochastic_recon_loss.py index 59650c322..793fbed32 100644 --- a/spd/metrics/stochastic_recon_loss.py +++ b/spd/metrics/stochastic_recon_loss.py @@ -15,11 +15,11 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_recon_loss_update[BatchT]( - model: ComponentModel[BatchT], +def _stochastic_recon_loss_update( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -51,11 +51,11 @@ def _stochastic_recon_loss_compute( return sum_loss / sum_n_examples -def stochastic_recon_loss[BatchT]( - model: ComponentModel[BatchT], +def stochastic_recon_loss( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -74,14 +74,14 @@ def stochastic_recon_loss[BatchT]( return _stochastic_recon_loss_compute(sum_loss, sum_n_examples) -class StochasticReconLoss[BatchT](Metric[BatchT]): +class StochasticReconLoss(Metric): """Recon loss when sampling with stochastic masks on all component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, sampling: SamplingType, use_delta_component: bool, @@ -100,7 +100,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/stochastic_recon_subset_ce_and_kl.py b/spd/metrics/stochastic_recon_subset_ce_and_kl.py index cdf52d183..b1a98b2f2 100644 --- a/spd/metrics/stochastic_recon_subset_ce_and_kl.py +++ b/spd/metrics/stochastic_recon_subset_ce_and_kl.py @@ -19,7 +19,7 @@ from spd.utils.general_utils import calc_kl_divergence_lm -class StochasticReconSubsetCEAndKL(Metric[Any]): +class StochasticReconSubsetCEAndKL(Metric): """Compute reconstruction loss for specific subsets of components. NOTE: Assumes all batches and sequences are the same size. @@ -29,7 +29,7 @@ class StochasticReconSubsetCEAndKL(Metric[Any]): def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, device: str, sampling: SamplingType, use_delta_component: bool, diff --git a/spd/metrics/stochastic_recon_subset_loss.py b/spd/metrics/stochastic_recon_subset_loss.py index 0fc074fa9..85293b17b 100644 --- a/spd/metrics/stochastic_recon_subset_loss.py +++ b/spd/metrics/stochastic_recon_subset_loss.py @@ -15,11 +15,11 @@ from spd.utils.general_utils import get_obj_device -def _stochastic_recon_subset_loss_update[BatchT]( - model: ComponentModel[BatchT], +def _stochastic_recon_subset_loss_update( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -55,11 +55,11 @@ def _stochastic_recon_subset_loss_compute( return sum_loss / n_examples -def stochastic_recon_subset_loss[BatchT]( - model: ComponentModel[BatchT], +def stochastic_recon_subset_loss( + model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - batch: BatchT, + batch: Any, target_out: Tensor, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, @@ -80,14 +80,14 @@ def stochastic_recon_subset_loss[BatchT]( return _stochastic_recon_subset_loss_compute(sum_loss, n_examples) -class StochasticReconSubsetLoss[BatchT](Metric[BatchT]): +class StochasticReconSubsetLoss(Metric): """Recon loss when sampling with stochastic masks and routing to subsets of component layers.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, sampling: SamplingType, use_delta_component: bool, @@ -108,7 +108,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], diff --git a/spd/metrics/unmasked_recon_loss.py b/spd/metrics/unmasked_recon_loss.py index 027022274..e113d5e36 100644 --- a/spd/metrics/unmasked_recon_loss.py +++ b/spd/metrics/unmasked_recon_loss.py @@ -13,9 +13,9 @@ from spd.utils.general_utils import get_obj_device -def _unmasked_recon_loss_update[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def _unmasked_recon_loss_update( + model: ComponentModel, + batch: Any, target_out: Tensor, reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: @@ -36,9 +36,9 @@ def _unmasked_recon_loss_compute( return sum_loss / n_examples -def unmasked_recon_loss[BatchT]( - model: ComponentModel[BatchT], - batch: BatchT, +def unmasked_recon_loss( + model: ComponentModel, + batch: Any, target_out: Tensor, reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: @@ -51,14 +51,14 @@ def unmasked_recon_loss[BatchT]( return _unmasked_recon_loss_compute(sum_loss, n_examples) -class UnmaskedReconLoss[BatchT](Metric[BatchT]): +class UnmaskedReconLoss(Metric): """Recon loss using the unmasked components and without the delta component.""" metric_section: ClassVar[str] = "loss" def __init__( self, - model: ComponentModel[BatchT], + model: ComponentModel, device: str, reconstruction_loss: ReconstructionLoss, ) -> None: @@ -71,7 +71,7 @@ def __init__( def update( self, *, - batch: BatchT, + batch: Any, target_out: Tensor, **_: Any, ) -> None: diff --git a/spd/metrics/uv_plots.py b/spd/metrics/uv_plots.py index 5b2b4e775..0d29c6401 100644 --- a/spd/metrics/uv_plots.py +++ b/spd/metrics/uv_plots.py @@ -9,7 +9,7 @@ from spd.plotting import plot_causal_importance_vals, plot_UV_matrices -class UVPlots(Metric[Any]): +class UVPlots(Metric): metric_section: ClassVar[str] = "figures" slow: ClassVar[bool] = True @@ -17,7 +17,7 @@ class UVPlots(Metric[Any]): def __init__( self, - model: ComponentModel[Any], + model: ComponentModel, sampling: SamplingType, identity_patterns: list[str] | None = None, dense_patterns: list[str] | None = None, diff --git a/spd/models/batch_and_loss_fns.py b/spd/models/batch_and_loss_fns.py index 69ed3fb7e..c603e638a 100644 --- a/spd/models/batch_and_loss_fns.py +++ b/spd/models/batch_and_loss_fns.py @@ -10,14 +10,13 @@ from jaxtyping import Float from torch import Tensor, nn -from spd.configs import AttrOutputExtract, IndexOutputExtract, OutputExtractConfig from spd.utils.general_utils import runtime_cast -class RunBatch[BatchT](Protocol): +class RunBatch(Protocol): """Protocol for running a batch through a model and returning the output.""" - def __call__(self, model: nn.Module, batch: BatchT) -> Tensor: ... + def __call__(self, model: nn.Module, batch: Any) -> Tensor: ... class ReconstructionLoss(Protocol): @@ -30,23 +29,29 @@ def run_batch_passthrough(model: nn.Module, batch: Any) -> Tensor: return runtime_cast(Tensor, model(batch)) -def make_run_batch(output_extract: OutputExtractConfig | None) -> RunBatch[Any]: - """creates a RunBatch function for a given configuration. +def make_run_batch(output_extract: int | str | None) -> RunBatch: + """Creates a RunBatch function for a given configuration. - Note that if you plan to override the RunBatch functionality, you can simply pass + NOTE: If you plan to override the RunBatch functionality, you can simply pass a custom RunBatch function into optimize and do not need to use this function at all. + + Args: + output_extract: How to extract the tensor from model output. + None: passthrough (model output is the tensor) + int: index into model output tuple (e.g. 0 for first element) + str: attribute name on model output (e.g. "logits") """ match output_extract: case None: return run_batch_passthrough - case IndexOutputExtract(index=idx): + case int(idx): def _run_index(model: nn.Module, batch: Any) -> Tensor: return model(batch)[idx] return _run_index - case AttrOutputExtract(attr=attr): + case str(attr): def _run_attr(model: nn.Module, batch: Any) -> Tensor: return getattr(model(batch), attr) diff --git a/spd/models/component_model.py b/spd/models/component_model.py index ccd373ba7..fd5b80f56 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -53,7 +53,7 @@ class CIOutputs: pre_sigmoid: dict[str, Tensor] -class ComponentModel[BatchT](nn.Module): +class ComponentModel(nn.Module): """Wrapper around an arbitrary pytorch model for running SPD. The underlying *base model* can be any subclass of `nn.Module` (e.g. @@ -75,14 +75,14 @@ class ComponentModel[BatchT](nn.Module): def __init__( self, target_model: nn.Module, - run_batch: RunBatch[BatchT], + run_batch: RunBatch, module_path_info: list[ModulePathInfo], ci_fn_type: CiFnType, ci_fn_hidden_dims: list[int], sigmoid_type: SigmoidType, ): super().__init__() - self._run_batch: RunBatch[BatchT] = run_batch + self._run_batch: RunBatch = run_batch for name, param in target_model.named_parameters(): assert not param.requires_grad, ( @@ -121,7 +121,7 @@ def __init__( self.upper_leaky_fn = SIGMOID_TYPES[sigmoid_type] @classmethod - def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any]": + def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": """Load a trained ComponentModel from a run info object.""" config = run_info.config @@ -162,7 +162,7 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel[Any]": return comp_model @classmethod - def from_pretrained(cls, path: ModelPath) -> "ComponentModel[Any]": + def from_pretrained(cls, path: ModelPath) -> "ComponentModel": """Load a trained ComponentModel from a wandb or local path.""" run_info = SPDRunInfo.from_path(path) return cls.from_run_info(run_info) @@ -285,7 +285,7 @@ def _create_ci_fns( @overload def __call__( self, - batch: BatchT, + batch: Any, cache_type: Literal["component_acts"], mask_infos: dict[str, ComponentsMaskInfo] | None = None, ) -> OutputWithCache: ... @@ -293,7 +293,7 @@ def __call__( @overload def __call__( self, - batch: BatchT, + batch: Any, cache_type: Literal["input"], mask_infos: dict[str, ComponentsMaskInfo] | None = None, ) -> OutputWithCache: ... @@ -301,7 +301,7 @@ def __call__( @overload def __call__( self, - batch: BatchT, + batch: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["none"] = "none", ) -> Tensor: ... @@ -313,7 +313,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache: @override def forward( self, - batch: BatchT, + batch: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["component_acts", "input", "none"] = "none", ) -> Tensor | OutputWithCache: diff --git a/spd/plotting.py b/spd/plotting.py index dffdff417..81c9c1d5d 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -1,7 +1,6 @@ import fnmatch import io from collections.abc import Callable -from typing import Any import numpy as np import torch @@ -183,7 +182,7 @@ def plot_mean_component_cis_both_scales( def get_single_feature_causal_importances( - model: ComponentModel[Any], + model: ComponentModel, batch_shape: tuple[int, ...], input_magnitude: float, sampling: SamplingType, @@ -217,7 +216,7 @@ def get_single_feature_causal_importances( def plot_causal_importance_vals( - model: ComponentModel[Any], + model: ComponentModel, batch_shape: tuple[int, ...], input_magnitude: float, sampling: SamplingType, diff --git a/spd/run_spd.py b/spd/run_spd.py index e7af079c6..cda6c29b6 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -51,7 +51,7 @@ def run_faithfulness_warmup( - component_model: ComponentModel[Any], + component_model: ComponentModel, component_params: list[torch.nn.Parameter], config: Config, ) -> None: @@ -109,13 +109,13 @@ def get_unique_metric_configs( return eval_metric_configs -def optimize[BatchT]( +def optimize( target_model: nn.Module, config: Config, device: str, - train_loader: DataLoader[BatchT], - eval_loader: DataLoader[BatchT], - run_batch: RunBatch[BatchT], + train_loader: DataLoader[Any], + eval_loader: DataLoader[Any], + run_batch: RunBatch, reconstruction_loss: ReconstructionLoss, out_dir: Path | None, tied_weights: list[tuple[str, str]] | None = None, @@ -125,7 +125,7 @@ def optimize[BatchT]( train_iterator = loop_dataloader(train_loader) eval_iterator = loop_dataloader(eval_loader) - def create_pgd_data_iter() -> Iterator[BatchT]: + def create_pgd_data_iter() -> Iterator[Any]: assert hasattr(train_loader, "generator") and train_loader.generator is not None train_loader.generator.manual_seed(config.seed) return iter(train_loader) @@ -158,7 +158,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: dist_state = get_distributed_state() wrapped_model: nn.Module = model - component_model: ComponentModel[BatchT] + component_model: ComponentModel if dist_state is not None: if dist_state.backend == "nccl": device_id = dist_state.local_rank @@ -171,7 +171,7 @@ def create_pgd_data_iter() -> Iterator[BatchT]: # For CPU, don't pass device_ids or output_device wrapped_model = torch.nn.parallel.DistributedDataParallel(model) # Access the underlying module for component operations - component_model = cast(ComponentModel[BatchT], wrapped_model.module) # type: ignore[attr-defined] + component_model = cast(ComponentModel, wrapped_model.module) # type: ignore[attr-defined] else: component_model = model assert isinstance(component_model, ComponentModel), "component_model is not a ComponentModel" diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index 0acefeb53..3e93da8ef 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -79,7 +79,7 @@ def __init__(self, config: CompareModelsConfig): config.reference_model_path ) - def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel[Any], Config]: + def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, Config]: """Load model and config using the standard pattern from existing codebase.""" run_info = SPDRunInfo.from_path(model_path) # TODO(oli): this should actually be generic (one of the only instances of this I think) @@ -234,7 +234,7 @@ def _create_ih_data_loader(self) -> Iterator[Any]: ) def compute_activation_densities( - self, model: ComponentModel[Any], eval_iterator: Iterator[Any], n_steps: int + self, model: ComponentModel, eval_iterator: Iterator[Any], n_steps: int ) -> dict[str, Float[Tensor, " C"]]: """Compute activation densities using same logic as ComponentActivationDensity.""" diff --git a/spd/utils/logging_utils.py b/spd/utils/logging_utils.py index 3555cadc7..fd39afeeb 100644 --- a/spd/utils/logging_utils.py +++ b/spd/utils/logging_utils.py @@ -40,7 +40,7 @@ def local_log(data: dict[str, Any], step: int, out_dir: Path) -> None: def get_grad_norms_dict( - component_model: ComponentModel[Any], device: torch.device | str + component_model: ComponentModel, device: torch.device | str ) -> dict[str, float]: """Create a dictionary of gradient norms for the parameters of a component model.""" diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index 70f3d62c2..72ace2cc6 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -22,7 +22,6 @@ from spd.app.backend.state import HarvestCache, RunState, StateManager from spd.configs import ( Config, - IndexOutputExtract, LMTaskConfig, ModulePatternInfoConfig, ScheduleConfig, @@ -98,7 +97,7 @@ def app_with_state(): ModulePatternInfoConfig(module_pattern=p, C=C) for p in target_module_patterns ], pretrained_model_class="spd.pretrain.models.gpt2_simple.GPT2Simple", - output_extract=IndexOutputExtract(index=0), + output_extract=0, tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", lr_schedule=ScheduleConfig(start_val=1e-3), steps=1, diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index dfbf95a51..d2caeaccd 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -1,6 +1,6 @@ """Shared test fixtures for loss function tests.""" -from typing import Any, override +from typing import override import torch import torch.nn as nn @@ -41,7 +41,7 @@ def forward(self, x: Tensor) -> Tensor: def make_one_layer_component_model( weight: Float[Tensor, "d_out d_in"], -) -> ComponentModel[Any]: +) -> ComponentModel: """Create a ComponentModel with a single linear layer for testing. Args: @@ -70,7 +70,7 @@ def make_one_layer_component_model( def make_two_layer_component_model( weight1: Float[Tensor, " d_hidden d_in"], weight2: Float[Tensor, " d_out d_hidden"] -) -> ComponentModel[Any]: +) -> ComponentModel: """Create a ComponentModel with two linear layers for testing. Args: diff --git a/tests/metrics/test_faithfulness_loss.py b/tests/metrics/test_faithfulness_loss.py index a8b89e259..b6036f72d 100644 --- a/tests/metrics/test_faithfulness_loss.py +++ b/tests/metrics/test_faithfulness_loss.py @@ -1,5 +1,3 @@ -from typing import Any - import torch from spd.metrics import faithfulness_loss @@ -7,7 +5,7 @@ from tests.metrics.fixtures import make_one_layer_component_model -def zero_out_components(model: ComponentModel[Any]) -> None: +def zero_out_components(model: ComponentModel) -> None: with torch.no_grad(): for cm in model.components.values(): cm.V.zero_() diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 216602cff..e96c8ae1d 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -53,7 +53,7 @@ # --- Pretrained model info --- "pretrained_model_class": "transformers.LlamaForCausalLM", "pretrained_model_name": "SimpleStories/SimpleStories-1.25M", - "output_extract": {"type": "attr", "attr": "logits"}, + "output_extract": "logits", "tokenizer_name": "SimpleStories/SimpleStories-1.25M", # --- Task Specific --- "task_config": { diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index 0784a0e5d..62c14ce12 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -4,7 +4,6 @@ from transformers import PreTrainedModel from spd.configs import ( - AttrOutputExtract, CI_L0Config, Config, FaithfulnessLossConfig, @@ -79,7 +78,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="transformers.GPT2LMHeadModel", pretrained_model_path=None, pretrained_model_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - output_extract=AttrOutputExtract(attr="logits"), + output_extract="logits", tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", # Task Specific task_config=LMTaskConfig( diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index e75776966..833d4671a 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -1,4 +1,4 @@ -from typing import Any, override +from typing import override import torch import torch.nn as nn @@ -31,7 +31,7 @@ def forward(self, x: Tensor) -> Tensor: return self.fc(x) -def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel[Any]: +def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel: d_out, d_in = weight.shape target = TinyLinearModel(d_in=d_in, d_out=d_out) with torch.no_grad(): @@ -50,7 +50,7 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel return comp_model -def _zero_components_for_test(model: ComponentModel[Any]) -> None: +def _zero_components_for_test(model: ComponentModel) -> None: with torch.no_grad(): for cm in model.components.values(): cm.V.zero_() From 3042ce0830c05416af88ead18f3d37329d778dfd Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Wed, 11 Feb 2026 16:36:30 +0000 Subject: [PATCH 16/16] Revert various changes that were made --- spd/app/backend/compute.py | 12 ++++----- .../backend/routers/dataset_attributions.py | 4 +-- spd/data.py | 1 + spd/dataset_attributions/harvester.py | 6 ++--- spd/experiments/ih/ih_decomposition.py | 4 +-- spd/experiments/ih/model.py | 7 ++---- spd/experiments/resid_mlp/models.py | 6 +---- .../resid_mlp/resid_mlp_decomposition.py | 4 +-- spd/experiments/tms/models.py | 6 +---- spd/experiments/tms/tms_decomposition.py | 4 +-- spd/metrics/pgd_utils.py | 2 +- spd/models/batch_and_loss_fns.py | 5 ++++ spd/models/component_model.py | 22 +++++++++++++--- tests/test_component_model.py | 25 +------------------ tests/test_distributed.py | 13 ++-------- tests/test_ih_transformer.py | 4 +-- tests/test_resid_mlp.py | 4 +-- tests/test_tms.py | 4 +-- 18 files changed, 55 insertions(+), 78 deletions(-) diff --git a/spd/app/backend/compute.py b/spd/app/backend/compute.py index 0c82f105b..6ef65dc99 100644 --- a/spd/app/backend/compute.py +++ b/spd/app/backend/compute.py @@ -7,7 +7,7 @@ from collections import defaultdict from collections.abc import Callable from dataclasses import dataclass -from typing import Any, cast, override +from typing import Any, override import torch from jaxtyping import Bool, Float @@ -166,9 +166,8 @@ def wte_hook( wte_cache["wte_post_detach"] = output return output - wte = cast(Any, model.target_model).wte - assert isinstance(wte, nn.Module), "wte is not a module" - wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) with torch.enable_grad(), bf16_autocast(): comp_output_with_cache: OutputWithCache = model( @@ -343,9 +342,8 @@ def compute_edges_from_ci( # Setup wte hook and run forward pass for gradient computation wte_hook, wte_cache = _setup_wte_hook() - wte = cast(Any, model.target_model).wte - assert isinstance(wte, nn.Module), "wte is not a module" - wte_handle = wte.register_forward_hook(wte_hook, with_kwargs=True) + assert isinstance(model.target_model.wte, nn.Module), "wte is not a module" + wte_handle = model.target_model.wte.register_forward_hook(wte_hook, with_kwargs=True) weight_deltas = model.calc_weight_deltas() weight_deltas_and_masks = { diff --git a/spd/app/backend/routers/dataset_attributions.py b/spd/app/backend/routers/dataset_attributions.py index e32f2e372..fa38f5146 100644 --- a/spd/app/backend/routers/dataset_attributions.py +++ b/spd/app/backend/routers/dataset_attributions.py @@ -4,7 +4,7 @@ over the full training dataset. """ -from typing import Annotated, Any, Literal, cast +from typing import Annotated, Literal from fastapi import APIRouter, HTTPException, Query from jaxtyping import Float @@ -85,7 +85,7 @@ def _require_target(storage: DatasetAttributionStorage, component_key: str) -> N def _get_w_unembed(loaded: DepLoadedRun) -> Float[Tensor, "d_model vocab"]: """Get the unembedding matrix from the loaded model.""" - lm_head = cast(Any, loaded.model.target_model).lm_head + lm_head = loaded.model.target_model.lm_head assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" return lm_head.weight.T.detach() diff --git a/spd/data.py b/spd/data.py index a968a928f..39479f0b0 100644 --- a/spd/data.py +++ b/spd/data.py @@ -320,6 +320,7 @@ def train_loader_and_tokenizer( batch_size=batch_size, buffer_size=task_config.buffer_size, global_seed=config.seed, + collate_fn=lm_collate_fn, ) assert isinstance(tokenizer, PreTrainedTokenizerBase) diff --git a/spd/dataset_attributions/harvester.py b/spd/dataset_attributions/harvester.py index bd46c3c6c..bea3b37d2 100644 --- a/spd/dataset_attributions/harvester.py +++ b/spd/dataset_attributions/harvester.py @@ -10,7 +10,7 @@ Output attributions computed on-the-fly at query time via w_unembed """ -from typing import Any, cast +from typing import Any import torch from jaxtyping import Bool, Float, Int @@ -75,7 +75,7 @@ def __init__( # For output targets: store attributions to output residual dimensions assert hasattr(model.target_model, "lm_head"), "Model must have lm_head" - lm_head = cast(Any, model.target_model).lm_head + lm_head = model.target_model.lm_head assert isinstance(lm_head, nn.Linear), f"lm_head must be nn.Linear, got {type(lm_head)}" self.d_model = lm_head.in_features self.out_residual_accumulator = torch.zeros(self.n_sources, self.d_model, device=device) @@ -144,7 +144,7 @@ def pre_unembed_hook(_mod: nn.Module, args: tuple[Any, ...], _kwargs: Any) -> No pre_unembed.clear() pre_unembed.append(args[0]) - wte = cast(Any, self.model.target_model).wte + wte = self.model.target_model.wte assert isinstance(wte, nn.Module) h1 = wte.register_forward_hook(wte_hook, with_kwargs=True) h2 = self.lm_head.register_forward_pre_hook(pre_unembed_hook, with_kwargs=True) diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 121747466..2ca5f8f1d 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -7,7 +7,7 @@ from spd.configs import Config, IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer from spd.log import logger -from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_passthrough +from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.distributed_utils import get_device @@ -97,7 +97,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - run_batch=run_batch_passthrough, + run_batch=run_batch_first_element, reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/ih/model.py b/spd/experiments/ih/model.py index 143f4eefe..0d14b5866 100644 --- a/spd/experiments/ih/model.py +++ b/spd/experiments/ih/model.py @@ -1,6 +1,6 @@ import math from dataclasses import dataclass -from typing import Any, override +from typing import override import torch from jaxtyping import Float @@ -210,10 +210,7 @@ def __init__(self, cfg: InductionModelConfig): self.unembed = nn.Linear(cfg.d_model, adjusted_vocab_size, bias=False) @override - def forward( - self, batch: tuple[Float[Tensor, "B S"], ...] | Float[Tensor, "B S"], **_: Any - ) -> Float[Tensor, "B S V"]: - tokens = batch[0] if isinstance(batch, tuple) else batch + def forward(self, tokens: Float[Tensor, "B S"]) -> Float[Tensor, "B S V"]: x = self.token_embed(tokens) for block in self.blocks: diff --git a/spd/experiments/resid_mlp/models.py b/spd/experiments/resid_mlp/models.py index df0605a1b..108d0b520 100644 --- a/spd/experiments/resid_mlp/models.py +++ b/spd/experiments/resid_mlp/models.py @@ -18,9 +18,6 @@ from spd.spd_types import ModelPath from spd.utils.module_utils import init_param_ -ResidMLPBatch = tuple[Float[Tensor, "... n_features"], Float[Tensor, "... n_features"]] -ResidMLPOutput = Float[Tensor, "... n_features"] - @dataclass class ResidMLPTargetRunInfo(RunInfo[ResidMLPTrainConfig]): @@ -92,10 +89,9 @@ def __init__(self, config: ResidMLPModelConfig): @override def forward( self, - batch: ResidMLPBatch | Float[Tensor, "... n_features"], + x: Float[Tensor, "... n_features"], return_residual: bool = False, ) -> Float[Tensor, "... n_features"] | Float[Tensor, "... d_embed"]: - x = batch[0] if isinstance(batch, tuple) else batch residual = einops.einsum(x, self.W_E, "... n_features, n_features d_embed -> ... d_embed") for layer in self.layers: out = layer(residual) diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 09214c396..a61d8904a 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -13,7 +13,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.log import logger -from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.distributed_utils import get_device @@ -109,7 +109,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - run_batch=run_batch_passthrough, + run_batch=run_batch_first_element, reconstruction_loss=recon_loss_mse, out_dir=out_dir, ) diff --git a/spd/experiments/tms/models.py b/spd/experiments/tms/models.py index 95327b51d..f8643e225 100644 --- a/spd/experiments/tms/models.py +++ b/spd/experiments/tms/models.py @@ -10,9 +10,6 @@ from spd.interfaces import LoadableModule, RunInfo from spd.spd_types import ModelPath -TMSBatch = tuple[Float[Tensor, "... n_features"], Float[Tensor, "... n_features"]] -TMSOutput = Float[Tensor, "... n_features"] - @dataclass class TMSTargetRunInfo(RunInfo[TMSTrainConfig]): @@ -56,9 +53,8 @@ def to(self, *args: Any, **kwargs: Any) -> Self: @override def forward( - self, batch: TMSBatch | Float[Tensor, "... n_features"], **_: Any + self, x: Float[Tensor, "... n_features"], **_: Any ) -> Float[Tensor, "... n_features"]: - x = batch[0] if isinstance(batch, tuple) else batch hidden = self.linear1(x) if self.hidden_layers is not None: for layer in self.hidden_layers: diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index ee4258629..38dc47a6b 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -13,7 +13,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSTargetRunInfo from spd.log import logger -from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.distributed_utils import get_device @@ -105,7 +105,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - run_batch=run_batch_passthrough, + run_batch=run_batch_first_element, reconstruction_loss=recon_loss_mse, out_dir=out_dir, tied_weights=tied_weights, diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index 132c1a12f..d26fe6291 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -254,7 +254,7 @@ def _multibatch_pgd_fwd_bwd( # important: take gradient wrt the UNEXPANDED adv_sources, not the expanded ones grads = torch.autograd.grad(batch_sum_loss, list(adv_sources.values())) for k, g in zip(adv_sources.keys(), grads, strict=True): - pgd_step_accum_sum_grads[k] += all_reduce(g, op=ReduceOp.SUM).detach() + pgd_step_accum_sum_grads[k] += all_reduce(g, op=ReduceOp.AVG).detach() del target_model_output, ci diff --git a/spd/models/batch_and_loss_fns.py b/spd/models/batch_and_loss_fns.py index c603e638a..6ec940167 100644 --- a/spd/models/batch_and_loss_fns.py +++ b/spd/models/batch_and_loss_fns.py @@ -29,6 +29,11 @@ def run_batch_passthrough(model: nn.Module, batch: Any) -> Tensor: return runtime_cast(Tensor, model(batch)) +def run_batch_first_element(model: nn.Module, batch: Any) -> Tensor: + """Run model on the first element of a batch tuple (e.g. (input, labels) -> model(input)).""" + return runtime_cast(Tensor, model(batch[0])) + + def make_run_batch(output_extract: int | str | None) -> RunBatch: """Creates a RunBatch function for a given configuration. diff --git a/spd/models/component_model.py b/spd/models/component_model.py index fd5b80f56..5376aae65 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -127,10 +127,26 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": model_class = resolve_class(config.pretrained_model_class) if config.pretrained_model_name is not None: - assert hasattr(model_class, "from_pretrained") - target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] + assert hasattr(model_class, "from_pretrained"), ( + f"Model class {model_class} should have a `from_pretrained` method" + ) + # Handle spd.pretrain models: patch missing model_type in old pretrain runs + if config.pretrained_model_class.startswith("spd.pretrain.models."): + from spd.pretrain.run_info import PretrainRunInfo + + pretrain_run_info = PretrainRunInfo.from_path(config.pretrained_model_name) + if "model_type" not in pretrain_run_info.model_config_dict: + pretrain_run_info.model_config_dict["model_type"] = ( + config.pretrained_model_class.split(".")[-1] + ) + target_model = model_class.from_run_info(pretrain_run_info) # pyright: ignore[reportAttributeAccessIssue] + else: + target_model = model_class.from_pretrained(config.pretrained_model_name) # pyright: ignore[reportAttributeAccessIssue] else: - assert issubclass(model_class, LoadableModule) + assert issubclass(model_class, LoadableModule), ( + f"Model class {model_class} should be a subclass of LoadableModule which " + "defines a `from_pretrained` method" + ) assert config.pretrained_model_path is not None target_model = model_class.from_pretrained(config.pretrained_model_path) diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 9868749f6..3d483edd8 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -182,32 +182,9 @@ def test_from_run_info(): save_file(config.model_dump(mode="json"), comp_model_dir / "final_config.yaml") cm_run_info = SPDRunInfo.from_path(comp_model_dir / "model.pth") + cm_loaded = ComponentModel.from_run_info(cm_run_info) assert config == cm_run_info.config - - # Manually reconstruct component model and load state dict - assert cm_run_info.config.pretrained_model_path is not None - loaded_target = SimpleTestModel.from_pretrained(cm_run_info.config.pretrained_model_path) - loaded_target.eval() - loaded_target.requires_grad_(False) - if cm_run_info.config.identity_module_info is not None: - insert_identity_operations_( - loaded_target, - identity_module_info=cm_run_info.config.identity_module_info, - ) - loaded_module_path_info = expand_module_patterns( - loaded_target, cm_run_info.config.all_module_info - ) - cm_loaded = ComponentModel( - target_model=loaded_target, - run_batch=run_batch_passthrough, - module_path_info=loaded_module_path_info, - ci_fn_type=cm_run_info.config.ci_fn_type, - ci_fn_hidden_dims=cm_run_info.config.ci_fn_hidden_dims, - sigmoid_type=cm_run_info.config.sigmoid_type, - ) - cm_loaded.load_state_dict(torch.load(cm_run_info.checkpoint_path)) - for k, v in cm_loaded.state_dict().items(): torch.testing.assert_close(v, cm.state_dict()[k]) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index e96c8ae1d..f1a76ae8e 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -160,10 +160,6 @@ def _run_experiment( new_env = os.environ.copy() new_env["CUDA_VISIBLE_DEVICES"] = "" new_env["SPD_OUT_DIR"] = str(spd_out_dir) - # Force single-threaded execution so that within-rank float32 operations - # are deterministic across different machines/CI environments. - new_env["OMP_NUM_THREADS"] = "1" - new_env["MKL_NUM_THREADS"] = "1" result = subprocess.run(cmd, env=new_env, capture_output=True, text=True, timeout=300) @@ -230,16 +226,11 @@ def _compare_saved_models( self, dp1_out_dir: Path, dp2_out_dir: Path, - atol: float = 2e-4, - rtol: float = 1e-3, + atol: float = 1e-6, + rtol: float = 1e-5, ) -> None: """Compare saved model parameters between dp=1 and dp=2 runs. - Tolerances are relatively loose because CI-masked reconstruction losses use hard - masking: tiny allreduce rounding differences can push a CI value across the mask - threshold, causing a different gradient path that compounds over training steps. - Empirically, across many seeds, max parameter diffs stay below ~1.5e-4. - Args: dp1_out_dir: Output directory for dp=1 run dp2_out_dir: Output directory for dp=2 run diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 4de05d05f..0392a4354 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -17,7 +17,7 @@ from spd.experiments.ih.configs import InductionModelConfig from spd.experiments.ih.model import InductionTransformer from spd.identity_insertion import insert_identity_operations_ -from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_passthrough +from spd.models.batch_and_loss_fns import recon_loss_kl, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.general_utils import set_seed @@ -132,7 +132,7 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - run_batch=run_batch_passthrough, + run_batch=run_batch_first_element, reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index 4c502aabc..8db454a87 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -13,7 +13,7 @@ from spd.experiments.resid_mlp.models import ResidMLP from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.identity_insertion import insert_identity_operations_ -from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.general_utils import set_seed @@ -128,7 +128,7 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - run_batch=run_batch_passthrough, + run_batch=run_batch_first_element, reconstruction_loss=recon_loss_mse, out_dir=tmp_path, ) diff --git a/tests/test_tms.py b/tests/test_tms.py index e717afcb8..f34a8b693 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -18,7 +18,7 @@ from spd.experiments.tms.models import TMSModel from spd.experiments.tms.train_tms import get_model_and_dataloader, train from spd.identity_insertion import insert_identity_operations_ -from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_passthrough +from spd.models.batch_and_loss_fns import recon_loss_mse, run_batch_first_element from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.general_utils import set_seed @@ -136,7 +136,7 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - run_batch=run_batch_passthrough, + run_batch=run_batch_first_element, reconstruction_loss=recon_loss_mse, out_dir=tmp_path, tied_weights=tied_weights,