From 255053a1d60cb62d862193adab93ddc69187bc97 Mon Sep 17 00:00:00 2001 From: Oli Clive-Griffin Date: Wed, 11 Feb 2026 18:47:26 +0000 Subject: [PATCH 1/2] Remove BatchT/OutputT generics from ComponentModel, introduce ReconstructionLoss protocol De-genericize ComponentModel by removing BatchT and OutputT type parameters. Model batch input and output types are now Any. Introduces a ReconstructionLoss protocol with concrete implementations (recon_loss_mse, recon_loss_kl) that callers pass explicitly instead of the old output_loss_type string config. Key changes: - Remove BatchT/OutputT generics from ComponentModel and all downstream code - Add ReconstructionLoss protocol in spd/models/batch_and_loss_fns.py - Rename pretrained_model_output_attr -> extract_tensor_output with regex-based parsing - Remove output_loss_type config field (added to deprecated keys with migration) - Remove extract_batch_data utility (callers handle tuple extraction inline) - Add lm_collate_fn for LM data loading - Update all 24 YAML configs, 11 metric files, 4 experiment scripts, and all tests Co-Authored-By: Claude Opus 4.6 --- spd/app/backend/routers/prompts.py | 10 --- spd/configs.py | 27 ++++-- spd/data.py | 16 +++- spd/dataset_attributions/harvest.py | 3 +- spd/eval.py | 89 ++++++++++--------- spd/experiments/ih/ih_config.yaml | 1 - spd/experiments/ih/ih_decomposition.py | 3 +- spd/experiments/lm/gpt2_config.yaml | 3 +- spd/experiments/lm/lm_decomposition.py | 7 +- .../lm/pile_llama_simple_mlp-2L.yaml | 3 +- .../lm/pile_llama_simple_mlp-4L.yaml | 3 +- spd/experiments/lm/ss_gpt2_config.yaml | 3 +- spd/experiments/lm/ss_gpt2_simple-1L.yaml | 3 +- spd/experiments/lm/ss_gpt2_simple-2L.yaml | 3 +- spd/experiments/lm/ss_gpt2_simple_config.yaml | 3 +- .../lm/ss_gpt2_simple_noln_config.yaml | 3 +- spd/experiments/lm/ss_llama_simple-1L.yaml | 3 +- spd/experiments/lm/ss_llama_simple-2L.yaml | 3 +- .../lm/ss_llama_simple_config.yaml | 4 +- .../lm/ss_llama_simple_mlp-1L.yaml | 3 +- .../lm/ss_llama_simple_mlp-2L-wide.yaml | 3 +- .../lm/ss_llama_simple_mlp-2L.yaml | 3 +- spd/experiments/lm/ss_llama_simple_mlp.yaml | 3 +- spd/experiments/lm/ts_config.yaml | 3 +- .../resid_mlp/resid_mlp1_config.yaml | 1 - .../resid_mlp/resid_mlp2_config.yaml | 1 - .../resid_mlp/resid_mlp3_config.yaml | 1 - .../resid_mlp/resid_mlp_decomposition.py | 5 +- spd/experiments/tms/tms_40-10-id_config.yaml | 1 - spd/experiments/tms/tms_40-10_config.yaml | 1 - spd/experiments/tms/tms_5-2-id_config.yaml | 1 - spd/experiments/tms/tms_5-2_config.yaml | 1 - spd/experiments/tms/tms_decomposition.py | 3 +- spd/harvest/harvest.py | 6 +- spd/losses.py | 31 +++---- spd/metrics/ci_masked_recon_layerwise_loss.py | 32 +++---- spd/metrics/ci_masked_recon_loss.py | 33 ++++--- spd/metrics/ci_masked_recon_subset_loss.py | 33 ++++--- .../pgd_masked_recon_layerwise_loss.py | 29 +++--- spd/metrics/pgd_masked_recon_loss.py | 23 ++--- spd/metrics/pgd_masked_recon_subset_loss.py | 23 ++--- spd/metrics/pgd_utils.py | 57 +++++------- .../stochastic_recon_layerwise_loss.py | 34 +++---- spd/metrics/stochastic_recon_loss.py | 38 ++++---- spd/metrics/stochastic_recon_subset_loss.py | 34 +++---- spd/metrics/unmasked_recon_loss.py | 33 ++++--- spd/models/batch_and_loss_fns.py | 39 ++++++++ spd/models/component_model.py | 77 +++++++--------- spd/persistent_pgd.py | 14 ++- spd/run_spd.py | 73 ++++++++------- spd/scripts/compare_models/compare_models.py | 4 +- spd/utils/general_utils.py | 62 +------------ tests/app/test_server_api.py | 5 +- tests/metrics/fixtures.py | 4 +- .../test_ci_masked_recon_layerwise_loss.py | 19 +++- tests/metrics/test_ci_masked_recon_loss.py | 19 +++- .../test_ci_masked_recon_subset_loss.py | 3 +- .../test_stochastic_recon_layerwise_loss.py | 7 +- tests/metrics/test_stochastic_recon_loss.py | 3 +- .../test_stochastic_recon_subset_loss.py | 3 +- tests/test_component_model.py | 48 +++++----- tests/test_distributed.py | 3 +- tests/test_gpt2.py | 10 ++- tests/test_ih_transformer.py | 6 +- tests/test_resid_mlp.py | 6 +- tests/test_spd_losses.py | 59 +++++++----- tests/test_tms.py | 6 +- 67 files changed, 551 insertions(+), 545 deletions(-) create mode 100644 spd/models/batch_and_loss_fns.py diff --git a/spd/app/backend/routers/prompts.py b/spd/app/backend/routers/prompts.py index 7f06bcf68..2cf0da197 100644 --- a/spd/app/backend/routers/prompts.py +++ b/spd/app/backend/routers/prompts.py @@ -6,16 +6,6 @@ from spd.app.backend.dependencies import DepLoadedRun, DepStateManager from spd.app.backend.utils import log_errors -from spd.utils.distributed_utils import get_device - -# TODO: Re-enable these endpoints when dependencies are available: -# - extract_active_from_ci from database -# - PromptSearchQuery, PromptSearchResponse from schemas -# - DatasetConfig, LMTaskConfig from configs -# - create_data_loader, extract_batch_data from data -# - logger from utils - -DEVICE = get_device() # ============================================================================= # Schemas diff --git a/spd/configs.py b/spd/configs.py index 680a506e0..615e46306 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -744,11 +744,6 @@ def all_module_info(self) -> list[ModulePatternInfoConfig]: ), ) ) - output_loss_type: Literal["mse", "kl"] = Field( - ..., - description="Metric used to measure recon error between model outputs and targets", - ) - # --- Training --- lr_schedule: ScheduleConfig = Field(..., description="Learning rate schedule configuration") steps: NonNegativeInt = Field(..., description="Total number of optimisation steps") @@ -849,9 +844,11 @@ def microbatch_size(self) -> PositiveInt: default=None, description="hf model identifier. E.g. 'SimpleStories/SimpleStories-1.25M'", ) - pretrained_model_output_attr: str | None = Field( + extract_tensor_output: str | None = Field( default=None, - description="Name of the attribute on the forward output that contains logits or activations", + description="Declarative accessor path for extracting tensor from model output. " + "None = raw output is the tensor. Examples: '.logits' for attribute access, " + "'[0]' for index access.", ) tokenizer_name: str | None = Field( default=None, @@ -890,6 +887,7 @@ def microbatch_size(self) -> PositiveInt: "lr_exponential_halflife", "out_dir", "n_examples_until_dead", + "output_loss_type", ] RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = { "grad_clip_norm": "grad_clip_norm_components", @@ -934,6 +932,21 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, "simple_stories_train.models.", "spd.pretrain.models.", 1 ) + # Migrate old pretrained_model_output_attr to extract_tensor_output + if "pretrained_model_output_attr" in config_dict: + old_val = config_dict.pop("pretrained_model_output_attr") + match old_val: + case None: + pass + case "idx_0": + config_dict["extract_tensor_output"] = "[0]" + case str(attr): + config_dict["extract_tensor_output"] = f".{attr}" + case _: + raise AssertionError( + f"Unexpected pretrained_model_output_attr value: {old_val}" + ) + if "eval_batch_size" not in config_dict: config_dict["eval_batch_size"] = config_dict["batch_size"] if "train_log_freq" not in config_dict: diff --git a/spd/data.py b/spd/data.py index b1ed33a62..39479f0b0 100644 --- a/spd/data.py +++ b/spd/data.py @@ -1,8 +1,10 @@ +from collections.abc import Callable, Generator from typing import Any import numpy as np import torch from datasets import Dataset, IterableDataset, load_dataset +from jaxtyping import Int from numpy.typing import NDArray from torch import Tensor from torch.utils.data import DataLoader, DistributedSampler @@ -152,7 +154,8 @@ def create_data_loader( dist_state: DistributedState | None = None, global_seed: int = 0, to_lower: bool = True, -) -> tuple[DataLoader[Any], PreTrainedTokenizer]: + collate_fn: Callable[..., Any] | None = None, +) -> tuple[DataLoader[Int[Tensor, "batch seq"]], PreTrainedTokenizer]: """Create a DataLoader for the given dataset. Uses PyTorch's DistributedSampler to ensure each rank gets the correct @@ -255,7 +258,7 @@ def create_data_loader( generator = torch.Generator(device="cpu") generator.manual_seed(seed) - loader = DataLoader[Dataset | IterableDataset]( + loader = DataLoader[Int[Tensor, "batch seq"]]( torch_dataset, # pyright: ignore[reportArgumentType] batch_size=batch_size, sampler=sampler, @@ -264,11 +267,17 @@ def create_data_loader( ), drop_last=True, generator=generator, + collate_fn=collate_fn, ) return loader, tokenizer -def loop_dataloader[T](dl: DataLoader[T]): +def lm_collate_fn(batch: list[dict[str, Tensor]]) -> Tensor: + """Collate function that extracts input_ids tensors from HuggingFace dataset dicts.""" + return torch.stack([item["input_ids"] for item in batch]) + + +def loop_dataloader[T](dl: DataLoader[T]) -> Generator[T]: """Loop over a dataloader, resetting the iterator when it is exhausted. Ensures that each epoch gets different data, even when using a distributed sampler. @@ -311,6 +320,7 @@ def train_loader_and_tokenizer( batch_size=batch_size, buffer_size=task_config.buffer_size, global_seed=config.seed, + collate_fn=lm_collate_fn, ) assert isinstance(tokenizer, PreTrainedTokenizerBase) diff --git a/spd/dataset_attributions/harvest.py b/spd/dataset_attributions/harvest.py index ff18af5b9..4c79d2406 100644 --- a/spd/dataset_attributions/harvest.py +++ b/spd/dataset_attributions/harvest.py @@ -30,7 +30,6 @@ from spd.models.component_model import ComponentModel, SPDRunInfo from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data from spd.utils.wandb_utils import parse_wandb_run_path @@ -201,7 +200,7 @@ def harvest_attributions( # Skip batches not assigned to this rank if world_size is not None and batch_idx % world_size != rank: continue - batch = extract_batch_data(batch_data).to(device) + batch = batch_data.to(device) harvester.process_batch(batch) logger.info( diff --git a/spd/eval.py b/spd/eval.py index 332e4adbf..c0b472658 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -3,7 +3,7 @@ from collections.abc import Iterator from typing import Any -from jaxtyping import Float, Int +from jaxtyping import Float from PIL import Image from torch import Tensor from torch.types import Number @@ -39,35 +39,38 @@ UnmaskedReconLossConfig, UVPlotsConfig, ) -from spd.metrics import UnmaskedReconLoss +from spd.metrics import ( + CI_L0, + CEandKLLosses, + CIHistograms, + CIMaskedReconLayerwiseLoss, + CIMaskedReconLoss, + CIMaskedReconSubsetLoss, + CIMeanPerComponent, + ComponentActivationDensity, + FaithfulnessLoss, + IdentityCIError, + ImportanceMinimalityLoss, + PermutedCIPlots, + PGDReconLayerwiseLoss, + PGDReconLoss, + PGDReconSubsetLoss, + StochasticHiddenActsReconLoss, + StochasticReconLayerwiseLoss, + StochasticReconLoss, + StochasticReconSubsetCEAndKL, + StochasticReconSubsetLoss, + UnmaskedReconLoss, + UVPlots, +) from spd.metrics.base import Metric -from spd.metrics.ce_and_kl_losses import CEandKLLosses -from spd.metrics.ci_histograms import CIHistograms -from spd.metrics.ci_l0 import CI_L0 -from spd.metrics.ci_masked_recon_layerwise_loss import CIMaskedReconLayerwiseLoss -from spd.metrics.ci_masked_recon_loss import CIMaskedReconLoss -from spd.metrics.ci_masked_recon_subset_loss import CIMaskedReconSubsetLoss -from spd.metrics.ci_mean_per_component import CIMeanPerComponent -from spd.metrics.component_activation_density import ComponentActivationDensity -from spd.metrics.faithfulness_loss import FaithfulnessLoss -from spd.metrics.identity_ci_error import IdentityCIError -from spd.metrics.importance_minimality_loss import ImportanceMinimalityLoss -from spd.metrics.permuted_ci_plots import PermutedCIPlots -from spd.metrics.pgd_masked_recon_layerwise_loss import PGDReconLayerwiseLoss -from spd.metrics.pgd_masked_recon_loss import PGDReconLoss -from spd.metrics.pgd_masked_recon_subset_loss import PGDReconSubsetLoss from spd.metrics.pgd_utils import CreateDataIter, calc_multibatch_pgd_masked_recon_loss -from spd.metrics.stochastic_hidden_acts_recon_loss import StochasticHiddenActsReconLoss -from spd.metrics.stochastic_recon_layerwise_loss import StochasticReconLayerwiseLoss -from spd.metrics.stochastic_recon_loss import StochasticReconLoss -from spd.metrics.stochastic_recon_subset_ce_and_kl import StochasticReconSubsetCEAndKL -from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss -from spd.metrics.uv_plots import UVPlots +from spd.models.batch_and_loss_fns import ReconstructionLoss, recon_loss_kl from spd.models.component_model import ComponentModel, OutputWithCache from spd.persistent_pgd import PersistentPGDReconLoss, PersistentPGDReconSubsetLoss, PPGDSources from spd.routing import AllLayersRouter, get_subset_router from spd.utils.distributed_utils import avg_metrics_across_ranks, is_distributed -from spd.utils.general_utils import dict_safe_update_, extract_batch_data +from spd.utils.general_utils import dict_safe_update_ MetricOutType = dict[str, str | Number | Image.Image | CustomChart] DistMetricOutType = dict[str, str | float | Image.Image | CustomChart] @@ -127,6 +130,7 @@ def init_metric( ], run_config: Config, device: str, + reconstruction_loss: ReconstructionLoss, ) -> Metric: match cfg: case ImportanceMinimalityLossConfig(): @@ -164,16 +168,16 @@ def init_metric( metric = CIMaskedReconSubsetLoss( model=model, device=device, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, routing=cfg.routing, ) case CIMaskedReconLayerwiseLossConfig(): metric = CIMaskedReconLayerwiseLoss( - model=model, device=device, output_loss_type=run_config.output_loss_type + model=model, device=device, reconstruction_loss=reconstruction_loss ) case CIMaskedReconLossConfig(): metric = CIMaskedReconLoss( - model=model, device=device, output_loss_type=run_config.output_loss_type + model=model, device=device, reconstruction_loss=reconstruction_loss ) case CIMeanPerComponentConfig(): metric = CIMeanPerComponent(model=model, device=device) @@ -202,7 +206,7 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) case StochasticReconLossConfig(): metric = StochasticReconLoss( @@ -211,7 +215,7 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) case StochasticReconSubsetLossConfig(): metric = StochasticReconSubsetLoss( @@ -220,7 +224,7 @@ def init_metric( sampling=run_config.sampling, use_delta_component=run_config.use_delta_component, n_mask_samples=run_config.n_mask_samples, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, routing=cfg.routing, ) case PGDReconLossConfig(): @@ -228,7 +232,7 @@ def init_metric( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, pgd_config=cfg, ) case PGDReconSubsetLossConfig(): @@ -236,7 +240,7 @@ def init_metric( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, pgd_config=cfg, routing=cfg.routing, ) @@ -245,7 +249,7 @@ def init_metric( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, pgd_config=cfg, ) case StochasticReconSubsetCEAndKLConfig(): @@ -277,23 +281,25 @@ def init_metric( metric = UnmaskedReconLoss( model=model, device=device, - output_loss_type=run_config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) case PersistentPGDReconLossConfig(): + ppgd_output_loss_type = "kl" if reconstruction_loss is recon_loss_kl else "mse" metric = PersistentPGDReconLoss( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, + output_loss_type=ppgd_output_loss_type, ppgd_sources=ppgd_sourcess[cfg], ) case PersistentPGDReconSubsetLossConfig(): + ppgd_output_loss_type = "kl" if reconstruction_loss is recon_loss_kl else "mse" metric = PersistentPGDReconSubsetLoss( model=model, device=device, use_delta_component=run_config.use_delta_component, - output_loss_type=run_config.output_loss_type, + output_loss_type=ppgd_output_loss_type, ppgd_sources=ppgd_sourcess[cfg], routing=cfg.routing, ) @@ -307,7 +313,7 @@ def init_metric( def evaluate( eval_metric_configs: list[MetricConfigType], model: ComponentModel, - eval_iterator: Iterator[Int[Tensor, "..."] | tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + eval_iterator: Iterator[Any], ppgd_sourcess: dict[ PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, dict[str, Float[Tensor, " source_c"]], @@ -317,6 +323,7 @@ def evaluate( slow_step: bool, n_eval_steps: int, current_frac_of_training: float, + reconstruction_loss: ReconstructionLoss, ) -> MetricOutType: """Run evaluation and return a mapping of metric names to values/images.""" @@ -328,6 +335,7 @@ def evaluate( ppgd_sourcess=ppgd_sourcess, run_config=run_config, device=device, + reconstruction_loss=reconstruction_loss, ) if metric.slow and not slow_step: continue @@ -338,7 +346,7 @@ def evaluate( for _ in range(n_eval_steps): batch_raw = next(eval_iterator) - batch = extract_batch_data(batch_raw).to(device) + batch = batch_raw[0] if isinstance(batch_raw, tuple) else batch_raw target_output: OutputWithCache = model(batch, cache_type="input") ci = model.calc_causal_importances( @@ -377,8 +385,8 @@ def evaluate_multibatch_pgd( model: ComponentModel, create_data_iter: CreateDataIter, config: Config, - batch_dims: tuple[int, ...], device: str, + reconstruction_loss: ReconstructionLoss, ) -> dict[str, float]: """Calculate multibatch PGD metrics.""" weight_deltas = model.calc_weight_deltas() if config.use_delta_component else None @@ -400,11 +408,10 @@ def evaluate_multibatch_pgd( model=model, weight_deltas=weight_deltas, create_data_iter=create_data_iter, - output_loss_type=config.output_loss_type, router=router, sampling=config.sampling, use_delta_component=config.use_delta_component, - batch_dims=batch_dims, device=device, + reconstruction_loss=reconstruction_loss, ).item() return metrics diff --git a/spd/experiments/ih/ih_config.yaml b/spd/experiments/ih/ih_config.yaml index 4ce329ff2..42f3bcda8 100644 --- a/spd/experiments/ih/ih_config.yaml +++ b/spd/experiments/ih/ih_config.yaml @@ -33,7 +33,6 @@ ci_recon_layerwise_coeff: null stochastic_recon_layerwise_coeff: 1 importance_minimality_coeff: 1e-2 pnorm: 0.1 -output_loss_type: kl ci_config: mode: layerwise fn_type: vector_mlp diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 1b0b268fc..847e28773 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -7,6 +7,7 @@ from spd.configs import Config, IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo, InductionTransformer from spd.log import logger +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.distributed_utils import get_device @@ -99,7 +100,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/lm/gpt2_config.yaml b/spd/experiments/lm/gpt2_config.yaml index 653d78833..4ffa77d13 100644 --- a/spd/experiments/lm/gpt2_config.yaml +++ b/spd/experiments/lm/gpt2_config.yaml @@ -27,7 +27,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 2 @@ -66,7 +65,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: openai-community/gpt2 -pretrained_model_output_attr: logits +extract_tensor_output: ".logits" tokenizer_name: openai-community/gpt2 # --- Task Specific --- diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index e67a4cfd2..105e9e091 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -13,8 +13,9 @@ PersistentPGDReconSubsetLossConfig, RepeatAcrossBatchScope, ) -from spd.data import DatasetConfig, create_data_loader +from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.log import logger +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.pretrain.run_info import PretrainRunInfo from spd.run_spd import optimize from spd.utils.distributed_utils import ( @@ -156,6 +157,7 @@ def main( buffer_size=config.task_config.buffer_size, global_seed=config.seed, dist_state=dist_state, + collate_fn=lm_collate_fn, ) eval_data_config = DatasetConfig( @@ -185,6 +187,7 @@ def main( buffer_size=config.task_config.buffer_size, global_seed=config.seed + 1, dist_state=dist_state, + collate_fn=lm_collate_fn, ) if is_main_process(): @@ -196,7 +199,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_kl, out_dir=out_dir, ) diff --git a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml index fc179119f..bcfcd4a74 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-2L.yaml @@ -67,7 +67,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 1.0e-04 warmup_pct: 0.0 @@ -115,7 +114,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-bd02d372 -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml index c3c3a106a..9af10d55f 100644 --- a/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml +++ b/spd/experiments/lm/pile_llama_simple_mlp-4L.yaml @@ -67,7 +67,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 1.0e-04 warmup_pct: 0.0 @@ -119,7 +118,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/t-32d1bb3b -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: EleutherAI/gpt-neox-20b task_config: task_name: lm diff --git a/spd/experiments/lm/ss_gpt2_config.yaml b/spd/experiments/lm/ss_gpt2_config.yaml index 70a63fa10..b5d51d9a6 100644 --- a/spd/experiments/lm/ss_gpt2_config.yaml +++ b/spd/experiments/lm/ss_gpt2_config.yaml @@ -27,7 +27,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 16 @@ -66,7 +65,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.GPT2LMHeadModel pretrained_model_name: SimpleStories/test-SimpleStories-gpt2-1.25M -pretrained_model_output_attr: logits +extract_tensor_output: ".logits" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-1L.yaml b/spd/experiments/lm/ss_gpt2_simple-1L.yaml index be7192058..d970340cb 100644 --- a/spd/experiments/lm/ss_gpt2_simple-1L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-1L.yaml @@ -49,7 +49,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 0.0002 fn_type: cosine @@ -94,7 +93,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/3qhd7rnb # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple-2L.yaml b/spd/experiments/lm/ss_gpt2_simple-2L.yaml index ff7901e02..aa3a702bb 100644 --- a/spd/experiments/lm/ss_gpt2_simple-2L.yaml +++ b/spd/experiments/lm/ss_gpt2_simple-2L.yaml @@ -49,7 +49,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl lr_schedule: start_val: 0.0002 fn_type: cosine @@ -96,7 +95,7 @@ ci_alive_threshold: 0.0 # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/wr1su18m # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index 827d1b0c0..b98de54a6 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -40,7 +40,6 @@ loss_metric_configs: routing: type: uniform_k_subset coeff: 1.0 -output_loss_type: kl # --- Training --- batch_size: 256 @@ -99,7 +98,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/rvu66183 # 100k steps. 4019 tokenizer -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml index 0afa4bd50..010ee04cf 100644 --- a/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_noln_config.yaml @@ -38,7 +38,6 @@ loss_metric_configs: coeff: 2.0 - classname: "StochasticReconLoss" coeff: 0.2 -output_loss_type: kl # --- Training --- batch_size: 48 @@ -96,7 +95,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.gpt2_simple.GPT2Simple pretrained_model_name: wandb:goodfire/spd/runs/xi36b9az # No ln -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # We'll load this from wandb in future # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple-1L.yaml b/spd/experiments/lm/ss_llama_simple-1L.yaml index 8d18cfaab..d20e87547 100644 --- a/spd/experiments/lm/ss_llama_simple-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple-1L.yaml @@ -49,7 +49,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -93,7 +92,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tfacbi70 # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple-2L.yaml b/spd/experiments/lm/ss_llama_simple-2L.yaml index 3bbbac7a3..5ddf73090 100644 --- a/spd/experiments/lm/ss_llama_simple-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple-2L.yaml @@ -49,7 +49,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -95,7 +94,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/tb8373uo # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_config.yaml b/spd/experiments/lm/ss_llama_simple_config.yaml index 8b0d65cae..c458b41ed 100644 --- a/spd/experiments/lm/ss_llama_simple_config.yaml +++ b/spd/experiments/lm/ss_llama_simple_config.yaml @@ -44,8 +44,6 @@ loss_metric_configs: type: uniform_k_subset coeff: 1.0 -output_loss_type: kl - # --- Training --- batch_size: 256 eval_batch_size: 64 @@ -96,7 +94,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: spd.pretrain.models.llama_simple.LlamaSimple pretrained_model_name: wandb:goodfire/spd/runs/erq48r3w # 100k steps 4019 tok -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M # 4019 tok: TODO: Load from wandb instead # --- Task Specific --- diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index a692dc578..25e487464 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -43,7 +43,6 @@ loss_metric_configs: classname: PGDReconSubsetLoss - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 400000 batch_size: 64 gradient_accumulation_steps: 1 @@ -87,7 +86,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gvbmdt9w -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml index a80265153..60b31484e 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L-wide.yaml @@ -47,7 +47,6 @@ loss_metric_configs: type: uniform_k_subset - classname: FaithfulnessLoss coeff: 1000000 -output_loss_type: kl steps: 400000 batch_size: 128 gradient_accumulation_steps: 1 @@ -95,7 +94,7 @@ ci_alive_threshold: 0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/gf6rbga0 -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: buffer_size: 1000 diff --git a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml index d94e6e098..6335f282d 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-2L.yaml @@ -47,7 +47,6 @@ loss_metric_configs: type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 200000 batch_size: 64 gradient_accumulation_steps: 1 @@ -93,7 +92,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/7pt957pf # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ss_llama_simple_mlp.yaml b/spd/experiments/lm/ss_llama_simple_mlp.yaml index a4cc29525..f4ae02c3c 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp.yaml @@ -43,7 +43,6 @@ loss_metric_configs: classname: PGDReconSubsetLoss - coeff: 100000.0 classname: FaithfulnessLoss -output_loss_type: kl steps: 400000 batch_size: 128 gradient_accumulation_steps: 1 @@ -118,7 +117,7 @@ ci_alive_threshold: 0.0 pretrained_model_class: spd.pretrain.models.llama_simple_mlp.LlamaSimpleMLP pretrained_model_path: null pretrained_model_name: wandb:goodfire/spd/runs/9de1zu65 # 100k steps -pretrained_model_output_attr: idx_0 +extract_tensor_output: "[0]" tokenizer_name: SimpleStories/test-SimpleStories-gpt2-1.25M task_config: task_name: lm diff --git a/spd/experiments/lm/ts_config.yaml b/spd/experiments/lm/ts_config.yaml index f60c4e589..22f955302 100644 --- a/spd/experiments/lm/ts_config.yaml +++ b/spd/experiments/lm/ts_config.yaml @@ -30,7 +30,6 @@ loss_metric_configs: p_anneal_end_frac: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 2.0 -output_loss_type: kl # --- Training --- batch_size: 4 @@ -67,7 +66,7 @@ eval_metric_configs: # --- Pretrained model info --- pretrained_model_class: transformers.AutoModelForCausalLM pretrained_model_name: roneneldan/TinyStories-1M -pretrained_model_output_attr: logits +extract_tensor_output: ".logits" tokenizer_name: EleutherAI/gpt-neo-125M # --- Task Specific --- diff --git a/spd/experiments/resid_mlp/resid_mlp1_config.yaml b/spd/experiments/resid_mlp/resid_mlp1_config.yaml index 6178f579b..c6be5820a 100644 --- a/spd/experiments/resid_mlp/resid_mlp1_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp1_config.yaml @@ -30,7 +30,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp2_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_config.yaml index dcc8abeba..94d929ca9 100644 --- a/spd/experiments/resid_mlp/resid_mlp2_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp2_config.yaml @@ -35,7 +35,6 @@ loss_metric_configs: mask_scope: shared_across_batch - classname: "FaithfulnessLoss" coeff: 0.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp3_config.yaml b/spd/experiments/resid_mlp/resid_mlp3_config.yaml index 1961a44d3..7b954147b 100644 --- a/spd/experiments/resid_mlp/resid_mlp3_config.yaml +++ b/spd/experiments/resid_mlp/resid_mlp3_config.yaml @@ -29,7 +29,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 2048 diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index 75e423099..e24f6c63d 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -13,6 +13,7 @@ ) from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.log import logger +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.distributed_utils import get_device @@ -100,15 +101,13 @@ def main( dataset, batch_size=config.eval_batch_size, shuffle=False ) - # TODO: Below not needed when TMS supports config.n_eval_steps - assert config.n_eval_steps is not None, "n_eval_steps must be set" optimize( target_model=target_model, config=config, device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_mse, out_dir=out_dir, ) diff --git a/spd/experiments/tms/tms_40-10-id_config.yaml b/spd/experiments/tms/tms_40-10-id_config.yaml index 0fbdfd07a..16792b4e6 100644 --- a/spd/experiments/tms/tms_40-10-id_config.yaml +++ b/spd/experiments/tms/tms_40-10-id_config.yaml @@ -30,7 +30,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: "mse" # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_40-10_config.yaml b/spd/experiments/tms/tms_40-10_config.yaml index 2a7cffb55..079b249bd 100644 --- a/spd/experiments/tms/tms_40-10_config.yaml +++ b/spd/experiments/tms/tms_40-10_config.yaml @@ -29,7 +29,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: "mse" # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_5-2-id_config.yaml b/spd/experiments/tms/tms_5-2-id_config.yaml index cc654532e..07ef3a5a9 100644 --- a/spd/experiments/tms/tms_5-2-id_config.yaml +++ b/spd/experiments/tms/tms_5-2-id_config.yaml @@ -30,7 +30,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_5-2_config.yaml b/spd/experiments/tms/tms_5-2_config.yaml index 07bc9056a..e22723b23 100644 --- a/spd/experiments/tms/tms_5-2_config.yaml +++ b/spd/experiments/tms/tms_5-2_config.yaml @@ -28,7 +28,6 @@ loss_metric_configs: coeff: 1.0 - classname: "StochasticReconLayerwiseLoss" coeff: 1.0 -output_loss_type: mse # --- Training --- batch_size: 4096 diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 18c437a68..2471c1830 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -13,6 +13,7 @@ from spd.configs import Config, TMSTaskConfig from spd.experiments.tms.models import TMSModel, TMSTargetRunInfo from spd.log import logger +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.distributed_utils import get_device @@ -104,7 +105,7 @@ def main( device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_mse, out_dir=out_dir, tied_weights=tied_weights, ) diff --git a/spd/harvest/harvest.py b/spd/harvest/harvest.py index a72c214da..737b83045 100644 --- a/spd/harvest/harvest.py +++ b/spd/harvest/harvest.py @@ -31,7 +31,7 @@ from spd.models.component_model import ComponentModel, SPDRunInfo from spd.topology import TransformerTopology from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import bf16_autocast, extract_batch_data +from spd.utils.general_utils import bf16_autocast def _compute_u_norms(model: ComponentModel) -> dict[str, Float[Tensor, " C"]]: @@ -184,7 +184,7 @@ def harvest_activation_contexts( for batch_idx in tqdm.tqdm(batch_range, desc="Harvesting", disable=rank is not None): try: - batch_data = extract_batch_data(next(train_iter)) + batch = next(train_iter) except StopIteration: logger.info(f"Dataset exhausted at batch {batch_idx}. Processing complete.") break @@ -193,7 +193,7 @@ def harvest_activation_contexts( if world_size is not None and batch_idx % world_size != rank: continue - batch = batch_data.to(device) + batch = batch.to(device) with torch.no_grad(), bf16_autocast(): out = model(batch, cache_type="input") probs = torch.softmax(out.output, dim=-1) diff --git a/spd/losses.py b/spd/losses.py index 5ea1fd4a2..4383ca36d 100644 --- a/spd/losses.py +++ b/spd/losses.py @@ -1,6 +1,6 @@ -from typing import Literal +from typing import Any -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from spd.configs import ( @@ -37,6 +37,7 @@ stochastic_recon_subset_loss, unmasked_recon_loss, ) +from spd.models.batch_and_loss_fns import ReconstructionLoss, recon_loss_kl from spd.models.component_model import CIOutputs, ComponentModel from spd.persistent_pgd import PPGDSources, persistent_pgd_recon_loss @@ -44,7 +45,7 @@ def compute_losses( loss_metric_configs: list[LossMetricConfigType], model: ComponentModel, - batch: Int[Tensor, "..."], + batch: Any, ci: CIOutputs, target_out: Tensor, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], @@ -56,7 +57,7 @@ def compute_losses( ppgd_sourcess: dict[ PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, PPGDSources ], - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> dict[LossMetricConfigType, Float[Tensor, ""]]: """Compute losses for each config and return a dict mapping config to loss tensor.""" losses: dict[LossMetricConfigType, Float[Tensor, ""]] = {} @@ -80,14 +81,14 @@ def compute_losses( case UnmaskedReconLossConfig(): loss = unmasked_recon_loss( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ) case CIMaskedReconSubsetLossConfig(): loss = ci_masked_recon_subset_loss( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -96,7 +97,7 @@ def compute_losses( case CIMaskedReconLayerwiseLossConfig(): loss = ci_masked_recon_layerwise_loss( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -104,7 +105,7 @@ def compute_losses( case CIMaskedReconLossConfig(): loss = ci_masked_recon_loss( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -114,7 +115,7 @@ def compute_losses( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -125,7 +126,7 @@ def compute_losses( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -136,7 +137,7 @@ def compute_losses( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -146,7 +147,7 @@ def compute_losses( case PGDReconLossConfig(): loss = pgd_recon_loss( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -156,7 +157,7 @@ def compute_losses( case PGDReconSubsetLossConfig(): loss = pgd_recon_subset_loss( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -167,7 +168,7 @@ def compute_losses( case PGDReconLayerwiseLossConfig(): loss = pgd_recon_layerwise_loss( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, @@ -193,7 +194,7 @@ def compute_losses( ci=ci.lower_leaky, weight_deltas=weight_deltas if use_delta_component else None, target_out=target_out, - output_loss_type=output_loss_type, + output_loss_type="kl" if reconstruction_loss is recon_loss_kl else "mse", ) losses[cfg] = loss diff --git a/spd/metrics/ci_masked_recon_layerwise_loss.py b/spd/metrics/ci_masked_recon_layerwise_loss.py index b7ff12be9..e39078b4c 100644 --- a/spd/metrics/ci_masked_recon_layerwise_loss.py +++ b/spd/metrics/ci_masked_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -6,17 +6,17 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm def _ci_masked_recon_layerwise_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], ) -> tuple[Float[Tensor, ""], int]: sum_loss = torch.tensor(0.0, device=batch.device) @@ -24,8 +24,8 @@ def _ci_masked_recon_layerwise_loss_update( mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) for module_name, mask_info in mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - n_examples += out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() + loss, count = reconstruction_loss(out, target_out) + n_examples += count sum_loss += loss return sum_loss, n_examples @@ -38,14 +38,14 @@ def _ci_masked_recon_layerwise_loss_compute( def ci_masked_recon_layerwise_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_layerwise_loss_update( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci, @@ -59,10 +59,10 @@ class CIMaskedReconLayerwiseLoss(Metric): metric_section: ClassVar[str] = "loss" def __init__( - self, model: ComponentModel, device: str, output_loss_type: Literal["mse", "kl"] + self, model: ComponentModel, device: str, reconstruction_loss: ReconstructionLoss ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -70,14 +70,14 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, **_: Any, ) -> None: sum_loss, n_examples = _ci_masked_recon_layerwise_loss_update( model=self.model, - output_loss_type=self.output_loss_type, + reconstruction_loss=self.reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, diff --git a/spd/metrics/ci_masked_recon_loss.py b/spd/metrics/ci_masked_recon_loss.py index a11c11469..f0783c7b9 100644 --- a/spd/metrics/ci_masked_recon_loss.py +++ b/spd/metrics/ci_masked_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -6,24 +6,23 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm def _ci_masked_recon_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], ) -> tuple[Float[Tensor, ""], int]: mask_infos = make_mask_infos(ci, weight_deltas_and_masks=None) out = model(batch, mask_infos=mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - return loss, out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, count = reconstruction_loss(out, target_out) + return loss, count def _ci_masked_recon_loss_compute( @@ -34,14 +33,14 @@ def _ci_masked_recon_loss_compute( def ci_masked_recon_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_loss_update( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci, @@ -55,10 +54,10 @@ class CIMaskedReconLoss(Metric): metric_section: ClassVar[str] = "loss" def __init__( - self, model: ComponentModel, device: str, output_loss_type: Literal["mse", "kl"] + self, model: ComponentModel, device: str, reconstruction_loss: ReconstructionLoss ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -66,14 +65,14 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, **_: Any, ) -> None: sum_loss, n_examples = _ci_masked_recon_loss_update( model=self.model, - output_loss_type=self.output_loss_type, + reconstruction_loss=self.reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, diff --git a/spd/metrics/ci_masked_recon_subset_loss.py b/spd/metrics/ci_masked_recon_subset_loss.py index 0a2e83441..461dfefb2 100644 --- a/spd/metrics/ci_masked_recon_subset_loss.py +++ b/spd/metrics/ci_masked_recon_subset_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,18 +7,18 @@ from spd.configs import SubsetRoutingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.models.components import make_mask_infos from spd.routing import Router, get_subset_router from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm def _ci_masked_recon_subset_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], router: Router, ) -> tuple[Float[Tensor, ""], int]: @@ -32,9 +32,8 @@ def _ci_masked_recon_subset_loss_update( weight_deltas_and_masks=None, ) out = model(batch, mask_infos=mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - return loss, out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, count = reconstruction_loss(out, target_out) + return loss, count def _ci_masked_recon_subset_loss_compute( @@ -45,15 +44,15 @@ def _ci_masked_recon_subset_loss_compute( def ci_masked_recon_subset_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], routing: SubsetRoutingType, ) -> Float[Tensor, ""]: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=model, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch=batch, target_out=target_out, ci=ci, @@ -71,11 +70,11 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, routing: SubsetRoutingType, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.router = get_subset_router(routing, device) self.sum_loss = torch.tensor(0.0, device=device) @@ -85,14 +84,14 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, **_: Any, ) -> None: sum_loss, n_examples = _ci_masked_recon_subset_loss_update( model=self.model, - output_loss_type=self.output_loss_type, + reconstruction_loss=self.reconstruction_loss, batch=batch, target_out=target_out, ci=ci.lower_leaky, diff --git a/spd/metrics/pgd_masked_recon_layerwise_loss.py b/spd/metrics/pgd_masked_recon_layerwise_loss.py index 787cad8a2..17a861357 100644 --- a/spd/metrics/pgd_masked_recon_layerwise_loss.py +++ b/spd/metrics/pgd_masked_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -8,6 +8,7 @@ from spd.configs import PGDConfig from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import LayerRouter from spd.utils.distributed_utils import all_reduce @@ -16,9 +17,9 @@ def _pgd_recon_layerwise_loss_update( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Any, + reconstruction_loss: ReconstructionLoss, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -33,7 +34,7 @@ def _pgd_recon_layerwise_loss_update( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, router=LayerRouter(device=device, layer_name=layer), pgd_config=pgd_config, ) @@ -45,9 +46,9 @@ def _pgd_recon_layerwise_loss_update( def pgd_recon_layerwise_loss( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Any, + reconstruction_loss: ReconstructionLoss, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -56,7 +57,7 @@ def pgd_recon_layerwise_loss( model=model, batch=batch, target_out=target_out, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, ci=ci, weight_deltas=weight_deltas, pgd_config=pgd_config, @@ -73,14 +74,14 @@ class PGDReconLayerwiseLoss(Metric): def __init__( self, model: ComponentModel, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, pgd_config: PGDConfig, device: str, use_delta_component: bool, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.use_delta_component: bool = use_delta_component self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -89,8 +90,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, @@ -99,7 +100,7 @@ def update( model=self.model, batch=batch, target_out=target_out, - output_loss_type=self.output_loss_type, + reconstruction_loss=self.reconstruction_loss, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, pgd_config=self.pgd_config, diff --git a/spd/metrics/pgd_masked_recon_loss.py b/spd/metrics/pgd_masked_recon_loss.py index 7d35e149f..b01218ef4 100644 --- a/spd/metrics/pgd_masked_recon_loss.py +++ b/spd/metrics/pgd_masked_recon_loss.py @@ -1,13 +1,14 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp from spd.configs import PGDConfig from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.distributed_utils import all_reduce @@ -16,9 +17,9 @@ def pgd_recon_loss( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Any, + reconstruction_loss: ReconstructionLoss, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -29,7 +30,7 @@ def pgd_recon_loss( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, router=AllLayersRouter(), pgd_config=pgd_config, ) @@ -46,13 +47,13 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, pgd_config: PGDConfig, use_delta_component: bool, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.use_delta_component: bool = use_delta_component self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -61,8 +62,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, **_: Any, @@ -73,7 +74,7 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, target_out=target_out, - output_loss_type=self.output_loss_type, + reconstruction_loss=self.reconstruction_loss, router=AllLayersRouter(), pgd_config=self.pgd_config, ) diff --git a/spd/metrics/pgd_masked_recon_subset_loss.py b/spd/metrics/pgd_masked_recon_subset_loss.py index a9e8a7eac..2856fbb54 100644 --- a/spd/metrics/pgd_masked_recon_subset_loss.py +++ b/spd/metrics/pgd_masked_recon_subset_loss.py @@ -1,13 +1,14 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp from spd.configs import PGDConfig, SubsetRoutingType from spd.metrics.base import Metric from spd.metrics.pgd_utils import pgd_masked_recon_loss_update +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import get_subset_router from spd.utils.distributed_utils import all_reduce @@ -16,9 +17,9 @@ def pgd_recon_subset_loss( *, model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + batch: Any, + target_out: Any, + reconstruction_loss: ReconstructionLoss, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, pgd_config: PGDConfig, @@ -30,7 +31,7 @@ def pgd_recon_subset_loss( ci=ci, weight_deltas=weight_deltas, target_out=target_out, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, router=get_subset_router(routing, batch.device), pgd_config=pgd_config, ) @@ -47,14 +48,14 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, use_delta_component: bool, pgd_config: PGDConfig, routing: SubsetRoutingType, ) -> None: self.model = model self.pgd_config: PGDConfig = pgd_config - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.use_delta_component: bool = use_delta_component self.router = get_subset_router(routing, device) @@ -65,8 +66,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -77,7 +78,7 @@ def update( ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, target_out=target_out, - output_loss_type=self.output_loss_type, + reconstruction_loss=self.reconstruction_loss, router=self.router, pgd_config=self.pgd_config, ) diff --git a/spd/metrics/pgd_utils.py b/spd/metrics/pgd_utils.py index 13aea9892..4d4d170ad 100644 --- a/spd/metrics/pgd_utils.py +++ b/spd/metrics/pgd_utils.py @@ -1,28 +1,28 @@ from collections.abc import Callable, Iterator from functools import partial -from typing import Literal +from typing import Any import torch -from jaxtyping import Float, Int +from jaxtyping import Float from torch import Tensor from torch.distributed import ReduceOp from spd.configs import PGDConfig, PGDInitStrategy, PGDMultiBatchConfig, SamplingType from spd.log import logger +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel, OutputWithCache from spd.models.components import RoutingMasks, make_mask_infos from spd.routing import Router from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, extract_batch_data def pgd_masked_recon_loss_update( model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + target_out: Any, + reconstruction_loss: ReconstructionLoss, router: Router, pgd_config: PGDConfig, ) -> tuple[Float[Tensor, ""], int]: @@ -57,7 +57,7 @@ def pgd_masked_recon_loss_update( weight_deltas=weight_deltas, routing_masks=routing_masks, target_out=target_out, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch_dims=batch_dims, ) @@ -79,10 +79,7 @@ def pgd_masked_recon_loss_update( return fwd_pass() -CreateDataIter = Callable[ - [], - Iterator[Int[Tensor, "..."]] | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], -] +CreateDataIter = Callable[[], Iterator[Any]] def calc_multibatch_pgd_masked_recon_loss( @@ -90,11 +87,10 @@ def calc_multibatch_pgd_masked_recon_loss( model: ComponentModel, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, create_data_iter: CreateDataIter, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, router: Router, sampling: SamplingType, use_delta_component: bool, - batch_dims: tuple[int, ...], device: str, ) -> Float[Tensor, ""]: """PGD masked reconstruction loss with gradient accumulation over multiple batches. @@ -108,14 +104,15 @@ def calc_multibatch_pgd_masked_recon_loss( create_data_iter: Function to create an iterator over batches. This function should return an iterator which behaves identically each time. Specifically in terms of data ordering and shuffling. - output_loss_type: Loss type for reconstruction ("mse" or "kl") + reconstruction_loss: Reconstruction loss function router: Router to use for routing masks sampling: Sampling mode for causal importance calculation use_delta_component: Whether to include weight delta component - batch_dims: Dimensions of batch (e.g., (batch_size,) or (batch_size, seq_len)) Returns: Final reconstruction loss after PGD optimization """ + first_batch = next(create_data_iter()) + batch_dims = tuple(first_batch.shape[:-1]) if first_batch.dim() > 1 else (first_batch.shape[0],) singleton_batch_dims = [1 for _ in batch_dims] adv_sources: dict[str, Float[Tensor, "*ones mask_c"]] = {} @@ -134,10 +131,9 @@ def calc_multibatch_pgd_masked_recon_loss( model=model, weight_deltas=weight_deltas, device=device, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, sampling=sampling, router=router, - batch_dims=batch_dims, ) for _ in range(pgd_config.n_steps): @@ -155,13 +151,13 @@ def calc_multibatch_pgd_masked_recon_loss( def _forward_with_adv_sources( model: ComponentModel, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], + batch: Any, adv_sources: dict[str, Float[Tensor, "*batch_dim_or_ones mask_c"]], ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing_masks: RoutingMasks, - target_out: Float[Tensor, "... vocab"], - output_loss_type: Literal["mse", "kl"], + target_out: Any, + reconstruction_loss: ReconstructionLoss, batch_dims: tuple[int, ...], ): expanded_adv_sources = {k: v.expand(*batch_dims, -1) for k, v in adv_sources.items()} @@ -183,13 +179,9 @@ def _forward_with_adv_sources( ) out = model(batch, mask_infos=mask_infos) - sum_loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - - n_examples = ( - target_out.shape.numel() if output_loss_type == "mse" else target_out.shape[:-1].numel() - ) + loss, count = reconstruction_loss(out, target_out) - return sum_loss, n_examples + return loss, count def _multibatch_pgd_fwd_bwd( @@ -197,13 +189,11 @@ def _multibatch_pgd_fwd_bwd( pgd_config: PGDMultiBatchConfig, model: ComponentModel, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, - data_iter: Iterator[Int[Tensor, "..."]] - | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], + data_iter: Iterator[Any], device: torch.device | str, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, router: Router, sampling: SamplingType, - batch_dims: tuple[int, ...], ) -> tuple[Float[Tensor, ""], int, dict[str, Float[Tensor, "*ones mask_c"]]]: """Perform a forward and backward pass over multiple batches with gradient accumulation. @@ -218,11 +208,10 @@ def _multibatch_pgd_fwd_bwd( for microbatch_idx in range(pgd_config.gradient_accumulation_steps): try: - microbatch_item = next(data_iter) + microbatch = next(data_iter) except StopIteration: logger.warning(f"Dataloader exhausted after {microbatch_idx} batches, ending PGD step.") break - microbatch = extract_batch_data(microbatch_item).to(device) # NOTE: technically this is duplicated work across PGD steps, but that's the price we pay to # enable accumulating gradients over more microbatches than we'd be able to fit CI values in @@ -233,6 +222,8 @@ def _multibatch_pgd_fwd_bwd( sampling=sampling, ).lower_leaky + batch_dims = next(iter(ci.values())).shape[:-1] + # It's important that we call this every microbatch to ensure stochastic routing masks are # sampled independently for each example. routing_masks = router.get_masks( @@ -247,7 +238,7 @@ def _multibatch_pgd_fwd_bwd( weight_deltas=weight_deltas, routing_masks=routing_masks, target_out=target_model_output.output, - output_loss_type=output_loss_type, + reconstruction_loss=reconstruction_loss, batch_dims=batch_dims, ) diff --git a/spd/metrics/stochastic_recon_layerwise_loss.py b/spd/metrics/stochastic_recon_layerwise_loss.py index b14d57fe3..d3327bd95 100644 --- a/spd/metrics/stochastic_recon_layerwise_loss.py +++ b/spd/metrics/stochastic_recon_layerwise_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,22 +7,23 @@ from spd.configs import SamplingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device def _stochastic_recon_layerwise_loss_update( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -42,9 +43,8 @@ def _stochastic_recon_layerwise_loss_update( for stochastic_mask_infos in stochastic_mask_infos_list: for module_name, mask_info in stochastic_mask_infos.items(): out = model(batch, mask_infos={module_name: mask_info}) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - - n_examples += out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() + loss, count = reconstruction_loss(out, target_out) + n_examples += count sum_loss += loss return sum_loss, n_examples @@ -59,21 +59,21 @@ def stochastic_recon_layerwise_loss( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _stochastic_recon_layerwise_loss_update( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, + reconstruction_loss=reconstruction_loss, ) return _stochastic_recon_layerwise_loss_compute(sum_loss, n_examples) @@ -90,13 +90,13 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -104,8 +104,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -114,11 +114,11 @@ def update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/stochastic_recon_loss.py b/spd/metrics/stochastic_recon_loss.py index 46cb0ad61..04bdd3643 100644 --- a/spd/metrics/stochastic_recon_loss.py +++ b/spd/metrics/stochastic_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,22 +7,23 @@ from spd.configs import SamplingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import AllLayersRouter from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device def _stochastic_recon_loss_update( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -40,10 +41,9 @@ def _stochastic_recon_loss_update( ] for stoch_mask_infos in stoch_mask_infos_list: out = model(batch, mask_infos=stoch_mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - n_examples += out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, count = reconstruction_loss(out, target_out) sum_loss += loss + n_examples += count return sum_loss, n_examples @@ -57,21 +57,21 @@ def stochastic_recon_loss( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _stochastic_recon_loss_update( model, sampling, n_mask_samples, - output_loss_type, batch, target_out, ci, weight_deltas, + reconstruction_loss, ) return _stochastic_recon_loss_compute(sum_loss, n_examples) @@ -88,13 +88,13 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model self.sampling: SamplingType = sampling - self.use_delta_component: bool = use_delta_component - self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.use_delta_component = use_delta_component + self.n_mask_samples = n_mask_samples + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -102,8 +102,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -112,11 +112,11 @@ def update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/stochastic_recon_subset_loss.py b/spd/metrics/stochastic_recon_subset_loss.py index 62573a889..2d6a88076 100644 --- a/spd/metrics/stochastic_recon_subset_loss.py +++ b/spd/metrics/stochastic_recon_subset_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -7,23 +7,24 @@ from spd.configs import SamplingType, SubsetRoutingType from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import CIOutputs, ComponentModel from spd.routing import Router, get_subset_router from spd.utils.component_utils import calc_stochastic_component_mask_info from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm, get_obj_device +from spd.utils.general_utils import get_obj_device def _stochastic_recon_subset_loss_update( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, router: Router, + reconstruction_loss: ReconstructionLoss, ) -> tuple[Float[Tensor, ""], int]: assert ci, "Empty ci" device = get_obj_device(ci) @@ -42,9 +43,8 @@ def _stochastic_recon_subset_loss_update( for stoch_mask_infos in stoch_mask_infos_list: out = model(batch, mask_infos=stoch_mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - n_examples += out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + loss, count = reconstruction_loss(out, target_out) + n_examples += count sum_loss += loss return sum_loss, n_examples @@ -60,23 +60,23 @@ def stochastic_recon_subset_loss( model: ComponentModel, sampling: SamplingType, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: dict[str, Float[Tensor, "... C"]], weight_deltas: dict[str, Float[Tensor, "d_out d_in"]] | None, routing: SubsetRoutingType, + reconstruction_loss: ReconstructionLoss, ) -> Float[Tensor, ""]: sum_loss, n_examples = _stochastic_recon_subset_loss_update( model=model, sampling=sampling, n_mask_samples=n_mask_samples, - output_loss_type=output_loss_type, batch=batch, target_out=target_out, ci=ci, weight_deltas=weight_deltas, router=get_subset_router(routing, batch.device), + reconstruction_loss=reconstruction_loss, ) return _stochastic_recon_subset_loss_compute(sum_loss, n_examples) @@ -93,14 +93,14 @@ def __init__( sampling: SamplingType, use_delta_component: bool, n_mask_samples: int, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, routing: SubsetRoutingType, ) -> None: self.model = model self.sampling: SamplingType = sampling self.use_delta_component: bool = use_delta_component self.n_mask_samples: int = n_mask_samples - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.router = get_subset_router(routing, device) self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -109,8 +109,8 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, ci: CIOutputs, weight_deltas: dict[str, Float[Tensor, "d_out d_in"]], **_: Any, @@ -119,12 +119,12 @@ def update( model=self.model, sampling=self.sampling, n_mask_samples=self.n_mask_samples, - output_loss_type=self.output_loss_type, batch=batch, target_out=target_out, ci=ci.lower_leaky, weight_deltas=weight_deltas if self.use_delta_component else None, router=self.router, + reconstruction_loss=self.reconstruction_loss, ) self.sum_loss += sum_loss self.n_examples += n_examples diff --git a/spd/metrics/unmasked_recon_loss.py b/spd/metrics/unmasked_recon_loss.py index 01cf67fe0..72cfeac57 100644 --- a/spd/metrics/unmasked_recon_loss.py +++ b/spd/metrics/unmasked_recon_loss.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Literal, override +from typing import Any, ClassVar, override import torch from jaxtyping import Float, Int @@ -6,17 +6,17 @@ from torch.distributed import ReduceOp from spd.metrics.base import Metric +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel from spd.models.components import make_mask_infos from spd.utils.distributed_utils import all_reduce -from spd.utils.general_utils import calc_sum_recon_loss_lm def _unmasked_recon_loss_update( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ) -> tuple[Float[Tensor, ""], int]: all_ones_mask_infos = make_mask_infos( # (C,) will broadcast to (B, S, C) @@ -26,9 +26,8 @@ def _unmasked_recon_loss_update( } ) out = model(batch, mask_infos=all_ones_mask_infos) - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=output_loss_type) - n_examples = out.shape.numel() if output_loss_type == "mse" else out.shape[:-1].numel() - return loss, n_examples + loss, count = reconstruction_loss(out, target_out) + return loss, count def _unmasked_recon_loss_compute( @@ -39,13 +38,13 @@ def _unmasked_recon_loss_compute( def unmasked_recon_loss( model: ComponentModel, - output_loss_type: Literal["mse", "kl"], - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + reconstruction_loss: ReconstructionLoss, + batch: Any, + target_out: Any, ) -> Float[Tensor, ""]: sum_loss, n_examples = _unmasked_recon_loss_update( model, - output_loss_type, + reconstruction_loss, batch, target_out, ) @@ -61,10 +60,10 @@ def __init__( self, model: ComponentModel, device: str, - output_loss_type: Literal["mse", "kl"], + reconstruction_loss: ReconstructionLoss, ) -> None: self.model = model - self.output_loss_type: Literal["mse", "kl"] = output_loss_type + self.reconstruction_loss = reconstruction_loss self.sum_loss = torch.tensor(0.0, device=device) self.n_examples = torch.tensor(0, device=device) @@ -72,13 +71,13 @@ def __init__( def update( self, *, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], + batch: Any, + target_out: Any, **_: Any, ) -> None: sum_loss, n_examples = _unmasked_recon_loss_update( model=self.model, - output_loss_type=self.output_loss_type, + reconstruction_loss=self.reconstruction_loss, batch=batch, target_out=target_out, ) diff --git a/spd/models/batch_and_loss_fns.py b/spd/models/batch_and_loss_fns.py new file mode 100644 index 000000000..f15e7e76d --- /dev/null +++ b/spd/models/batch_and_loss_fns.py @@ -0,0 +1,39 @@ +"""Reconstruction loss functions for different model types. + +These functions parameterize training for different target model architectures. +""" + +from typing import Any, Protocol + +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor + + +class ReconstructionLoss(Protocol): + """Protocol for computing reconstruction loss between predictions and targets.""" + + def __call__(self, pred: Any, target: Any) -> tuple[Float[Tensor, ""], int]: ... + + +def recon_loss_mse( + pred: Float[Tensor, "... d"], + target: Float[Tensor, "... d"], +) -> tuple[Float[Tensor, ""], int]: + """MSE reconstruction loss. Returns (sum_of_squared_errors, n_elements).""" + assert pred.shape == target.shape + squared_errors = (pred - target) ** 2 + return squared_errors.sum(), pred.numel() + + +def recon_loss_kl( + pred: Float[Tensor, "... vocab"], + target: Float[Tensor, "... vocab"], +) -> tuple[Float[Tensor, ""], int]: + """KL divergence reconstruction loss for logits. Returns (sum_of_kl, n_positions).""" + assert pred.shape == target.shape + log_q = torch.log_softmax(pred, dim=-1) # log Q + p = torch.softmax(target, dim=-1) # P + kl_per_position = F.kl_div(log_q, p, reduction="none").sum(dim=-1) # P · (log P − log Q) + return kl_per_position.sum(), pred[..., 0].numel() diff --git a/spd/models/component_model.py b/spd/models/component_model.py index 3ea9384e4..9647d3e3e 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -1,4 +1,5 @@ import fnmatch +import re from collections.abc import Callable, Generator, Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -104,7 +105,7 @@ def __init__( module_path_info: list[ModulePathInfo], ci_config: CiConfig, sigmoid_type: SigmoidType, - pretrained_model_output_attr: str | None, + extract_tensor_output: str | None = None, ): super().__init__() @@ -115,7 +116,7 @@ def __init__( ) self.target_model = target_model - self.pretrained_model_output_attr = pretrained_model_output_attr + self.extract_tensor_output = extract_tensor_output self.module_to_c = {info.module_path: info.C for info in module_path_info} self.target_module_paths = list(self.module_to_c.keys()) @@ -374,67 +375,57 @@ def _create_global_ci_fn( attn_config=ci_config.transition_attn_config, ) - def _extract_output(self, raw_output: Any) -> Tensor: + def _extract_output(self, raw_output: Any) -> Any: """Extract the desired output from the model's raw output. - If pretrained_model_output_attr is None, returns the raw output directly. - If pretrained_model_output_attr starts with "idx_", returns the index specified by the - second part of the string. E.g. "idx_0" returns the first element of the raw output. - Otherwise, returns the specified attribute from the raw output. - - Args: - raw_output: The raw output from the model. - - Returns: - The extracted output. + Uses the declarative accessor path in extract_tensor_output: + - None: returns the raw output directly + - ".logits": attribute access (getattr) + - "[0]": index access + - ".output[0]": chained attribute + index access """ - if self.pretrained_model_output_attr is None: - out = raw_output - elif self.pretrained_model_output_attr.startswith("idx_"): - idx_val = int(self.pretrained_model_output_attr.split("_")[1]) - assert isinstance(raw_output, Sequence), ( - f"raw_output must be a sequence, not {type(raw_output)}" - ) - assert idx_val < len(raw_output), ( - f"Index {idx_val} out of range for raw_output of length {len(raw_output)}" - ) - out = raw_output[idx_val] - else: - out = getattr(raw_output, self.pretrained_model_output_attr) - - assert isinstance(out, Tensor), f"Expected tensor output, got {type(out)}" - return out + if self.extract_tensor_output is None: + return raw_output + + result = raw_output + for step in re.findall(r"\.\w+|\[\d+\]", self.extract_tensor_output): + if step.startswith("."): + result = getattr(result, step[1:]) + elif step.startswith("["): + idx = int(step[1:-1]) + assert isinstance(result, Sequence), ( + f"Expected sequence for index access, got {type(result)}" + ) + result = result[idx] + return result @overload def __call__( self, - *args: Any, - mask_infos: dict[str, ComponentsMaskInfo] | None = None, + batch: Any, cache_type: Literal["component_acts", "input"], - **kwargs: Any, + mask_infos: dict[str, ComponentsMaskInfo] | None = None, ) -> OutputWithCache: ... @overload def __call__( self, - *args: Any, + batch: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["none"] = "none", - **kwargs: Any, - ) -> Tensor: ... + ) -> Any: ... @override - def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache: + def __call__(self, *args: Any, **kwargs: Any) -> Any | OutputWithCache: return super().__call__(*args, **kwargs) @override def forward( self, - *args: Any, + batch: Any, mask_infos: dict[str, ComponentsMaskInfo] | None = None, cache_type: Literal["component_acts", "input", "none"] = "none", - **kwargs: Any, - ) -> Tensor | OutputWithCache: + ) -> Any | OutputWithCache: """Forward pass with optional component replacement and/or input caching. This method handles the following 4 cases: @@ -458,8 +449,7 @@ def forward( model output tensor. """ if mask_infos is None and cache_type == "none": - # No hooks needed. Do a regular forward pass of the target model. - return self._extract_output(self.target_model(*args, **kwargs)) + return self._extract_output(self.target_model(batch)) cache: dict[str, Tensor] = {} hooks: dict[str, Callable[..., Any]] = {} @@ -480,9 +470,8 @@ def forward( ) with self._attach_forward_hooks(hooks): - raw_out = self.target_model(*args, **kwargs) + out = self._extract_output(self.target_model(batch)) - out = self._extract_output(raw_out) match cache_type: case "input" | "component_acts": return OutputWithCache(output=out, cache=cache) @@ -611,8 +600,8 @@ def from_run_info(cls, run_info: RunInfo[Config]) -> "ComponentModel": target_model=target_model, module_path_info=module_path_info, ci_config=config.ci_config, + extract_tensor_output=config.extract_tensor_output, sigmoid_type=config.sigmoid_type, - pretrained_model_output_attr=config.pretrained_model_output_attr, ) comp_model_weights = torch.load( diff --git a/spd/persistent_pgd.py b/spd/persistent_pgd.py index fa1ba4100..84e6d1f89 100644 --- a/spd/persistent_pgd.py +++ b/spd/persistent_pgd.py @@ -12,6 +12,7 @@ from typing import Any, ClassVar, Literal, override import torch +import torch.nn.functional as F from jaxtyping import Float, Int from torch import Tensor from torch.distributed import ReduceOp @@ -33,7 +34,6 @@ from spd.models.components import ComponentsMaskInfo, RoutingMasks, make_mask_infos from spd.routing import AllLayersRouter, Router, get_subset_router from spd.utils.distributed_utils import all_reduce, call_on_rank0_then_broadcast -from spd.utils.general_utils import calc_sum_recon_loss_lm PPGDSources = dict[str, Float[Tensor, " source_c"]] @@ -279,9 +279,15 @@ def _persistent_pgd_recon_subset_loss_update( mask_infos = get_mask_infos(model, ci, weight_deltas, ppgd_sources, router) out = model(batch, mask_infos=mask_infos) - loss_type = output_loss_type - loss = calc_sum_recon_loss_lm(pred=out, target=target_out, loss_type=loss_type) - n_examples = out.shape.numel() if loss_type == "mse" else out.shape[:-1].numel() + match output_loss_type: + case "mse": + loss = ((out - target_out) ** 2).sum() + n_examples = out.numel() + case "kl": + log_q = torch.log_softmax(out, dim=-1) + p = torch.softmax(target_out, dim=-1) + loss = F.kl_div(log_q, p, reduction="none").sum() + n_examples = out[..., 0].numel() return loss, n_examples diff --git a/spd/run_spd.py b/spd/run_spd.py index 1a272ae20..3a937b9fe 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -4,16 +4,14 @@ from collections import defaultdict from collections.abc import Iterator from pathlib import Path -from typing import cast +from typing import Any, cast import torch import torch.nn as nn import torch.nn.parallel -import torch.optim as optim import wandb -from jaxtyping import Float, Int from PIL import Image -from torch import Tensor +from torch import optim from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from tqdm import tqdm @@ -34,6 +32,7 @@ from spd.log import logger from spd.losses import compute_losses from spd.metrics import faithfulness_loss +from spd.models.batch_and_loss_fns import ReconstructionLoss, recon_loss_mse from spd.models.component_model import ComponentModel, OutputWithCache from spd.persistent_pgd import PersistentPGDState from spd.utils.component_utils import calc_ci_l_zero @@ -46,7 +45,6 @@ from spd.utils.general_utils import ( bf16_autocast, dict_safe_update_, - extract_batch_data, get_scheduled_value, ) from spd.utils.logging_utils import get_grad_norms_dict, local_log @@ -118,11 +116,9 @@ def optimize( target_model: nn.Module, config: Config, device: str, - train_loader: DataLoader[Int[Tensor, "..."]] - | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], - eval_loader: DataLoader[Int[Tensor, "..."]] - | DataLoader[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]], - n_eval_steps: int, + train_loader: DataLoader[Any], + eval_loader: DataLoader[Any], + reconstruction_loss: ReconstructionLoss, out_dir: Path | None, tied_weights: list[tuple[str, str]] | None = None, ) -> None: @@ -131,9 +127,7 @@ def optimize( train_iterator = loop_dataloader(train_loader) eval_iterator = loop_dataloader(eval_loader) - def create_pgd_data_iter() -> ( - Iterator[Int[Tensor, "..."]] | Iterator[tuple[Float[Tensor, "..."], Float[Tensor, "..."]]] - ): + def create_pgd_data_iter() -> Iterator[Any]: assert hasattr(train_loader, "generator") and train_loader.generator is not None train_loader.generator.manual_seed(config.seed) return iter(train_loader) @@ -155,8 +149,8 @@ def create_pgd_data_iter() -> ( target_model=target_model, module_path_info=module_path_info, ci_config=config.ci_config, + extract_tensor_output=config.extract_tensor_output, sigmoid_type=config.sigmoid_type, - pretrained_model_output_attr=config.pretrained_model_output_attr, ) model.to(device) @@ -164,6 +158,8 @@ def create_pgd_data_iter() -> ( # Wrap model with DDP if distributed dist_state = get_distributed_state() wrapped_model: nn.Module = model + + component_model: ComponentModel if dist_state is not None: if dist_state.backend == "nccl": device_id = dist_state.local_rank @@ -176,7 +172,7 @@ def create_pgd_data_iter() -> ( # For CPU, don't pass device_ids or output_device wrapped_model = torch.nn.parallel.DistributedDataParallel(model) # Access the underlying module for component operations - component_model = wrapped_model.module # type: ignore[attr-defined] + component_model = cast(ComponentModel, wrapped_model.module) # type: ignore[attr-defined] else: component_model = model assert isinstance(component_model, ComponentModel), "component_model is not a ComponentModel" @@ -236,27 +232,29 @@ def create_pgd_data_iter() -> ( if not isinstance(cfg, PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig) ] - sample_batch = extract_batch_data(next(train_iterator)) - batch_dims = ( - sample_batch.shape[:-1] - if config.output_loss_type == "mse" # if mse then input is a vector - else sample_batch.shape # else it's a batch of token ids - ) - # Initialize PersistentPGD states if needed ppgd_states: dict[ PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig, PersistentPGDState - ] = { - ppgd_cfg: PersistentPGDState( - module_to_c=model.module_to_c, - seq_len=batch_dims[-1], - device=device, - use_delta_component=config.use_delta_component, - cfg=ppgd_cfg, - batch_size=batch_dims[0], + ] = {} + if persistent_pgd_configs: + sample_batch_raw = next(train_iterator) + sample_batch = ( + sample_batch_raw[0] if isinstance(sample_batch_raw, tuple) else sample_batch_raw ) - for ppgd_cfg in persistent_pgd_configs - } + batch_dims = ( + sample_batch.shape[:-1] if reconstruction_loss is recon_loss_mse else sample_batch.shape + ) + ppgd_states = { + ppgd_cfg: PersistentPGDState( + module_to_c=model.module_to_c, + seq_len=batch_dims[-1], + device=device, + use_delta_component=config.use_delta_component, + cfg=ppgd_cfg, + batch_size=batch_dims[0], + ) + for ppgd_cfg in persistent_pgd_configs + } for step in tqdm(range(config.steps + 1), ncols=0, disable=not is_main_process()): optimizer.zero_grad() @@ -276,7 +274,8 @@ def create_pgd_data_iter() -> ( } for _ in range(config.gradient_accumulation_steps): - microbatch = extract_batch_data(next(train_iterator)).to(device, non_blocking=True) + microbatch_raw = next(train_iterator) + microbatch = microbatch_raw[0] if isinstance(microbatch_raw, tuple) else microbatch_raw with bf16_autocast(enabled=config.autocast_bf16): # NOTE: we need to call the wrapped_model at least once each step in order @@ -307,7 +306,7 @@ def create_pgd_data_iter() -> ( cfg: ppgd_states[cfg].get_effective_sources() for cfg in persistent_pgd_configs }, - output_loss_type=config.output_loss_type, + reconstruction_loss=reconstruction_loss, ) # Compute total loss and accumulate PPGD grads @@ -379,14 +378,13 @@ def create_pgd_data_iter() -> ( else step % config.slow_eval_freq == 0 ) - assert batch_dims is not None, "batch_dims is not set" multibatch_pgd_metrics = evaluate_multibatch_pgd( multibatch_pgd_eval_configs=multibatch_pgd_eval_configs, model=component_model, create_data_iter=create_pgd_data_iter, config=config, - batch_dims=batch_dims, device=device, + reconstruction_loss=reconstruction_loss, ) metrics = evaluate( @@ -400,8 +398,9 @@ def create_pgd_data_iter() -> ( device=device, run_config=config, slow_step=slow_step, - n_eval_steps=n_eval_steps, + n_eval_steps=config.n_eval_steps, current_frac_of_training=step / config.steps, + reconstruction_loss=reconstruction_loss, ) dict_safe_update_(metrics, multibatch_pgd_metrics) diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index 4bfee9c9a..115effc6c 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -25,7 +25,7 @@ from spd.log import logger from spd.models.component_model import ComponentModel, SPDRunInfo from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data, get_obj_device +from spd.utils.general_utils import get_obj_device from spd.utils.run_utils import save_file @@ -250,7 +250,7 @@ def compute_activation_densities( model.eval() with torch.no_grad(): for _step in range(n_steps): - batch = extract_batch_data(next(eval_iterator)) + batch = next(eval_iterator) batch = batch.to(self.device) pre_weight_acts = model(batch, cache_type="input").cache diff --git a/spd/utils/general_utils.py b/spd/utils/general_utils.py index 1fda48111..bb4e846f1 100644 --- a/spd/utils/general_utils.py +++ b/spd/utils/general_utils.py @@ -2,7 +2,7 @@ import random from collections.abc import Sequence from pathlib import Path -from typing import Any, Literal, Protocol +from typing import Any, Protocol import einops import numpy as np @@ -167,56 +167,15 @@ def resolve_class(path: str) -> type[nn.Module]: return getattr(module, class_name) -def extract_batch_data( - batch_item: dict[str, Any] | tuple[Tensor, "..."] | Tensor, - input_key: str = "input_ids", -) -> Tensor: - """Extract input data from various batch formats. - - This utility function handles different batch formats commonly used across the codebase: - 1. Dictionary format: {"input_ids": Tensor, "..."} - common in LM tasks - 2. Tuple format: (input_tensor, labels) - common in SPD optimization - 3. Direct tensor: when batch is already the input tensor - - Args: - batch_item: The batch item from a data loader - input_key: Key to use for dictionary format (default: "input_ids") - - Returns: - The input tensor extracted from the batch - """ - assert isinstance(batch_item, dict | tuple | Tensor), ( - f"Unsupported batch format: {type(batch_item)}. Must be a dictionary, tuple, or tensor." - ) - if isinstance(batch_item, dict): - # Dictionary format: extract the specified key - if input_key not in batch_item: - available_keys = list(batch_item.keys()) - raise KeyError( - f"Key '{input_key}' not found in batch. Available keys: {available_keys}" - ) - tensor = batch_item[input_key] - elif isinstance(batch_item, tuple): - # Assume input is the first element - tensor = batch_item[0] - else: - # Direct tensor format - tensor = batch_item - - return tensor - - def calc_kl_divergence_lm( pred: Float[Tensor, "... vocab"], target: Float[Tensor, "... vocab"], - reduce: bool = True, ) -> Float[Tensor, ""] | Float[Tensor, "..."]: """Calculate the KL divergence between two logits. Args: pred: The predicted logits target: The target logits - reduce: Whether to reduce the KL divergence across the batch and sequence dimensions Returns: The KL divergence @@ -226,24 +185,7 @@ def calc_kl_divergence_lm( p = torch.softmax(target, dim=-1) # P kl_raw = F.kl_div(log_q, p, reduction="none") # P · (log P − log Q) kl = kl_raw.sum(dim=-1) - if reduce: - return kl.mean() # Σ_vocab / (batch·seq) - else: - return kl - - -def calc_sum_recon_loss_lm( - pred: Float[Tensor, "... vocab"], - target: Float[Tensor, "... vocab"], - loss_type: Literal["mse", "kl"], -) -> Float[Tensor, ""]: - """Calculate the reconstruction loss for a language model without reduction.""" - match loss_type: - case "mse": - loss = ((pred - target) ** 2).sum() - case "kl": - loss = calc_kl_divergence_lm(pred=pred, target=target, reduce=False).sum() - return loss + return kl.mean() # Σ_vocab / (batch·seq) def runtime_cast[T](type_: type[T], obj: Any) -> T: diff --git a/tests/app/test_server_api.py b/tests/app/test_server_api.py index fcb06c099..50f982485 100644 --- a/tests/app/test_server_api.py +++ b/tests/app/test_server_api.py @@ -99,9 +99,8 @@ def app_with_state(): ModulePatternInfoConfig(module_pattern=p, C=C) for p in target_module_patterns ], pretrained_model_class="spd.pretrain.models.gpt2_simple.GPT2Simple", - pretrained_model_output_attr="idx_0", + extract_tensor_output="[0]", tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - output_loss_type="kl", lr_schedule=ScheduleConfig(start_val=1e-3), steps=1, batch_size=1, @@ -124,7 +123,7 @@ def app_with_state(): target_model=target_model, module_path_info=module_path_info, ci_config=config.ci_config, - pretrained_model_output_attr=config.pretrained_model_output_attr, + extract_tensor_output=config.extract_tensor_output, sigmoid_type=config.sigmoid_type, ) model.eval() diff --git a/tests/metrics/fixtures.py b/tests/metrics/fixtures.py index ce594ed76..5b4b3632b 100644 --- a/tests/metrics/fixtures.py +++ b/tests/metrics/fixtures.py @@ -58,7 +58,7 @@ def make_one_layer_component_model(weight: Float[Tensor, "d_out d_in"]) -> Compo target_model=target, module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -94,7 +94,7 @@ def make_two_layer_component_model( ModulePathInfo(module_path="fc2", C=1), ], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) diff --git a/tests/metrics/test_ci_masked_recon_layerwise_loss.py b/tests/metrics/test_ci_masked_recon_layerwise_loss.py index 00b8092b2..90458b231 100644 --- a/tests/metrics/test_ci_masked_recon_layerwise_loss.py +++ b/tests/metrics/test_ci_masked_recon_layerwise_loss.py @@ -1,6 +1,7 @@ import torch from spd.metrics import ci_masked_recon_layerwise_loss, ci_masked_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model, make_two_layer_component_model @@ -43,7 +44,11 @@ def test_two_layer_manual_calculation(self: object) -> None: # Calculate actual loss actual_loss = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci, ) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( @@ -60,10 +65,18 @@ def test_layerwise_vs_all_layer(self: object) -> None: ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} loss_all = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci, ) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci, ) # For single layer, results should be the same diff --git a/tests/metrics/test_ci_masked_recon_loss.py b/tests/metrics/test_ci_masked_recon_loss.py index 3f1202425..38ac7ed24 100644 --- a/tests/metrics/test_ci_masked_recon_loss.py +++ b/tests/metrics/test_ci_masked_recon_loss.py @@ -1,6 +1,7 @@ import torch from spd.metrics import ci_masked_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model @@ -26,7 +27,11 @@ def test_manual_calculation(self: object) -> None: # Calculate actual loss actual_loss = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci, ) assert torch.allclose(actual_loss, expected_loss, rtol=1e-5), ( @@ -45,10 +50,18 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_full + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci_full, ) loss_half = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_half + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci_half, ) # Different CI values should produce different losses diff --git a/tests/metrics/test_ci_masked_recon_subset_loss.py b/tests/metrics/test_ci_masked_recon_subset_loss.py index 4a9f870b7..39ce1099e 100644 --- a/tests/metrics/test_ci_masked_recon_subset_loss.py +++ b/tests/metrics/test_ci_masked_recon_subset_loss.py @@ -5,6 +5,7 @@ from spd.configs import UniformKSubsetRoutingConfig from spd.metrics import ci_masked_recon_subset_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from tests.metrics.fixtures import make_one_layer_component_model @@ -77,7 +78,7 @@ def mock_sample_uniform_k_subset_routing_masks( for _ in range(2): actual_loss = ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, diff --git a/tests/metrics/test_stochastic_recon_layerwise_loss.py b/tests/metrics/test_stochastic_recon_layerwise_loss.py index 3862d85f8..6a2193d7e 100644 --- a/tests/metrics/test_stochastic_recon_layerwise_loss.py +++ b/tests/metrics/test_stochastic_recon_layerwise_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType from spd.metrics import stochastic_recon_layerwise_loss, stochastic_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model, make_two_layer_component_model @@ -105,7 +106,7 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -130,7 +131,7 @@ def test_layerwise_vs_full_loss(self: object) -> None: model=model, sampling="continuous", n_mask_samples=5, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -140,7 +141,7 @@ def test_layerwise_vs_full_loss(self: object) -> None: model=model, sampling="continuous", n_mask_samples=5, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, diff --git a/tests/metrics/test_stochastic_recon_loss.py b/tests/metrics/test_stochastic_recon_loss.py index 594b55a7f..377c4018d 100644 --- a/tests/metrics/test_stochastic_recon_loss.py +++ b/tests/metrics/test_stochastic_recon_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType from spd.metrics import stochastic_recon_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model @@ -78,7 +79,7 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, diff --git a/tests/metrics/test_stochastic_recon_subset_loss.py b/tests/metrics/test_stochastic_recon_subset_loss.py index 484e3d49f..09655c384 100644 --- a/tests/metrics/test_stochastic_recon_subset_loss.py +++ b/tests/metrics/test_stochastic_recon_subset_loss.py @@ -5,6 +5,7 @@ from spd.configs import SamplingType, UniformKSubsetRoutingConfig from spd.metrics import stochastic_recon_subset_loss +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.models.components import ComponentsMaskInfo, make_mask_infos from spd.routing import Router from tests.metrics.fixtures import make_one_layer_component_model @@ -92,7 +93,7 @@ def mock_calc_stochastic_component_mask_info( model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, diff --git a/tests/test_component_model.py b/tests/test_component_model.py index 6835f8a1d..e5cb27a79 100644 --- a/tests/test_component_model.py +++ b/tests/test_component_model.py @@ -96,7 +96,7 @@ def test_correct_parameters_require_grad(): ModulePathInfo(module_path="conv1d2", C=5), ], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -156,7 +156,6 @@ def test_from_run_info(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -177,7 +176,7 @@ def test_from_run_info(): target_model=target_model, module_path_info=module_path_info, ci_config=config.ci_config, - pretrained_model_output_attr=config.pretrained_model_output_attr, + extract_tensor_output=config.extract_tensor_output, sigmoid_type=config.sigmoid_type, ) @@ -283,7 +282,7 @@ def test_full_weight_delta_matches_target_behaviour(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[4]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -314,7 +313,7 @@ def test_input_cache_captures_pre_weight_input(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=2) for p in target_module_paths], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -348,7 +347,7 @@ def test_weight_deltas(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=3) for p in target_module_paths], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -382,7 +381,7 @@ def forward(self, x: Tensor) -> Tensor: target_model=model, module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -437,7 +436,7 @@ def forward(self, x: Tensor) -> Tensor: target_model=model, module_path_info=[ModulePathInfo(module_path="linear.pre_identity", C=C)], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -486,7 +485,7 @@ def forward(self, x: Tensor) -> Tensor: target_model=model, module_path_info=[ModulePathInfo(module_path="linear", C=C)], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -557,7 +556,6 @@ def test_checkpoint_ci_config_mismatch_global_to_layerwise(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -572,7 +570,7 @@ def test_checkpoint_ci_config_mismatch_global_to_layerwise(): target_model=target_model, module_path_info=module_path_info, ci_config=config_global.ci_config, - pretrained_model_output_attr=config_global.pretrained_model_output_attr, + extract_tensor_output=config_global.extract_tensor_output, sigmoid_type=config_global.sigmoid_type, ) @@ -599,7 +597,6 @@ def test_checkpoint_ci_config_mismatch_global_to_layerwise(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -657,7 +654,6 @@ def test_checkpoint_ci_config_mismatch_layerwise_to_global(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -672,7 +668,7 @@ def test_checkpoint_ci_config_mismatch_layerwise_to_global(): target_model=target_model, module_path_info=module_path_info, ci_config=config_layerwise.ci_config, - pretrained_model_output_attr=config_layerwise.pretrained_model_output_attr, + extract_tensor_output=config_layerwise.extract_tensor_output, sigmoid_type=config_layerwise.sigmoid_type, ) @@ -699,7 +695,6 @@ def test_checkpoint_ci_config_mismatch_layerwise_to_global(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -955,7 +950,7 @@ def test_component_model_with_global_ci(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -980,7 +975,7 @@ def test_component_model_global_ci_calc_causal_importances(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1024,7 +1019,7 @@ def test_component_model_global_ci_different_inputs_different_ci(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1055,7 +1050,7 @@ def test_component_model_global_ci_binomial_sampling(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1080,7 +1075,7 @@ def test_component_model_global_ci_with_embeddings(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1116,7 +1111,7 @@ def test_component_model_global_ci_gradient_flow(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1153,7 +1148,7 @@ def test_component_model_global_ci_detach_inputs_blocks_gradients(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1190,7 +1185,7 @@ def test_component_model_global_ci_masking_zeros(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1236,7 +1231,7 @@ def test_component_model_global_ci_partial_masking(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1268,7 +1263,7 @@ def test_component_model_global_ci_weight_deltas_all_ones_matches_target(): target_model=target_model, module_path_info=[ModulePathInfo(module_path=p, C=C) for p in target_module_paths], ci_config=GlobalCiConfig(fn_type="global_shared_mlp", hidden_dims=[16]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -1320,7 +1315,6 @@ def test_global_ci_save_and_load(): eval_freq=1, slow_eval_freq=1, loss_metric_configs=[ImportanceMinimalityLossConfig(coeff=1.0, pnorm=1.0, beta=0.5)], - output_loss_type="mse", train_log_freq=1, n_mask_samples=1, task_config=TMSTaskConfig( @@ -1335,7 +1329,7 @@ def test_global_ci_save_and_load(): target_model=target_model, module_path_info=module_path_info, ci_config=config.ci_config, - pretrained_model_output_attr=config.pretrained_model_output_attr, + extract_tensor_output=config.extract_tensor_output, sigmoid_type=config.sigmoid_type, ) diff --git a/tests/test_distributed.py b/tests/test_distributed.py index 2419b8647..6d6fda187 100644 --- a/tests/test_distributed.py +++ b/tests/test_distributed.py @@ -37,7 +37,6 @@ {"classname": "CIMaskedReconLayerwiseLoss", "coeff": 1.0}, {"classname": "CIMaskedReconLoss", "coeff": 1.0}, ], - "output_loss_type": "kl", # --- Training --- "batch_size": 2, "steps": 20, @@ -57,7 +56,7 @@ # --- Pretrained model info --- "pretrained_model_class": "transformers.LlamaForCausalLM", "pretrained_model_name": "SimpleStories/SimpleStories-1.25M", - "pretrained_model_output_attr": "logits", + "extract_tensor_output": ".logits", "tokenizer_name": "SimpleStories/SimpleStories-1.25M", # --- Task Specific --- "task_config": { diff --git a/tests/test_gpt2.py b/tests/test_gpt2.py index 33ce346dc..6bc17a6e7 100644 --- a/tests/test_gpt2.py +++ b/tests/test_gpt2.py @@ -15,8 +15,9 @@ StochasticReconLayerwiseLossConfig, StochasticReconLossConfig, ) -from spd.data import DatasetConfig, create_data_loader +from spd.data import DatasetConfig, create_data_loader, lm_collate_fn from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize from spd.utils.general_utils import resolve_class, set_seed @@ -55,7 +56,6 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=200), ], - output_loss_type="kl", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -78,7 +78,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="transformers.GPT2LMHeadModel", pretrained_model_path=None, pretrained_model_name="SimpleStories/test-SimpleStories-gpt2-1.25M", - pretrained_model_output_attr="logits", + extract_tensor_output=".logits", tokenizer_name="SimpleStories/test-SimpleStories-gpt2-1.25M", # Task Specific task_config=LMTaskConfig( @@ -123,6 +123,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: batch_size=config.batch_size, buffer_size=config.task_config.buffer_size, global_seed=config.seed, + collate_fn=lm_collate_fn, ) eval_data_config = DatasetConfig( @@ -140,6 +141,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: batch_size=config.batch_size, buffer_size=config.task_config.buffer_size, global_seed=config.seed + 1, + collate_fn=lm_collate_fn, ) # Run optimize function @@ -149,7 +151,7 @@ def test_gpt_2_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 5e4f19ea9..43fa20657 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -18,6 +18,7 @@ from spd.experiments.ih.configs import InductionModelConfig from spd.experiments.ih.model import InductionTransformer from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import recon_loss_kl from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset from spd.utils.general_utils import set_seed @@ -71,7 +72,6 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=200), ], - output_loss_type="kl", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -95,7 +95,7 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.ih.model.InductionTransformer", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, + extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=IHTaskConfig( @@ -133,7 +133,7 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_kl, out_dir=tmp_path, ) diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index ead150684..bd1a06aef 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -14,6 +14,7 @@ from spd.experiments.resid_mlp.models import ResidMLP from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader from spd.utils.general_utils import set_seed @@ -62,7 +63,6 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: identity_module_info=[ ModulePatternInfoConfig(module_pattern="layers.*.mlp_in", C=10), ], - output_loss_type="mse", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.01, final_val_frac=0.0 @@ -82,7 +82,7 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.resid_mlp.models.ResidMLP", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, + extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=ResidMLPTaskConfig( @@ -129,7 +129,7 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_mse, out_dir=tmp_path, ) diff --git a/tests/test_spd_losses.py b/tests/test_spd_losses.py index 3fa12cb53..49ac8f5fc 100644 --- a/tests/test_spd_losses.py +++ b/tests/test_spd_losses.py @@ -23,6 +23,7 @@ stochastic_recon_loss, stochastic_recon_subset_loss, ) +from spd.models.batch_and_loss_fns import recon_loss_kl, recon_loss_mse from spd.models.component_model import ComponentModel from spd.persistent_pgd import PersistentPGDState, persistent_pgd_recon_loss from spd.utils.module_utils import ModulePathInfo @@ -66,7 +67,7 @@ def _make_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentModel target_model=target, module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -85,7 +86,7 @@ def _make_seq_component_model(weight: Float[Tensor, "d_out d_in"]) -> ComponentM target_model=target, module_path_info=[ModulePathInfo(module_path="fc", C=1)], ci_config=LayerwiseCiConfig(fn_type="mlp", hidden_dims=[2]), - pretrained_model_output_attr=None, + extract_tensor_output=None, sigmoid_type="leaky_hard", ) @@ -322,7 +323,7 @@ def test_mse_loss_basic(self: object) -> None: result = ci_masked_recon_loss( model=model, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -347,7 +348,7 @@ def test_kl_loss_basic(self: object) -> None: result = ci_masked_recon_loss( model=model, - output_loss_type="kl", + reconstruction_loss=recon_loss_kl, batch=batch, target_out=target_out, ci=ci, @@ -367,10 +368,18 @@ def test_different_ci_values_produce_different_losses(self: object) -> None: ci_half = {"fc": torch.tensor([[0.5]], dtype=torch.float32)} loss_full = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_full + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci_full, ) loss_half = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci_half + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci_half, ) # Different CI values should produce different losses @@ -389,7 +398,7 @@ def test_layerwise_basic(self: object) -> None: result = ci_masked_recon_layerwise_loss( model=model, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -409,10 +418,18 @@ def test_layerwise_vs_all_layer(self: object) -> None: ci = {"fc": torch.tensor([[1.0]], dtype=torch.float32)} loss_all = ci_masked_recon_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci, ) loss_layerwise = ci_masked_recon_layerwise_loss( - model=model, output_loss_type="mse", batch=batch, target_out=target_out, ci=ci + model=model, + reconstruction_loss=recon_loss_mse, + batch=batch, + target_out=target_out, + ci=ci, ) # For single layer, results should be the same @@ -431,7 +448,7 @@ def test_subset_basic(self: object) -> None: result = ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -454,7 +471,7 @@ def test_subset_stochastic_behavior(self: object) -> None: losses = [ ci_masked_recon_subset_loss( model=model, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -482,7 +499,7 @@ def test_continuous_sampling_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -505,7 +522,7 @@ def test_binomial_sampling_basic(self: object) -> None: model=model, sampling="binomial", n_mask_samples=3, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -530,7 +547,7 @@ def test_multiple_mask_samples(self: object) -> None: model=model, sampling="continuous", n_mask_samples=n_samples, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -552,7 +569,7 @@ def test_with_and_without_delta_component(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -563,7 +580,7 @@ def test_with_and_without_delta_component(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -590,7 +607,7 @@ def test_layerwise_stochastic_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -614,7 +631,7 @@ def test_layerwise_multiple_samples(self: object) -> None: model=model, sampling="continuous", n_mask_samples=n_samples, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -638,7 +655,7 @@ def test_subset_stochastic_basic(self: object) -> None: model=model, sampling="continuous", n_mask_samples=3, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -662,7 +679,7 @@ def test_subset_with_binomial_sampling(self: object) -> None: model=model, sampling="binomial", n_mask_samples=3, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, @@ -687,7 +704,7 @@ def test_subset_stochastic_variability(self: object) -> None: model=model, sampling="continuous", n_mask_samples=2, - output_loss_type="mse", + reconstruction_loss=recon_loss_mse, batch=batch, target_out=target_out, ci=ci, diff --git a/tests/test_tms.py b/tests/test_tms.py index 496451f94..019714cdd 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -19,6 +19,7 @@ from spd.experiments.tms.models import TMSModel from spd.experiments.tms.train_tms import get_model_and_dataloader, train from spd.identity_insertion import insert_identity_operations_ +from spd.models.batch_and_loss_fns import recon_loss_mse from spd.run_spd import optimize from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset from spd.utils.general_utils import set_seed @@ -68,7 +69,6 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: StochasticReconLossConfig(coeff=1.0), FaithfulnessLossConfig(coeff=1.0), ], - output_loss_type="mse", # Training lr_schedule=ScheduleConfig( start_val=1e-3, fn_type="cosine", warmup_pct=0.0, final_val_frac=0.0 @@ -91,7 +91,7 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: pretrained_model_class="spd.experiments.tms.models.TMSModel", pretrained_model_path=None, pretrained_model_name=None, - pretrained_model_output_attr=None, + extract_tensor_output=None, tokenizer_name=None, # Task Specific task_config=TMSTaskConfig( @@ -137,7 +137,7 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: device=device, train_loader=train_loader, eval_loader=eval_loader, - n_eval_steps=config.n_eval_steps, + reconstruction_loss=recon_loss_mse, out_dir=tmp_path, tied_weights=tied_weights, ) From af975707c6d3e8a06e921f002717d6ff1b0eba34 Mon Sep 17 00:00:00 2001 From: Claude SPD1 Date: Wed, 11 Feb 2026 22:19:12 +0000 Subject: [PATCH 2/2] wip: Remove batch extraction logic, assert tensor properties instead --- spd/eval.py | 3 +-- spd/experiments/ih/ih_decomposition.py | 9 ++++++-- .../resid_mlp/resid_mlp_decomposition.py | 5 ++-- spd/experiments/tms/tms_decomposition.py | 5 ++-- spd/run_spd.py | 23 +++++++++---------- spd/scripts/compare_models/compare_models.py | 6 +++++ spd/utils/data_utils.py | 12 ++++++---- tests/test_ih_transformer.py | 5 ++-- tests/test_resid_mlp.py | 5 ++-- tests/test_tms.py | 5 ++-- 10 files changed, 47 insertions(+), 31 deletions(-) diff --git a/spd/eval.py b/spd/eval.py index c0b472658..cf23bb935 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -345,8 +345,7 @@ def evaluate( weight_deltas = model.calc_weight_deltas() for _ in range(n_eval_steps): - batch_raw = next(eval_iterator) - batch = batch_raw[0] if isinstance(batch_raw, tuple) else batch_raw + batch: Any = next(eval_iterator) target_output: OutputWithCache = model(batch, cache_type="input") ci = model.calc_causal_importances( diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 847e28773..cc7a44040 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -91,8 +91,13 @@ def main( prefix_window=prefix_window, device=device, ) - train_loader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) - eval_loader = DatasetGeneratedDataLoader(dataset, batch_size=config.batch_size, shuffle=False) + extract_input = lambda batch: batch[0] + train_loader = DatasetGeneratedDataLoader( + dataset, batch_size=config.batch_size, shuffle=False, transform=extract_input + ) + eval_loader = DatasetGeneratedDataLoader( + dataset, batch_size=config.batch_size, shuffle=False, transform=extract_input + ) optimize( target_model=target_model, diff --git a/spd/experiments/resid_mlp/resid_mlp_decomposition.py b/spd/experiments/resid_mlp/resid_mlp_decomposition.py index e24f6c63d..4f981d6de 100644 --- a/spd/experiments/resid_mlp/resid_mlp_decomposition.py +++ b/spd/experiments/resid_mlp/resid_mlp_decomposition.py @@ -94,11 +94,12 @@ def main( synced_inputs=synced_inputs, ) + extract_input = lambda batch: batch[0] train_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.microbatch_size, shuffle=False + dataset, batch_size=config.microbatch_size, shuffle=False, transform=extract_input ) eval_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.eval_batch_size, shuffle=False + dataset, batch_size=config.eval_batch_size, shuffle=False, transform=extract_input ) optimize( diff --git a/spd/experiments/tms/tms_decomposition.py b/spd/experiments/tms/tms_decomposition.py index 2471c1830..3306b5725 100644 --- a/spd/experiments/tms/tms_decomposition.py +++ b/spd/experiments/tms/tms_decomposition.py @@ -88,11 +88,12 @@ def main( value_range=(0.0, 1.0), synced_inputs=synced_inputs, ) + extract_input = lambda batch: batch[0] train_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.microbatch_size, shuffle=False + dataset, batch_size=config.microbatch_size, shuffle=False, transform=extract_input ) eval_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.eval_batch_size, shuffle=False + dataset, batch_size=config.eval_batch_size, shuffle=False, transform=extract_input ) tied_weights = None diff --git a/spd/run_spd.py b/spd/run_spd.py index 3a937b9fe..2e781488c 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -11,7 +11,7 @@ import torch.nn.parallel import wandb from PIL import Image -from torch import optim +from torch import Tensor, optim from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from tqdm import tqdm @@ -32,7 +32,7 @@ from spd.log import logger from spd.losses import compute_losses from spd.metrics import faithfulness_loss -from spd.models.batch_and_loss_fns import ReconstructionLoss, recon_loss_mse +from spd.models.batch_and_loss_fns import ReconstructionLoss from spd.models.component_model import ComponentModel, OutputWithCache from spd.persistent_pgd import PersistentPGDState from spd.utils.component_utils import calc_ci_l_zero @@ -238,20 +238,20 @@ def create_pgd_data_iter() -> Iterator[Any]: ] = {} if persistent_pgd_configs: sample_batch_raw = next(train_iterator) - sample_batch = ( - sample_batch_raw[0] if isinstance(sample_batch_raw, tuple) else sample_batch_raw - ) - batch_dims = ( - sample_batch.shape[:-1] if reconstruction_loss is recon_loss_mse else sample_batch.shape - ) + assert isinstance(sample_batch_raw, Tensor) + assert sample_batch_raw.dtype == torch.long, "sample_batch_raw must be a long tensor" + assert sample_batch_raw.dim() == 2, "sample_batch_raw must be a 2D (batch, seq) tensor" + + batch_size, seq_len = sample_batch_raw.shape + ppgd_states = { ppgd_cfg: PersistentPGDState( module_to_c=model.module_to_c, - seq_len=batch_dims[-1], + seq_len=seq_len, device=device, use_delta_component=config.use_delta_component, cfg=ppgd_cfg, - batch_size=batch_dims[0], + batch_size=batch_size, ) for ppgd_cfg in persistent_pgd_configs } @@ -274,8 +274,7 @@ def create_pgd_data_iter() -> Iterator[Any]: } for _ in range(config.gradient_accumulation_steps): - microbatch_raw = next(train_iterator) - microbatch = microbatch_raw[0] if isinstance(microbatch_raw, tuple) else microbatch_raw + microbatch: Any = next(train_iterator) with bf16_autocast(enabled=config.autocast_bf16): # NOTE: we need to call the wrapped_model at least once each step in order diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index 115effc6c..f4d3b39ff 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -130,11 +130,13 @@ def _create_tms_data_loader(self) -> Iterator[Any]: value_range=(0.0, 1.0), synced_inputs=target_run_info.config.synced_inputs, ) + extract_input = lambda batch: batch[0] return iter( DatasetGeneratedDataLoader( dataset, batch_size=self.config.eval_batch_size, shuffle=self.config.shuffle_data, + transform=extract_input, ) ) @@ -164,11 +166,13 @@ def _create_resid_mlp_data_loader(self) -> Iterator[Any]: label_fn_seed=None, synced_inputs=target_run_info.config.synced_inputs, ) + extract_input = lambda batch: batch[0] return iter( DatasetGeneratedDataLoader( dataset, batch_size=self.config.eval_batch_size, shuffle=self.config.shuffle_data, + transform=extract_input, ) ) @@ -224,11 +228,13 @@ def _create_ih_data_loader(self) -> Iterator[Any]: or target_run_info.config.ih_model_config.seq_len - 3, device=self.device, ) + extract_input = lambda batch: batch[0] return iter( DatasetGeneratedDataLoader( dataset, batch_size=self.config.eval_batch_size, shuffle=self.config.shuffle_data, + transform=extract_input, ) ) diff --git a/spd/utils/data_utils.py b/spd/utils/data_utils.py index 90666d4f2..f7ef0052d 100644 --- a/spd/utils/data_utils.py +++ b/spd/utils/data_utils.py @@ -1,5 +1,5 @@ -from collections.abc import Iterator -from typing import Literal, override +from collections.abc import Callable, Iterator +from typing import Any, Literal, override import torch from jaxtyping import Float @@ -16,17 +16,19 @@ def __init__( batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, + transform: Callable[[Q], Any] | None = None, ): - # assert that dataset has a generate_batch method assert hasattr(dataset, "generate_batch") super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers) + self._transform = transform @override def __iter__( # pyright: ignore[reportIncompatibleMethodOverride] self, - ) -> Iterator[Q]: + ) -> Iterator[Any]: for _ in range(len(self)): - yield self.dataset.generate_batch(self.batch_size) # pyright: ignore[reportAttributeAccessIssue] + batch = self.dataset.generate_batch(self.batch_size) # pyright: ignore[reportAttributeAccessIssue] + yield self._transform(batch) if self._transform else batch class BatchedDataLoader[Q](DataLoader[Q]): diff --git a/tests/test_ih_transformer.py b/tests/test_ih_transformer.py index 43fa20657..04df3db9d 100644 --- a/tests/test_ih_transformer.py +++ b/tests/test_ih_transformer.py @@ -119,11 +119,12 @@ def test_ih_transformer_decomposition_happy_path(tmp_path: Path) -> None: prefix_window=ih_transformer_config.seq_len - 3, ) + extract_input = lambda batch: batch[0] train_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.microbatch_size, shuffle=False + dataset, batch_size=config.microbatch_size, shuffle=False, transform=extract_input ) eval_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.microbatch_size, shuffle=False + dataset, batch_size=config.microbatch_size, shuffle=False, transform=extract_input ) # Run optimize function diff --git a/tests/test_resid_mlp.py b/tests/test_resid_mlp.py index bd1a06aef..77a7ab46a 100644 --- a/tests/test_resid_mlp.py +++ b/tests/test_resid_mlp.py @@ -115,11 +115,12 @@ def test_resid_mlp_decomposition_happy_path(tmp_path: Path) -> None: synced_inputs=None, ) + extract_input = lambda batch: batch[0] train_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.microbatch_size, shuffle=False + dataset, batch_size=config.microbatch_size, shuffle=False, transform=extract_input ) eval_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.eval_batch_size, shuffle=False + dataset, batch_size=config.eval_batch_size, shuffle=False, transform=extract_input ) # Run optimize function diff --git a/tests/test_tms.py b/tests/test_tms.py index 019714cdd..d382eb2ed 100644 --- a/tests/test_tms.py +++ b/tests/test_tms.py @@ -119,11 +119,12 @@ def test_tms_decomposition_happy_path(tmp_path: Path) -> None: synced_inputs=None, ) + extract_input = lambda batch: batch[0] train_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.microbatch_size, shuffle=False + dataset, batch_size=config.microbatch_size, shuffle=False, transform=extract_input ) eval_loader = DatasetGeneratedDataLoader( - dataset, batch_size=config.microbatch_size, shuffle=False + dataset, batch_size=config.microbatch_size, shuffle=False, transform=extract_input ) tied_weights = None