Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions spd/app/backend/routers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,6 @@

from spd.app.backend.dependencies import DepLoadedRun, DepStateManager
from spd.app.backend.utils import log_errors
from spd.utils.distributed_utils import get_device

# TODO: Re-enable these endpoints when dependencies are available:
# - extract_active_from_ci from database
# - PromptSearchQuery, PromptSearchResponse from schemas
# - DatasetConfig, LMTaskConfig from configs
# - create_data_loader, extract_batch_data from data
# - logger from utils

DEVICE = get_device()

# =============================================================================
# Schemas
Expand Down
27 changes: 20 additions & 7 deletions spd/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,11 +744,6 @@ def all_module_info(self) -> list[ModulePatternInfoConfig]:
),
)
)
output_loss_type: Literal["mse", "kl"] = Field(
...,
description="Metric used to measure recon error between model outputs and targets",
)

# --- Training ---
lr_schedule: ScheduleConfig = Field(..., description="Learning rate schedule configuration")
steps: NonNegativeInt = Field(..., description="Total number of optimisation steps")
Expand Down Expand Up @@ -849,9 +844,11 @@ def microbatch_size(self) -> PositiveInt:
default=None,
description="hf model identifier. E.g. 'SimpleStories/SimpleStories-1.25M'",
)
pretrained_model_output_attr: str | None = Field(
extract_tensor_output: str | None = Field(
default=None,
description="Name of the attribute on the forward output that contains logits or activations",
description="Declarative accessor path for extracting tensor from model output. "
"None = raw output is the tensor. Examples: '.logits' for attribute access, "
"'[0]' for index access.",
)
tokenizer_name: str | None = Field(
default=None,
Expand Down Expand Up @@ -890,6 +887,7 @@ def microbatch_size(self) -> PositiveInt:
"lr_exponential_halflife",
"out_dir",
"n_examples_until_dead",
"output_loss_type",
]
RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = {
"grad_clip_norm": "grad_clip_norm_components",
Expand Down Expand Up @@ -934,6 +932,21 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str,
"simple_stories_train.models.", "spd.pretrain.models.", 1
)

# Migrate old pretrained_model_output_attr to extract_tensor_output
if "pretrained_model_output_attr" in config_dict:
old_val = config_dict.pop("pretrained_model_output_attr")
match old_val:
case None:
pass
case "idx_0":
config_dict["extract_tensor_output"] = "[0]"
case str(attr):
config_dict["extract_tensor_output"] = f".{attr}"
case _:
raise AssertionError(
f"Unexpected pretrained_model_output_attr value: {old_val}"
)

if "eval_batch_size" not in config_dict:
config_dict["eval_batch_size"] = config_dict["batch_size"]
if "train_log_freq" not in config_dict:
Expand Down
16 changes: 13 additions & 3 deletions spd/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections.abc import Callable, Generator
from typing import Any

import numpy as np
import torch
from datasets import Dataset, IterableDataset, load_dataset
from jaxtyping import Int
from numpy.typing import NDArray
from torch import Tensor
from torch.utils.data import DataLoader, DistributedSampler
Expand Down Expand Up @@ -152,7 +154,8 @@ def create_data_loader(
dist_state: DistributedState | None = None,
global_seed: int = 0,
to_lower: bool = True,
) -> tuple[DataLoader[Any], PreTrainedTokenizer]:
collate_fn: Callable[..., Any] | None = None,
) -> tuple[DataLoader[Int[Tensor, "batch seq"]], PreTrainedTokenizer]:
"""Create a DataLoader for the given dataset.

Uses PyTorch's DistributedSampler to ensure each rank gets the correct
Expand Down Expand Up @@ -255,7 +258,7 @@ def create_data_loader(
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)

loader = DataLoader[Dataset | IterableDataset](
loader = DataLoader[Int[Tensor, "batch seq"]](
torch_dataset, # pyright: ignore[reportArgumentType]
batch_size=batch_size,
sampler=sampler,
Expand All @@ -264,11 +267,17 @@ def create_data_loader(
),
drop_last=True,
generator=generator,
collate_fn=collate_fn,
)
return loader, tokenizer


def loop_dataloader[T](dl: DataLoader[T]):
def lm_collate_fn(batch: list[dict[str, Tensor]]) -> Tensor:
"""Collate function that extracts input_ids tensors from HuggingFace dataset dicts."""
return torch.stack([item["input_ids"] for item in batch])


def loop_dataloader[T](dl: DataLoader[T]) -> Generator[T]:
"""Loop over a dataloader, resetting the iterator when it is exhausted.

Ensures that each epoch gets different data, even when using a distributed sampler.
Expand Down Expand Up @@ -311,6 +320,7 @@ def train_loader_and_tokenizer(
batch_size=batch_size,
buffer_size=task_config.buffer_size,
global_seed=config.seed,
collate_fn=lm_collate_fn,
)

assert isinstance(tokenizer, PreTrainedTokenizerBase)
Expand Down
3 changes: 1 addition & 2 deletions spd/dataset_attributions/harvest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from spd.models.component_model import ComponentModel, SPDRunInfo
from spd.topology import TransformerTopology
from spd.utils.distributed_utils import get_device
from spd.utils.general_utils import extract_batch_data
from spd.utils.wandb_utils import parse_wandb_run_path


Expand Down Expand Up @@ -201,7 +200,7 @@ def harvest_attributions(
# Skip batches not assigned to this rank
if world_size is not None and batch_idx % world_size != rank:
continue
batch = extract_batch_data(batch_data).to(device)
batch = batch_data.to(device)
harvester.process_batch(batch)

logger.info(
Expand Down
90 changes: 48 additions & 42 deletions spd/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Iterator
from typing import Any

from jaxtyping import Float, Int
from jaxtyping import Float
from PIL import Image
from torch import Tensor
from torch.types import Number
Expand Down Expand Up @@ -39,35 +39,38 @@
UnmaskedReconLossConfig,
UVPlotsConfig,
)
from spd.metrics import UnmaskedReconLoss
from spd.metrics import (
CI_L0,
CEandKLLosses,
CIHistograms,
CIMaskedReconLayerwiseLoss,
CIMaskedReconLoss,
CIMaskedReconSubsetLoss,
CIMeanPerComponent,
ComponentActivationDensity,
FaithfulnessLoss,
IdentityCIError,
ImportanceMinimalityLoss,
PermutedCIPlots,
PGDReconLayerwiseLoss,
PGDReconLoss,
PGDReconSubsetLoss,
StochasticHiddenActsReconLoss,
StochasticReconLayerwiseLoss,
StochasticReconLoss,
StochasticReconSubsetCEAndKL,
StochasticReconSubsetLoss,
UnmaskedReconLoss,
UVPlots,
)
from spd.metrics.base import Metric
from spd.metrics.ce_and_kl_losses import CEandKLLosses
from spd.metrics.ci_histograms import CIHistograms
from spd.metrics.ci_l0 import CI_L0
from spd.metrics.ci_masked_recon_layerwise_loss import CIMaskedReconLayerwiseLoss
from spd.metrics.ci_masked_recon_loss import CIMaskedReconLoss
from spd.metrics.ci_masked_recon_subset_loss import CIMaskedReconSubsetLoss
from spd.metrics.ci_mean_per_component import CIMeanPerComponent
from spd.metrics.component_activation_density import ComponentActivationDensity
from spd.metrics.faithfulness_loss import FaithfulnessLoss
from spd.metrics.identity_ci_error import IdentityCIError
from spd.metrics.importance_minimality_loss import ImportanceMinimalityLoss
from spd.metrics.permuted_ci_plots import PermutedCIPlots
from spd.metrics.pgd_masked_recon_layerwise_loss import PGDReconLayerwiseLoss
from spd.metrics.pgd_masked_recon_loss import PGDReconLoss
from spd.metrics.pgd_masked_recon_subset_loss import PGDReconSubsetLoss
from spd.metrics.pgd_utils import CreateDataIter, calc_multibatch_pgd_masked_recon_loss
from spd.metrics.stochastic_hidden_acts_recon_loss import StochasticHiddenActsReconLoss
from spd.metrics.stochastic_recon_layerwise_loss import StochasticReconLayerwiseLoss
from spd.metrics.stochastic_recon_loss import StochasticReconLoss
from spd.metrics.stochastic_recon_subset_ce_and_kl import StochasticReconSubsetCEAndKL
from spd.metrics.stochastic_recon_subset_loss import StochasticReconSubsetLoss
from spd.metrics.uv_plots import UVPlots
from spd.models.batch_and_loss_fns import ReconstructionLoss, recon_loss_kl
from spd.models.component_model import ComponentModel, OutputWithCache
from spd.persistent_pgd import PersistentPGDReconLoss, PersistentPGDReconSubsetLoss, PPGDSources
from spd.routing import AllLayersRouter, get_subset_router
from spd.utils.distributed_utils import avg_metrics_across_ranks, is_distributed
from spd.utils.general_utils import dict_safe_update_, extract_batch_data
from spd.utils.general_utils import dict_safe_update_

MetricOutType = dict[str, str | Number | Image.Image | CustomChart]
DistMetricOutType = dict[str, str | float | Image.Image | CustomChart]
Expand Down Expand Up @@ -127,6 +130,7 @@ def init_metric(
],
run_config: Config,
device: str,
reconstruction_loss: ReconstructionLoss,
) -> Metric:
match cfg:
case ImportanceMinimalityLossConfig():
Expand Down Expand Up @@ -164,16 +168,16 @@ def init_metric(
metric = CIMaskedReconSubsetLoss(
model=model,
device=device,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
routing=cfg.routing,
)
case CIMaskedReconLayerwiseLossConfig():
metric = CIMaskedReconLayerwiseLoss(
model=model, device=device, output_loss_type=run_config.output_loss_type
model=model, device=device, reconstruction_loss=reconstruction_loss
)
case CIMaskedReconLossConfig():
metric = CIMaskedReconLoss(
model=model, device=device, output_loss_type=run_config.output_loss_type
model=model, device=device, reconstruction_loss=reconstruction_loss
)
case CIMeanPerComponentConfig():
metric = CIMeanPerComponent(model=model, device=device)
Expand Down Expand Up @@ -202,7 +206,7 @@ def init_metric(
sampling=run_config.sampling,
use_delta_component=run_config.use_delta_component,
n_mask_samples=run_config.n_mask_samples,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
)
case StochasticReconLossConfig():
metric = StochasticReconLoss(
Expand All @@ -211,7 +215,7 @@ def init_metric(
sampling=run_config.sampling,
use_delta_component=run_config.use_delta_component,
n_mask_samples=run_config.n_mask_samples,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
)
case StochasticReconSubsetLossConfig():
metric = StochasticReconSubsetLoss(
Expand All @@ -220,23 +224,23 @@ def init_metric(
sampling=run_config.sampling,
use_delta_component=run_config.use_delta_component,
n_mask_samples=run_config.n_mask_samples,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
routing=cfg.routing,
)
case PGDReconLossConfig():
metric = PGDReconLoss(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
pgd_config=cfg,
)
case PGDReconSubsetLossConfig():
metric = PGDReconSubsetLoss(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
pgd_config=cfg,
routing=cfg.routing,
)
Expand All @@ -245,7 +249,7 @@ def init_metric(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
pgd_config=cfg,
)
case StochasticReconSubsetCEAndKLConfig():
Expand Down Expand Up @@ -277,23 +281,25 @@ def init_metric(
metric = UnmaskedReconLoss(
model=model,
device=device,
output_loss_type=run_config.output_loss_type,
reconstruction_loss=reconstruction_loss,
)

case PersistentPGDReconLossConfig():
ppgd_output_loss_type = "kl" if reconstruction_loss is recon_loss_kl else "mse"
metric = PersistentPGDReconLoss(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
output_loss_type=ppgd_output_loss_type,
ppgd_sources=ppgd_sourcess[cfg],
)
case PersistentPGDReconSubsetLossConfig():
ppgd_output_loss_type = "kl" if reconstruction_loss is recon_loss_kl else "mse"
metric = PersistentPGDReconSubsetLoss(
model=model,
device=device,
use_delta_component=run_config.use_delta_component,
output_loss_type=run_config.output_loss_type,
output_loss_type=ppgd_output_loss_type,
ppgd_sources=ppgd_sourcess[cfg],
routing=cfg.routing,
)
Expand All @@ -307,7 +313,7 @@ def init_metric(
def evaluate(
eval_metric_configs: list[MetricConfigType],
model: ComponentModel,
eval_iterator: Iterator[Int[Tensor, "..."] | tuple[Float[Tensor, "..."], Float[Tensor, "..."]]],
eval_iterator: Iterator[Any],
ppgd_sourcess: dict[
PersistentPGDReconLossConfig | PersistentPGDReconSubsetLossConfig,
dict[str, Float[Tensor, " source_c"]],
Expand All @@ -317,6 +323,7 @@ def evaluate(
slow_step: bool,
n_eval_steps: int,
current_frac_of_training: float,
reconstruction_loss: ReconstructionLoss,
) -> MetricOutType:
"""Run evaluation and return a mapping of metric names to values/images."""

Expand All @@ -328,6 +335,7 @@ def evaluate(
ppgd_sourcess=ppgd_sourcess,
run_config=run_config,
device=device,
reconstruction_loss=reconstruction_loss,
)
if metric.slow and not slow_step:
continue
Expand All @@ -337,8 +345,7 @@ def evaluate(
weight_deltas = model.calc_weight_deltas()

for _ in range(n_eval_steps):
batch_raw = next(eval_iterator)
batch = extract_batch_data(batch_raw).to(device)
batch: Any = next(eval_iterator)

target_output: OutputWithCache = model(batch, cache_type="input")
ci = model.calc_causal_importances(
Expand Down Expand Up @@ -377,8 +384,8 @@ def evaluate_multibatch_pgd(
model: ComponentModel,
create_data_iter: CreateDataIter,
config: Config,
batch_dims: tuple[int, ...],
device: str,
reconstruction_loss: ReconstructionLoss,
) -> dict[str, float]:
"""Calculate multibatch PGD metrics."""
weight_deltas = model.calc_weight_deltas() if config.use_delta_component else None
Expand All @@ -400,11 +407,10 @@ def evaluate_multibatch_pgd(
model=model,
weight_deltas=weight_deltas,
create_data_iter=create_data_iter,
output_loss_type=config.output_loss_type,
router=router,
sampling=config.sampling,
use_delta_component=config.use_delta_component,
batch_dims=batch_dims,
device=device,
reconstruction_loss=reconstruction_loss,
).item()
return metrics
1 change: 0 additions & 1 deletion spd/experiments/ih/ih_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ ci_recon_layerwise_coeff: null
stochastic_recon_layerwise_coeff: 1
importance_minimality_coeff: 1e-2
pnorm: 0.1
output_loss_type: kl
ci_config:
mode: layerwise
fn_type: vector_mlp
Expand Down
Loading
Loading