From b93b9d6775878321ba771e1994f9d995b4ce08bb Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Tue, 16 Sep 2025 18:07:08 +0000 Subject: [PATCH 01/18] Geometric similarity comparison made consistent with other evals and tested --- spd/eval.py | 159 ++++++++++++++++++++++++++++++++++++++++++++++++ spd/registry.py | 6 ++ 2 files changed, 165 insertions(+) diff --git a/spd/eval.py b/spd/eval.py index 6237ef7bd..0c5ec6f30 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -19,6 +19,7 @@ from torch import Tensor from spd.configs import Config +from spd.log import logger from spd.models.component_model import ComponentModel from spd.plotting import ( get_single_feature_causal_importances, @@ -671,6 +672,163 @@ def compute(self) -> Mapping[str, float]: return results +class GeometricSimilarityComparison(StreamingEval): + SLOW = True # This involves loading another model, so it's slow + + def __init__(self, model: ComponentModel, config: Config, **kwargs: Any): + self.model = model + self.config = config + self.reference_run_path = kwargs.get("reference_run_path") + if self.reference_run_path is None: + raise ValueError("reference_run_path is required for GeometricSimilarityComparison") + self.kwargs = kwargs + self.reference_model: ComponentModel | None = None + self._computed_this_eval = False + self.device = next(iter(model.parameters())).device + self.n_tokens = 0 + self.component_activation_counts: dict[str, Float[Tensor, " C"]] = { + module_name: torch.zeros(model.C, device=self.device) + for module_name in model.components + } + + def _load_reference_model(self) -> ComponentModel: + """Load the reference model from wandb or local path""" + if self.reference_model is None: + from spd.models.component_model import ComponentModel + + assert self.reference_run_path is not None, ( + "reference_run_path should not be None at this point" + ) + self.reference_model = ComponentModel.from_pretrained(self.reference_run_path) + + if torch.cuda.is_available(): + self.reference_model.to("cuda") + self.reference_model.eval() + self.reference_model.requires_grad_(False) + + return self.reference_model + + def _compute_subcomponent_geometric_similarities( + self, activation_densities: dict[str, Float[Tensor, " C"]] + ) -> dict[str, float]: + """Compute mean max cosine similarity between subcomponent rank-one matrices""" + reference_model = self._load_reference_model() + similarities = {} + + # Iterate through all component layers in both models + for layer_name in self.model.components: + if layer_name not in reference_model.components: + logger.warning(f"Layer {layer_name} not found in reference model, skipping") + continue + + current_components = self.model.components[layer_name] + reference_components = reference_model.components[layer_name] + + # Verify component counts match + if current_components.C != reference_components.C: + logger.warning( + f"Component count mismatch for {layer_name}: {current_components.C} vs {reference_components.C}" + ) + continue + + # Extract U and V matrices + C = current_components.C + current_U = current_components.U # Shape: [C, d_out] + current_V = current_components.V # Shape: [d_in, C] + ref_U = reference_components.U + ref_V = reference_components.V + + # Throw away components that are not active enough in the current model + density_threshold = self.kwargs.get("density_threshold", 0.0) + C_alive = sum(activation_densities[layer_name] > density_threshold) + if C_alive == 0: + logger.warning( + f"\n WARNING:No components are active enough in {layer_name} for density threshold {density_threshold}. Geometric similarity comparison failed to run. \n" + ) + continue + current_V = current_V[:, activation_densities[layer_name] > density_threshold] + current_U = current_U[activation_densities[layer_name] > density_threshold] + + # Compute rank-one matrices: V @ U for each component + # Each component c produces a rank-one matrix of shape [d_in, d_out] + current_rank_one = einops.einsum( + current_V, current_U, "d_in C_alive, C_alive d_out -> C_alive d_in d_out" + ) + ref_rank_one = einops.einsum(ref_V, ref_U, "d_in C, C d_out -> C d_in d_out") + + # Flatten to vectors for cosine similarity computation + current_flat = current_rank_one.reshape(C_alive, -1) + ref_flat = ref_rank_one.reshape(C, -1) + + # Compute cosine similarities between all pairs + current_norm = F.normalize(current_flat, p=2, dim=1) + ref_norm = F.normalize(ref_flat, p=2, dim=1) + + cosine_sim_matrix = einops.einsum( + current_norm, ref_norm, "C_alive d_in_d_out, C_ref d_in_d_out -> C_alive C_ref" + ) + + # Find max cosine similarity for each current component + max_similarities = cosine_sim_matrix.max(dim=1).values + similarities[f"mean_max_cosine_sim/{layer_name}"] = max_similarities.mean().item() + similarities[f"max_cosine_sim_std/{layer_name}"] = max_similarities.std().item() + similarities[f"max_cosine_sim_min/{layer_name}"] = max_similarities.min().item() + similarities[f"max_cosine_sim_max/{layer_name}"] = max_similarities.max().item() + + # Compute a metrics across all model components for each type of metric + # First get the metric names by stripping away the layer name + metric_names = [ + "mean_max_cosine_sim", + "max_cosine_sim_std", + "max_cosine_sim_min", + "max_cosine_sim_max", + ] + + for metric_name in metric_names: + # Go through all layers and get the average of the metric + values = [ + similarities[f"{metric_name}/{layer_name}"] for layer_name in self.model.components + ] + similarities[f"{metric_name}/all_layers"] = sum(values) / len(values) + + return similarities + + @override + def watch_batch( + self, + batch: Int[Tensor, "..."] | Float[Tensor, "..."], + target_out: Float[Tensor, "... vocab"], + ci: dict[str, Float[Tensor, "... C"]], + ) -> None: + n_tokens = next(iter(ci.values())).shape[:-1].numel() + self.n_tokens += n_tokens + + for module_name, ci_vals in ci.items(): + active_components = ci_vals > self.config.ci_alive_threshold + n_activations_per_component = reduce(active_components, "... C -> C", "sum") + self.component_activation_counts[module_name] += n_activations_per_component + + @override + def compute(self) -> Mapping[str, float]: + """Compute the geometric similarity metrics""" + + activation_densities = { + module_name: self.component_activation_counts[module_name] / self.n_tokens + for module_name in self.model.components + } + + if self._computed_this_eval: + return {} + + try: + similarities = self._compute_subcomponent_geometric_similarities(activation_densities) + self._computed_this_eval = True + return similarities + except Exception as e: + logger.warning(f"Failed to compute geometric similarity comparison: {e}") + return {} + + EVAL_CLASSES = { cls.__name__: cls for cls in [ @@ -683,6 +841,7 @@ def compute(self) -> Mapping[str, float]: IdentityCIError, CIMeanPerComponent, SubsetReconstructionLoss, + GeometricSimilarityComparison, ] } diff --git a/spd/registry.py b/spd/registry.py index 1939a6f7b..52e665253 100644 --- a/spd/registry.py +++ b/spd/registry.py @@ -116,6 +116,12 @@ class ExperimentConfig: config_path=Path("spd/experiments/lm/ss_gpt2_simple_noln_config.yaml"), expected_runtime=330, ), + "ss_llama_single_with_comparison": ExperimentConfig( + task_name="lm", + decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), + config_path=Path("spd/experiments/lm/ss_llama_single_with_comparison_config.yaml"), + expected_runtime=60 * 94, # Same as ss_llama_single + ), # "ss_emb": ExperimentConfig( # task_name="lm", # decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), From cd5fda27aefbb9ad90701abe7a7592059e014ba7 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Wed, 17 Sep 2025 11:53:45 +0000 Subject: [PATCH 02/18] Replaced mean max cosine sim with mean max ABS cosine sim --- spd/eval.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/spd/eval.py b/spd/eval.py index 0c5ec6f30..675d7c8d8 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -767,21 +767,23 @@ def _compute_subcomponent_geometric_similarities( cosine_sim_matrix = einops.einsum( current_norm, ref_norm, "C_alive d_in_d_out, C_ref d_in_d_out -> C_alive C_ref" ) + # Take the abs of the cosine similarity matrix + cosine_sim_matrix = cosine_sim_matrix.abs() - # Find max cosine similarity for each current component + # Find max abs cosine similarity for each current component max_similarities = cosine_sim_matrix.max(dim=1).values - similarities[f"mean_max_cosine_sim/{layer_name}"] = max_similarities.mean().item() - similarities[f"max_cosine_sim_std/{layer_name}"] = max_similarities.std().item() - similarities[f"max_cosine_sim_min/{layer_name}"] = max_similarities.min().item() - similarities[f"max_cosine_sim_max/{layer_name}"] = max_similarities.max().item() + similarities[f"mean_max_abs_cosine_sim/{layer_name}"] = max_similarities.mean().item() + similarities[f"max_abs_cosine_sim_std/{layer_name}"] = max_similarities.std().item() + similarities[f"max_abs_cosine_sim_min/{layer_name}"] = max_similarities.min().item() + similarities[f"max_abs_cosine_sim_max/{layer_name}"] = max_similarities.max().item() # Compute a metrics across all model components for each type of metric # First get the metric names by stripping away the layer name metric_names = [ - "mean_max_cosine_sim", - "max_cosine_sim_std", - "max_cosine_sim_min", - "max_cosine_sim_max", + "mean_max_abs_cosine_sim", + "max_abs_cosine_sim_std", + "max_abs_cosine_sim_min", + "max_abs_cosine_sim_max", ] for metric_name in metric_names: From 61d340881531068e13920208f74929527f502f97 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Wed, 17 Sep 2025 12:00:00 +0000 Subject: [PATCH 03/18] Configs for geom comparison runs --- spd/experiments/lm/ss_gpt2_simple_config.yaml | 2 +- .../lm/ss_llama_single_gpu_config.yaml | 2 +- ...s_llama_single_with_comparison_config.yaml | 124 ++++++++++++++++++ .../resid_mlp2_geom_comparison_config.yaml | 82 ++++++++++++ .../tms_5-2-id_geom_comparison_config.yaml | 81 ++++++++++++ .../tms/tms_5-2_geom_comparison_config.yaml | 75 +++++++++++ 6 files changed, 364 insertions(+), 2 deletions(-) create mode 100644 spd/experiments/lm/ss_llama_single_with_comparison_config.yaml create mode 100644 spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml create mode 100644 spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml create mode 100644 spd/experiments/tms/tms_5-2_geom_comparison_config.yaml diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index 48ed97713..9201c5d81 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -42,7 +42,7 @@ eval_freq: 1000 slow_eval_freq: 5000 slow_eval_on_first_step: true n_eval_steps: 5 -save_freq: null +save_freq: 1000 ci_alive_threshold: 0.0 n_examples_until_dead: 3_276_800 # train_log_freq * batch_size * max_seq_len = 100 * 64 * 512 eval_metrics: diff --git a/spd/experiments/lm/ss_llama_single_gpu_config.yaml b/spd/experiments/lm/ss_llama_single_gpu_config.yaml index 6410a2611..f1c40b164 100644 --- a/spd/experiments/lm/ss_llama_single_gpu_config.yaml +++ b/spd/experiments/lm/ss_llama_single_gpu_config.yaml @@ -46,7 +46,7 @@ eval_freq: 1000 slow_eval_freq: 5000 slow_eval_on_first_step: true n_eval_steps: 5 -save_freq: null +save_freq: 100 ci_alive_threshold: 0.0 n_examples_until_dead: 1368400 eval_metrics: diff --git a/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml b/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml new file mode 100644 index 000000000..78e395990 --- /dev/null +++ b/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml @@ -0,0 +1,124 @@ +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 1 +C: 4000 +n_mask_samples: 1 +gate_type: "vector_mlp" +gate_hidden_dims: [12] +sigmoid_type: "leaky_hard" +target_module_patterns: ["model.layers.*.mlp.gate_proj", "model.layers.*.mlp.down_proj", "model.layers.*.mlp.up_proj", "model.layers.*.self_attn.q_proj", "model.layers.*.self_attn.k_proj", "model.layers.*.self_attn.v_proj", "model.layers.*.self_attn.o_proj"] +sampling: "binomial" + +# --- Loss Coefficients --- +faithfulness_coeff: 10000000.0 +recon_coeff: null +stochastic_recon_coeff: 1.0 +recon_layerwise_coeff: null +stochastic_recon_layerwise_coeff: 1.0 +importance_minimality_coeff: 0.0003 +schatten_coeff: null +out_recon_coeff: null +embedding_recon_coeff: null +is_embed_unembed_recon: false +pnorm: 2.0 +p_anneal_start_frac: 0.0 +p_anneal_final_p: 0.1 +p_anneal_end_frac: 1.0 +output_loss_type: kl + +# --- Training --- +batch_size: 12 +eval_batch_size: 12 +steps: 300000 +lr: 0.0005 +lr_schedule: cosine +lr_warmup_pct: 0.0 +lr_exponential_halflife: null +gradient_accumulation_steps: 4 + +# --- Logging & Saving --- +train_log_freq: 100 +eval_freq: 100 +slow_eval_freq: 100 +slow_eval_on_first_step: true +n_eval_steps: 5 +save_freq: 100 +ci_alive_threshold: 0.0 +n_examples_until_dead: 1368400 +eval_metrics: + - classname: "CIHistograms" + extra_init_kwargs: + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + extra_init_kwargs: {} + - classname: "CI_L0" + extra_init_kwargs: + groups: + total: ["*"] # Sum of all L0 values + layer_0: ["model.layers.0.*"] + layer_1: ["model.layers.1.*"] + layer_2: ["model.layers.2.*"] + layer_3: ["model.layers.3.*"] + - classname: "CEandKLLosses" + extra_init_kwargs: + rounding_threshold: 0.0 + - classname: "SubsetReconstructionLoss" + extra_init_kwargs: + n_mask_samples: 1 + use_all_ones_for_non_replaced: false + include_patterns: + layer_0_only: ["model.layers.0.*"] + layer_1_only: ["model.layers.1.*"] + layer_2_only: ["model.layers.2.*"] + layer_3_only: ["model.layers.3.*"] + mlp_only: ["*.mlp.*"] + attention_only: ["*.self_attn.*"] + exclude_patterns: + all_but_layer_0: ["model.layers.0.*"] + all_but_layer_1: ["model.layers.1.*"] + all_but_layer_2: ["model.layers.2.*"] + all_but_layer_3: ["model.layers.3.*"] + - classname: "GeometricSimilarityComparison" + extra_init_kwargs: + reference_run_path: "wandb:goodfire/spd/runs/2js1ccon" + density_threshold: 0.001 + + +# --- Pretrained model info --- +pretrained_model_class: transformers.LlamaForCausalLM +pretrained_model_name: SimpleStories/SimpleStories-1.25M +pretrained_model_path: null +pretrained_model_output_attr: logits +tokenizer_name: SimpleStories/SimpleStories-1.25M + +# --- Task Specific --- +task_config: + task_name: lm + max_seq_len: 512 + buffer_size: 1000 + dataset_name: "SimpleStories/SimpleStories" + column_name: "story" + train_data_split: "train" + eval_data_split: "test" + shuffle_each_epoch: true + is_tokenized: false + streaming: false + + +# Config details for the target model taken from https://github.com/danbraunai/simple_stories_train/blob/main/simple_stories_train/models/model_configs.py#L54 + # "1.25M": LlamaConfig( + # block_size=512, + # vocab_size=4096, + # n_layer=4, + # n_head=4, + # n_embd=128, + # n_intermediate=128 * 4 * 2 // 3 = 341, + # rotary_dim=128 // 4 = 32, + # n_ctx=512, + # n_key_value_heads=2, + # flash_attention=True, + # ), diff --git a/spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml new file mode 100644 index 000000000..4d601862a --- /dev/null +++ b/spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml @@ -0,0 +1,82 @@ +# ResidualMLP 2 layers with Geometric Comparison +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 0 +C: 400 +n_mask_samples: 1 +gate_type: "mlp" +gate_hidden_dims: [16] +sigmoid_type: "leaky_hard" +target_module_patterns: + - "layers.*.mlp_in" + - "layers.*.mlp_out" + +# --- Loss Coefficients --- +faithfulness_coeff: 1.0 +out_recon_coeff: 0.0 +recon_coeff: null +stochastic_recon_coeff: 1.0 +recon_layerwise_coeff: null +stochastic_recon_layerwise_coeff: 1.0 +importance_minimality_coeff: 1e-5 +pnorm: 2 +output_loss_type: mse + +# --- Training --- +batch_size: 2048 +eval_batch_size: 2048 +steps: 50_000 +lr: 1e-3 +lr_schedule: constant +lr_warmup_pct: 0.00 + +# --- Logging & Saving --- +train_log_freq: 50 +eval_freq: 500 +n_eval_steps: 100 +slow_eval_freq: 5_000 +slow_eval_on_first_step: true +save_freq: null +ci_alive_threshold: 0.1 +n_examples_until_dead: 1_024_000 +eval_metrics: + - classname: "CIHistograms" + extra_init_kwargs: + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + - classname: "PermutedCIPlots" + extra_init_kwargs: + identity_patterns: ["layers.*.mlp_in"] + dense_patterns: ["layers.*.mlp_out"] + - classname: "UVPlots" + extra_init_kwargs: + identity_patterns: ["layers.*.mlp_in"] + dense_patterns: ["layers.*.mlp_out"] + - classname: "IdentityCIError" + extra_init_kwargs: + identity_ci: + - layer_pattern: "layers.*.mlp_in" + n_features: 100 + dense_ci: + - layer_pattern: "layers.*.mlp_out" + k: 25 + - classname: "CI_L0" + - classname: "CIMeanPerComponent" + - classname: "GeometricSimilarityComparison" + extra_init_kwargs: + reference_run_path: "wandb:goodfire/spd/runs/nr085xlx" + density_threshold: 0.001 + +# --- Pretrained model info --- +pretrained_model_class: "spd.experiments.resid_mlp.models.ResidMLP" +pretrained_model_path: "wandb:goodfire/spd/runs/any9ekl9" + +# --- Task Specific --- +task_config: + task_name: resid_mlp + feature_probability: 0.01 + data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml b/spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml new file mode 100644 index 000000000..6a5c93456 --- /dev/null +++ b/spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml @@ -0,0 +1,81 @@ +# TMS 5-2 w/ fixed identity with Geometric Comparison +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 0 +C: 20 +n_mask_samples: 1 +gate_type: "mlp" +gate_hidden_dims: [16] +sigmoid_type: "leaky_hard" +target_module_patterns: ["linear1", "linear2", "hidden_layers.0"] + +# --- Loss Coefficients --- +faithfulness_coeff: 1.0 +recon_coeff: null +stochastic_recon_coeff: 1 +recon_layerwise_coeff: null +stochastic_recon_layerwise_coeff: 1.0 +importance_minimality_coeff: 3e-3 +pnorm: 1.0 +output_loss_type: mse + +# --- Training --- +batch_size: 4096 +eval_batch_size: 4096 +steps: 40_000 +lr: 1e-3 +lr_schedule: cosine +lr_warmup_pct: 0.0 + +# --- Logging & Saving --- +train_log_freq: 100 +eval_freq: 1000 +n_eval_steps: 100 +slow_eval_freq: 5_000 +slow_eval_on_first_step: true +save_freq: null +ci_alive_threshold: 0.1 +n_examples_until_dead: 4_096_000 +eval_metrics: + - classname: "CIHistograms" + extra_init_kwargs: + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + - classname: "PermutedCIPlots" + extra_init_kwargs: + identity_patterns: ["linear1", "linear2"] + dense_patterns: ["hidden_layers.0"] + - classname: "UVPlots" + extra_init_kwargs: + identity_patterns: ["linear1", "linear2"] + dense_patterns: ["hidden_layers.0"] + - classname: "IdentityCIError" + extra_init_kwargs: + identity_ci: + - layer_pattern: "linear1" + n_features: 5 + - layer_pattern: "linear2" + n_features: 5 + dense_ci: + - layer_pattern: "hidden_layers.0" + k: 2 + - classname: "CI_L0" + - classname: "CIMeanPerComponent" + - classname: "GeometricSimilarityComparison" + extra_init_kwargs: + reference_run_path: "wandb:goodfire/spd/runs/swr68dli" + density_threshold: 0.001 + +# --- Pretrained model info --- +pretrained_model_class: "spd.experiments.tms.models.TMSModel" +pretrained_model_path: "wandb:goodfire/spd/runs/gfgchmxj" # 1 hidden w/fixed identity + +# --- Task Specific --- +task_config: + task_name: tms + feature_probability: 0.05 + data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml b/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml new file mode 100644 index 000000000..473470db7 --- /dev/null +++ b/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml @@ -0,0 +1,75 @@ +# TMS 5-2 (Non-identity) with Geometric Comparison +# --- WandB --- +wandb_project: spd +wandb_run_name: null +wandb_run_name_prefix: "" + +# --- General --- +seed: 1 +C: 20 +n_mask_samples: 1 +gate_type: "mlp" +gate_hidden_dims: [16] +sigmoid_type: "leaky_hard" +target_module_patterns: ["linear1", "linear2"] + +# --- Loss Coefficients --- +faithfulness_coeff: 1.0 +recon_coeff: null +stochastic_recon_coeff: 1 +recon_layerwise_coeff: null +stochastic_recon_layerwise_coeff: 1.0 +importance_minimality_coeff: 3e-3 +pnorm: 1.0 +output_loss_type: mse + +# --- Training --- +batch_size: 4096 +eval_batch_size: 4096 +steps: 40_000 +lr: 1e-3 +lr_schedule: cosine +lr_warmup_pct: 0.0 + +# --- Logging & Saving --- +train_log_freq: 100 +eval_freq: 1000 +n_eval_steps: 100 +slow_eval_freq: 5_000 +slow_eval_on_first_step: true +save_freq: null +ci_alive_threshold: 0.1 +n_examples_until_dead: 4_096_000 +eval_metrics: + - classname: "CIHistograms" + extra_init_kwargs: + n_batches_accum: 5 + - classname: "ComponentActivationDensity" + - classname: "PermutedCIPlots" + extra_init_kwargs: + identity_patterns: ["linear1", "linear2"] + - classname: "UVPlots" + extra_init_kwargs: + identity_patterns: ["linear1", "linear2"] + - classname: "IdentityCIError" + extra_init_kwargs: + identity_ci: + - layer_pattern: "linear1" + n_features: 5 + - layer_pattern: "linear2" + n_features: 5 + - classname: "CI_L0" + - classname: "CIMeanPerComponent" + - classname: "GeometricSimilarityComparison" + extra_init_kwargs: + reference_run_path: "wandb:goodfire/spd/runs/7ngt0c8d" + density_threshold: 0.001 +# --- Pretrained model info --- +pretrained_model_class: "spd.experiments.tms.models.TMSModel" +pretrained_model_path: "wandb:goodfire/spd/runs/0hsp07o4" + +# --- Task Specific --- +task_config: + task_name: tms + feature_probability: 0.05 + data_generation_type: "at_least_zero_active" From 770a5c55150fea34d2466af5559ce68cc3f82048 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Wed, 17 Sep 2025 13:28:37 +0000 Subject: [PATCH 04/18] Minor modifications to make PR-ready --- spd/eval.py | 2 +- spd/experiments/lm/ss_gpt2_simple_config.yaml | 2 +- .../lm/ss_llama_single_gpu_config.yaml | 2 +- ...ss_llama_single_with_comparison_config.yaml | 2 +- spd/registry.py | 18 ++++++++++++++++++ 5 files changed, 22 insertions(+), 4 deletions(-) diff --git a/spd/eval.py b/spd/eval.py index 38ac2ff34..5fcc45938 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -746,7 +746,7 @@ def compute(self) -> Mapping[str, float]: class GeometricSimilarityComparison(StreamingEval): - SLOW = True # This involves loading another model, so it's slow + SLOW = True def __init__(self, model: ComponentModel, config: Config, **kwargs: Any): self.model = model diff --git a/spd/experiments/lm/ss_gpt2_simple_config.yaml b/spd/experiments/lm/ss_gpt2_simple_config.yaml index 49d93ddf4..06738248d 100644 --- a/spd/experiments/lm/ss_gpt2_simple_config.yaml +++ b/spd/experiments/lm/ss_gpt2_simple_config.yaml @@ -43,7 +43,7 @@ eval_freq: 1000 slow_eval_freq: 5000 slow_eval_on_first_step: true n_eval_steps: 5 -save_freq: 1000 +save_freq: null ci_alive_threshold: 0.0 n_examples_until_dead: 3_276_800 # train_log_freq * batch_size * max_seq_len = 100 * 64 * 512 eval_metrics: diff --git a/spd/experiments/lm/ss_llama_single_gpu_config.yaml b/spd/experiments/lm/ss_llama_single_gpu_config.yaml index 1820f6808..3395c5e69 100644 --- a/spd/experiments/lm/ss_llama_single_gpu_config.yaml +++ b/spd/experiments/lm/ss_llama_single_gpu_config.yaml @@ -47,7 +47,7 @@ eval_freq: 1000 slow_eval_freq: 5000 slow_eval_on_first_step: true n_eval_steps: 5 -save_freq: 100 +save_freq: null ci_alive_threshold: 0.0 n_examples_until_dead: 1368400 eval_metrics: diff --git a/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml b/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml index 78e395990..877eee951 100644 --- a/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml +++ b/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml @@ -46,7 +46,7 @@ eval_freq: 100 slow_eval_freq: 100 slow_eval_on_first_step: true n_eval_steps: 5 -save_freq: 100 +save_freq: null ci_alive_threshold: 0.0 n_examples_until_dead: 1368400 eval_metrics: diff --git a/spd/registry.py b/spd/registry.py index 07adbfe3e..94a28295b 100644 --- a/spd/registry.py +++ b/spd/registry.py @@ -122,6 +122,24 @@ class ExperimentConfig: config_path=Path("spd/experiments/lm/ss_llama_single_with_comparison_config.yaml"), expected_runtime=60 * 94, # Same as ss_llama_single ), + "tms_5-2_geom_comparison": ExperimentConfig( + task_name="tms", + decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), + config_path=Path("spd/experiments/tms/tms_5-2_geom_comparison_config.yaml"), + expected_runtime=4, # Same as tms_5-2 + ), + "tms_5-2-id_geom_comparison": ExperimentConfig( + task_name="tms", + decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), + config_path=Path("spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml"), + expected_runtime=4, # Same as tms_5-2-id + ), + "resid_mlp2_geom_comparison": ExperimentConfig( + task_name="resid_mlp", + decomp_script=Path("spd/experiments/resid_mlp/resid_mlp_decomposition.py"), + config_path=Path("spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml"), + expected_runtime=5, # Same as resid_mlp2 + ), # "ss_emb": ExperimentConfig( # task_name="lm", # decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), From 364198e2166110c2fc77dada7c2769867213c2f1 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Wed, 17 Sep 2025 15:18:00 +0000 Subject: [PATCH 05/18] Update seed to be consistent with other configs again --- spd/experiments/lm/ss_llama_single_with_comparison_config.yaml | 2 +- spd/experiments/tms/tms_5-2_geom_comparison_config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml b/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml index 877eee951..7d6e49afb 100644 --- a/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml +++ b/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml @@ -4,7 +4,7 @@ wandb_run_name: null wandb_run_name_prefix: "" # --- General --- -seed: 1 +seed: 0 C: 4000 n_mask_samples: 1 gate_type: "vector_mlp" diff --git a/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml b/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml index 473470db7..78e040791 100644 --- a/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml +++ b/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml @@ -5,7 +5,7 @@ wandb_run_name: null wandb_run_name_prefix: "" # --- General --- -seed: 1 +seed: 0 C: 20 n_mask_samples: 1 gate_type: "mlp" From 57c2c76677233bdcb19b7053a7934a38a7144b52 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Thu, 18 Sep 2025 13:26:28 +0000 Subject: [PATCH 06/18] Cleaned up some comments and other bits --- spd/eval.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/spd/eval.py b/spd/eval.py index 2a7caf2c4..f26183e8e 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -767,7 +767,6 @@ def __init__(self, model: ComponentModel, config: Config, **kwargs: Any): raise ValueError("reference_run_path is required for GeometricSimilarityComparison") self.kwargs = kwargs self.reference_model: ComponentModel | None = None - self._computed_this_eval = False self.device = next(iter(model.parameters())).device self.n_tokens = 0 self.component_activation_counts: dict[str, Float[Tensor, " C"]] = { @@ -903,12 +902,8 @@ def compute(self) -> Mapping[str, float]: for module_name in self.model.components } - if self._computed_this_eval: - return {} - try: similarities = self._compute_subcomponent_geometric_similarities(activation_densities) - self._computed_this_eval = True return similarities except Exception as e: logger.warning(f"Failed to compute geometric similarity comparison: {e}") From 2e7752d0f99955ecad75b56dc473c6bc68ea333d Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Thu, 18 Sep 2025 16:19:02 +0000 Subject: [PATCH 07/18] Major update of PR following review: Now implemented as script rather than eval --- README.md | 10 + spd/configs.py | 14 + spd/eval.py | 156 ------- ...s_llama_single_with_comparison_config.yaml | 124 ------ .../resid_mlp2_geom_comparison_config.yaml | 82 ---- .../tms_5-2-id_geom_comparison_config.yaml | 81 ---- .../tms/tms_5-2_geom_comparison_config.yaml | 75 ---- spd/scripts/compare_models.py | 393 ++++++++++++++++++ spd/scripts/compare_models_config.yaml | 20 + 9 files changed, 437 insertions(+), 518 deletions(-) delete mode 100644 spd/experiments/lm/ss_llama_single_with_comparison_config.yaml delete mode 100644 spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml delete mode 100644 spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml delete mode 100644 spd/experiments/tms/tms_5-2_geom_comparison_config.yaml create mode 100644 spd/scripts/compare_models.py create mode 100644 spd/scripts/compare_models_config.yaml diff --git a/README.md b/README.md index 3cdef743f..2599f9761 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,16 @@ subdirectories, along with a corresponding config file. E.g. python spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_5-2_config.yaml ``` +### Model Comparison + +For post-hoc analysis of completed runs, use the model comparison script to compute geometric similarities between subcomponents: + +```bash +python spd/scripts/compare_models.py --config spd/scripts/compare_models_config.yaml +``` + +The comparison script supports both wandb and local model paths, and calculates mean max absolute cosine similarity metrics (among others) between learned subcomponents in different runs. See `spd/scripts/compare_models_config.yaml` for configuration options. + ## Development Suggested extensions and settings for VSCode/Cursor are provided in `.vscode/`. To use the suggested diff --git a/spd/configs.py b/spd/configs.py index bc916b41b..0bfc6d08e 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -325,6 +325,7 @@ def microbatch_size(self) -> PositiveInt: RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = { "print_freq": "eval_freq", "pretrained_model_name_hf": "pretrained_model_name", + "output_recon_loss_type": "output_loss_type", } @model_validator(mode="before") @@ -346,6 +347,19 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, config_dict["train_log_freq"] = 50 if "slow_eval_freq" not in config_dict: config_dict["slow_eval_freq"] = config_dict["eval_freq"] + + # Add backward compatibility for older model checkpoints + if "output_loss_type" not in config_dict: + config_dict["output_loss_type"] = "kl" # Default value + logger.info("Added missing output_loss_type field with default value 'kl'") + + # Remove forbidden fields from older configs (fields that are not part of current Config schema) + forbidden_fields = ["hidden_act_recon_coeff"] + for field in forbidden_fields: + if field in config_dict: + logger.info(f"Removing forbidden field: {field}") + del config_dict[field] + return config_dict @model_validator(mode="after") diff --git a/spd/eval.py b/spd/eval.py index f26183e8e..74a67fa23 100644 --- a/spd/eval.py +++ b/spd/eval.py @@ -20,7 +20,6 @@ from torch import Tensor from spd.configs import Config -from spd.log import logger from spd.losses import calc_faithfulness_loss, calc_weight_deltas from spd.mask_info import make_mask_infos from spd.models.component_model import ComponentModel @@ -756,160 +755,6 @@ def compute(self) -> Mapping[str, float]: return {"loss/faithfulness": loss.item()} -class GeometricSimilarityComparison(StreamingEval): - SLOW = True - - def __init__(self, model: ComponentModel, config: Config, **kwargs: Any): - self.model = model - self.config = config - self.reference_run_path = kwargs.get("reference_run_path") - if self.reference_run_path is None: - raise ValueError("reference_run_path is required for GeometricSimilarityComparison") - self.kwargs = kwargs - self.reference_model: ComponentModel | None = None - self.device = next(iter(model.parameters())).device - self.n_tokens = 0 - self.component_activation_counts: dict[str, Float[Tensor, " C"]] = { - module_name: torch.zeros(model.C, device=self.device) - for module_name in model.components - } - - def _load_reference_model(self) -> ComponentModel: - """Load the reference model from wandb or local path""" - if self.reference_model is None: - from spd.models.component_model import ComponentModel - - assert self.reference_run_path is not None, ( - "reference_run_path should not be None at this point" - ) - self.reference_model = ComponentModel.from_pretrained(self.reference_run_path) - - if torch.cuda.is_available(): - self.reference_model.to("cuda") - self.reference_model.eval() - self.reference_model.requires_grad_(False) - - return self.reference_model - - def _compute_subcomponent_geometric_similarities( - self, activation_densities: dict[str, Float[Tensor, " C"]] - ) -> dict[str, float]: - """Compute mean max cosine similarity between subcomponent rank-one matrices""" - reference_model = self._load_reference_model() - similarities = {} - - # Iterate through all component layers in both models - for layer_name in self.model.components: - if layer_name not in reference_model.components: - logger.warning(f"Layer {layer_name} not found in reference model, skipping") - continue - - current_components = self.model.components[layer_name] - reference_components = reference_model.components[layer_name] - - # Verify component counts match - if current_components.C != reference_components.C: - logger.warning( - f"Component count mismatch for {layer_name}: {current_components.C} vs {reference_components.C}" - ) - continue - - # Extract U and V matrices - C = current_components.C - current_U = current_components.U # Shape: [C, d_out] - current_V = current_components.V # Shape: [d_in, C] - ref_U = reference_components.U - ref_V = reference_components.V - - # Throw away components that are not active enough in the current model - density_threshold = self.kwargs.get("density_threshold", 0.0) - C_alive = sum(activation_densities[layer_name] > density_threshold) - if C_alive == 0: - logger.warning( - f"\n WARNING:No components are active enough in {layer_name} for density threshold {density_threshold}. Geometric similarity comparison failed to run. \n" - ) - continue - current_V = current_V[:, activation_densities[layer_name] > density_threshold] - current_U = current_U[activation_densities[layer_name] > density_threshold] - - # Compute rank-one matrices: V @ U for each component - # Each component c produces a rank-one matrix of shape [d_in, d_out] - current_rank_one = einops.einsum( - current_V, current_U, "d_in C_alive, C_alive d_out -> C_alive d_in d_out" - ) - ref_rank_one = einops.einsum(ref_V, ref_U, "d_in C, C d_out -> C d_in d_out") - - # Flatten to vectors for cosine similarity computation - current_flat = current_rank_one.reshape(C_alive, -1) - ref_flat = ref_rank_one.reshape(C, -1) - - # Compute cosine similarities between all pairs - current_norm = F.normalize(current_flat, p=2, dim=1) - ref_norm = F.normalize(ref_flat, p=2, dim=1) - - cosine_sim_matrix = einops.einsum( - current_norm, ref_norm, "C_alive d_in_d_out, C_ref d_in_d_out -> C_alive C_ref" - ) - # Take the abs of the cosine similarity matrix - cosine_sim_matrix = cosine_sim_matrix.abs() - - # Find max abs cosine similarity for each current component - max_similarities = cosine_sim_matrix.max(dim=1).values - similarities[f"mean_max_abs_cosine_sim/{layer_name}"] = max_similarities.mean().item() - similarities[f"max_abs_cosine_sim_std/{layer_name}"] = max_similarities.std().item() - similarities[f"max_abs_cosine_sim_min/{layer_name}"] = max_similarities.min().item() - similarities[f"max_abs_cosine_sim_max/{layer_name}"] = max_similarities.max().item() - - # Compute a metrics across all model components for each type of metric - # First get the metric names by stripping away the layer name - metric_names = [ - "mean_max_abs_cosine_sim", - "max_abs_cosine_sim_std", - "max_abs_cosine_sim_min", - "max_abs_cosine_sim_max", - ] - - for metric_name in metric_names: - # Go through all layers and get the average of the metric - values = [ - similarities[f"{metric_name}/{layer_name}"] for layer_name in self.model.components - ] - similarities[f"{metric_name}/all_layers"] = sum(values) / len(values) - - return similarities - - @override - def watch_batch( - self, - batch: Int[Tensor, "..."] | Float[Tensor, "..."], - target_out: Float[Tensor, "... vocab"], - ci: dict[str, Float[Tensor, "... C"]], - ) -> None: - n_tokens = next(iter(ci.values())).shape[:-1].numel() - self.n_tokens += n_tokens - - for module_name, ci_vals in ci.items(): - active_components = ci_vals > self.config.ci_alive_threshold - n_activations_per_component = reduce(active_components, "... C -> C", "sum") - self.component_activation_counts[module_name] += n_activations_per_component - - @override - def compute(self) -> Mapping[str, float]: - """Compute the geometric similarity metrics""" - - activation_densities = { - module_name: self.component_activation_counts[module_name] / self.n_tokens - for module_name in self.model.components - } - - try: - similarities = self._compute_subcomponent_geometric_similarities(activation_densities) - return similarities - except Exception as e: - logger.warning(f"Failed to compute geometric similarity comparison: {e}") - return {} - - EVAL_CLASSES = { cls.__name__: cls for cls in [ @@ -923,7 +768,6 @@ def compute(self) -> Mapping[str, float]: CIMeanPerComponent, SubsetReconstructionLoss, FaithfulnessLoss, - GeometricSimilarityComparison, ] } diff --git a/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml b/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml deleted file mode 100644 index 7d6e49afb..000000000 --- a/spd/experiments/lm/ss_llama_single_with_comparison_config.yaml +++ /dev/null @@ -1,124 +0,0 @@ -# --- WandB --- -wandb_project: spd -wandb_run_name: null -wandb_run_name_prefix: "" - -# --- General --- -seed: 0 -C: 4000 -n_mask_samples: 1 -gate_type: "vector_mlp" -gate_hidden_dims: [12] -sigmoid_type: "leaky_hard" -target_module_patterns: ["model.layers.*.mlp.gate_proj", "model.layers.*.mlp.down_proj", "model.layers.*.mlp.up_proj", "model.layers.*.self_attn.q_proj", "model.layers.*.self_attn.k_proj", "model.layers.*.self_attn.v_proj", "model.layers.*.self_attn.o_proj"] -sampling: "binomial" - -# --- Loss Coefficients --- -faithfulness_coeff: 10000000.0 -recon_coeff: null -stochastic_recon_coeff: 1.0 -recon_layerwise_coeff: null -stochastic_recon_layerwise_coeff: 1.0 -importance_minimality_coeff: 0.0003 -schatten_coeff: null -out_recon_coeff: null -embedding_recon_coeff: null -is_embed_unembed_recon: false -pnorm: 2.0 -p_anneal_start_frac: 0.0 -p_anneal_final_p: 0.1 -p_anneal_end_frac: 1.0 -output_loss_type: kl - -# --- Training --- -batch_size: 12 -eval_batch_size: 12 -steps: 300000 -lr: 0.0005 -lr_schedule: cosine -lr_warmup_pct: 0.0 -lr_exponential_halflife: null -gradient_accumulation_steps: 4 - -# --- Logging & Saving --- -train_log_freq: 100 -eval_freq: 100 -slow_eval_freq: 100 -slow_eval_on_first_step: true -n_eval_steps: 5 -save_freq: null -ci_alive_threshold: 0.0 -n_examples_until_dead: 1368400 -eval_metrics: - - classname: "CIHistograms" - extra_init_kwargs: - n_batches_accum: 5 - - classname: "ComponentActivationDensity" - extra_init_kwargs: {} - - classname: "CI_L0" - extra_init_kwargs: - groups: - total: ["*"] # Sum of all L0 values - layer_0: ["model.layers.0.*"] - layer_1: ["model.layers.1.*"] - layer_2: ["model.layers.2.*"] - layer_3: ["model.layers.3.*"] - - classname: "CEandKLLosses" - extra_init_kwargs: - rounding_threshold: 0.0 - - classname: "SubsetReconstructionLoss" - extra_init_kwargs: - n_mask_samples: 1 - use_all_ones_for_non_replaced: false - include_patterns: - layer_0_only: ["model.layers.0.*"] - layer_1_only: ["model.layers.1.*"] - layer_2_only: ["model.layers.2.*"] - layer_3_only: ["model.layers.3.*"] - mlp_only: ["*.mlp.*"] - attention_only: ["*.self_attn.*"] - exclude_patterns: - all_but_layer_0: ["model.layers.0.*"] - all_but_layer_1: ["model.layers.1.*"] - all_but_layer_2: ["model.layers.2.*"] - all_but_layer_3: ["model.layers.3.*"] - - classname: "GeometricSimilarityComparison" - extra_init_kwargs: - reference_run_path: "wandb:goodfire/spd/runs/2js1ccon" - density_threshold: 0.001 - - -# --- Pretrained model info --- -pretrained_model_class: transformers.LlamaForCausalLM -pretrained_model_name: SimpleStories/SimpleStories-1.25M -pretrained_model_path: null -pretrained_model_output_attr: logits -tokenizer_name: SimpleStories/SimpleStories-1.25M - -# --- Task Specific --- -task_config: - task_name: lm - max_seq_len: 512 - buffer_size: 1000 - dataset_name: "SimpleStories/SimpleStories" - column_name: "story" - train_data_split: "train" - eval_data_split: "test" - shuffle_each_epoch: true - is_tokenized: false - streaming: false - - -# Config details for the target model taken from https://github.com/danbraunai/simple_stories_train/blob/main/simple_stories_train/models/model_configs.py#L54 - # "1.25M": LlamaConfig( - # block_size=512, - # vocab_size=4096, - # n_layer=4, - # n_head=4, - # n_embd=128, - # n_intermediate=128 * 4 * 2 // 3 = 341, - # rotary_dim=128 // 4 = 32, - # n_ctx=512, - # n_key_value_heads=2, - # flash_attention=True, - # ), diff --git a/spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml b/spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml deleted file mode 100644 index 4d601862a..000000000 --- a/spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml +++ /dev/null @@ -1,82 +0,0 @@ -# ResidualMLP 2 layers with Geometric Comparison -# --- WandB --- -wandb_project: spd -wandb_run_name: null -wandb_run_name_prefix: "" - -# --- General --- -seed: 0 -C: 400 -n_mask_samples: 1 -gate_type: "mlp" -gate_hidden_dims: [16] -sigmoid_type: "leaky_hard" -target_module_patterns: - - "layers.*.mlp_in" - - "layers.*.mlp_out" - -# --- Loss Coefficients --- -faithfulness_coeff: 1.0 -out_recon_coeff: 0.0 -recon_coeff: null -stochastic_recon_coeff: 1.0 -recon_layerwise_coeff: null -stochastic_recon_layerwise_coeff: 1.0 -importance_minimality_coeff: 1e-5 -pnorm: 2 -output_loss_type: mse - -# --- Training --- -batch_size: 2048 -eval_batch_size: 2048 -steps: 50_000 -lr: 1e-3 -lr_schedule: constant -lr_warmup_pct: 0.00 - -# --- Logging & Saving --- -train_log_freq: 50 -eval_freq: 500 -n_eval_steps: 100 -slow_eval_freq: 5_000 -slow_eval_on_first_step: true -save_freq: null -ci_alive_threshold: 0.1 -n_examples_until_dead: 1_024_000 -eval_metrics: - - classname: "CIHistograms" - extra_init_kwargs: - n_batches_accum: 5 - - classname: "ComponentActivationDensity" - - classname: "PermutedCIPlots" - extra_init_kwargs: - identity_patterns: ["layers.*.mlp_in"] - dense_patterns: ["layers.*.mlp_out"] - - classname: "UVPlots" - extra_init_kwargs: - identity_patterns: ["layers.*.mlp_in"] - dense_patterns: ["layers.*.mlp_out"] - - classname: "IdentityCIError" - extra_init_kwargs: - identity_ci: - - layer_pattern: "layers.*.mlp_in" - n_features: 100 - dense_ci: - - layer_pattern: "layers.*.mlp_out" - k: 25 - - classname: "CI_L0" - - classname: "CIMeanPerComponent" - - classname: "GeometricSimilarityComparison" - extra_init_kwargs: - reference_run_path: "wandb:goodfire/spd/runs/nr085xlx" - density_threshold: 0.001 - -# --- Pretrained model info --- -pretrained_model_class: "spd.experiments.resid_mlp.models.ResidMLP" -pretrained_model_path: "wandb:goodfire/spd/runs/any9ekl9" - -# --- Task Specific --- -task_config: - task_name: resid_mlp - feature_probability: 0.01 - data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml b/spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml deleted file mode 100644 index 6a5c93456..000000000 --- a/spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml +++ /dev/null @@ -1,81 +0,0 @@ -# TMS 5-2 w/ fixed identity with Geometric Comparison -# --- WandB --- -wandb_project: spd -wandb_run_name: null -wandb_run_name_prefix: "" - -# --- General --- -seed: 0 -C: 20 -n_mask_samples: 1 -gate_type: "mlp" -gate_hidden_dims: [16] -sigmoid_type: "leaky_hard" -target_module_patterns: ["linear1", "linear2", "hidden_layers.0"] - -# --- Loss Coefficients --- -faithfulness_coeff: 1.0 -recon_coeff: null -stochastic_recon_coeff: 1 -recon_layerwise_coeff: null -stochastic_recon_layerwise_coeff: 1.0 -importance_minimality_coeff: 3e-3 -pnorm: 1.0 -output_loss_type: mse - -# --- Training --- -batch_size: 4096 -eval_batch_size: 4096 -steps: 40_000 -lr: 1e-3 -lr_schedule: cosine -lr_warmup_pct: 0.0 - -# --- Logging & Saving --- -train_log_freq: 100 -eval_freq: 1000 -n_eval_steps: 100 -slow_eval_freq: 5_000 -slow_eval_on_first_step: true -save_freq: null -ci_alive_threshold: 0.1 -n_examples_until_dead: 4_096_000 -eval_metrics: - - classname: "CIHistograms" - extra_init_kwargs: - n_batches_accum: 5 - - classname: "ComponentActivationDensity" - - classname: "PermutedCIPlots" - extra_init_kwargs: - identity_patterns: ["linear1", "linear2"] - dense_patterns: ["hidden_layers.0"] - - classname: "UVPlots" - extra_init_kwargs: - identity_patterns: ["linear1", "linear2"] - dense_patterns: ["hidden_layers.0"] - - classname: "IdentityCIError" - extra_init_kwargs: - identity_ci: - - layer_pattern: "linear1" - n_features: 5 - - layer_pattern: "linear2" - n_features: 5 - dense_ci: - - layer_pattern: "hidden_layers.0" - k: 2 - - classname: "CI_L0" - - classname: "CIMeanPerComponent" - - classname: "GeometricSimilarityComparison" - extra_init_kwargs: - reference_run_path: "wandb:goodfire/spd/runs/swr68dli" - density_threshold: 0.001 - -# --- Pretrained model info --- -pretrained_model_class: "spd.experiments.tms.models.TMSModel" -pretrained_model_path: "wandb:goodfire/spd/runs/gfgchmxj" # 1 hidden w/fixed identity - -# --- Task Specific --- -task_config: - task_name: tms - feature_probability: 0.05 - data_generation_type: "at_least_zero_active" diff --git a/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml b/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml deleted file mode 100644 index 78e040791..000000000 --- a/spd/experiments/tms/tms_5-2_geom_comparison_config.yaml +++ /dev/null @@ -1,75 +0,0 @@ -# TMS 5-2 (Non-identity) with Geometric Comparison -# --- WandB --- -wandb_project: spd -wandb_run_name: null -wandb_run_name_prefix: "" - -# --- General --- -seed: 0 -C: 20 -n_mask_samples: 1 -gate_type: "mlp" -gate_hidden_dims: [16] -sigmoid_type: "leaky_hard" -target_module_patterns: ["linear1", "linear2"] - -# --- Loss Coefficients --- -faithfulness_coeff: 1.0 -recon_coeff: null -stochastic_recon_coeff: 1 -recon_layerwise_coeff: null -stochastic_recon_layerwise_coeff: 1.0 -importance_minimality_coeff: 3e-3 -pnorm: 1.0 -output_loss_type: mse - -# --- Training --- -batch_size: 4096 -eval_batch_size: 4096 -steps: 40_000 -lr: 1e-3 -lr_schedule: cosine -lr_warmup_pct: 0.0 - -# --- Logging & Saving --- -train_log_freq: 100 -eval_freq: 1000 -n_eval_steps: 100 -slow_eval_freq: 5_000 -slow_eval_on_first_step: true -save_freq: null -ci_alive_threshold: 0.1 -n_examples_until_dead: 4_096_000 -eval_metrics: - - classname: "CIHistograms" - extra_init_kwargs: - n_batches_accum: 5 - - classname: "ComponentActivationDensity" - - classname: "PermutedCIPlots" - extra_init_kwargs: - identity_patterns: ["linear1", "linear2"] - - classname: "UVPlots" - extra_init_kwargs: - identity_patterns: ["linear1", "linear2"] - - classname: "IdentityCIError" - extra_init_kwargs: - identity_ci: - - layer_pattern: "linear1" - n_features: 5 - - layer_pattern: "linear2" - n_features: 5 - - classname: "CI_L0" - - classname: "CIMeanPerComponent" - - classname: "GeometricSimilarityComparison" - extra_init_kwargs: - reference_run_path: "wandb:goodfire/spd/runs/7ngt0c8d" - density_threshold: 0.001 -# --- Pretrained model info --- -pretrained_model_class: "spd.experiments.tms.models.TMSModel" -pretrained_model_path: "wandb:goodfire/spd/runs/0hsp07o4" - -# --- Task Specific --- -task_config: - task_name: tms - feature_probability: 0.05 - data_generation_type: "at_least_zero_active" diff --git a/spd/scripts/compare_models.py b/spd/scripts/compare_models.py new file mode 100644 index 000000000..567f62067 --- /dev/null +++ b/spd/scripts/compare_models.py @@ -0,0 +1,393 @@ +"""Model comparison script for geometric similarity analysis. + +This script compares two SPD models by computing geometric similarities between +their learned subcomponents. It's designed for post-hoc analysis of completed runs. + +Usage: + python spd/scripts/compare_models.py --config spd/scripts/compare_models_config.yaml +""" + +import argparse +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +import einops +import torch +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor + +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import extract_batch_data, load_config +from spd.utils.run_utils import save_file + + +class ModelComparator: + """Compare two SPD models for geometric similarity between subcomponents.""" + + def __init__( + self, + current_model_path: str, + reference_model_path: str, + density_threshold: float = 0.0, + device: str = "auto", + comparison_config: dict[str, Any] | None = None, + ): + """Initialize the model comparator. + + Args: + current_model_path: Path to current model (wandb: or local path) + reference_model_path: Path to reference model (wandb: or local path) + density_threshold: Minimum activation density for components to be included + device: Device to run comparison on ("auto", "cuda", "cpu") + comparison_config: Full comparison configuration dict + """ + self.current_model_path = current_model_path + self.reference_model_path = reference_model_path + self.density_threshold = density_threshold + self.comparison_config = comparison_config or {} + + if device == "auto": + self.device = get_device() + else: + self.device = device + + logger.info(f"Loading current model from: {current_model_path}") + self.current_model, self.current_config = self._load_model_and_config(current_model_path) + + logger.info(f"Loading reference model from: {reference_model_path}") + self.reference_model, self.reference_config = self._load_model_and_config( + reference_model_path + ) + + def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, dict[str, Any]]: + """Load model and config using the standard pattern from existing codebase.""" + run_info = SPDRunInfo.from_path(model_path) + model = ComponentModel.from_run_info(run_info) + model.to(self.device) + model.eval() + model.requires_grad_(False) + + config_dict = run_info.config.model_dump() + + return model, config_dict + + def create_eval_data_loader(self, config: dict[str, Any]) -> Iterator[Any]: + """Create evaluation data loader using exact same patterns as decomposition scripts.""" + task_config = config.get("task_config", {}) + task_name = task_config.get("task_name") + + if not task_name: + raise ValueError("task_config.task_name must be set") + + if task_name == "tms": + from spd.experiments.tms.models import TMSTargetRunInfo + from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset + + if "pretrained_model_path" not in config or not config["pretrained_model_path"]: + raise ValueError("pretrained_model_path must be set for TMS models") + + target_run_info = TMSTargetRunInfo.from_path(config["pretrained_model_path"]) + n_features = target_run_info.config.tms_model_config.n_features + synced_inputs = target_run_info.config.synced_inputs + + dataset = SparseFeatureDataset( + n_features=n_features, + feature_probability=task_config["feature_probability"], + device=self.device, + data_generation_type=task_config["data_generation_type"], + value_range=(0.0, 1.0), + synced_inputs=synced_inputs, + ) + return iter( + DatasetGeneratedDataLoader( + dataset, + batch_size=self.comparison_config.get( + "eval_batch_size", 1 + ), # TODO get rid of 'get' pattern + shuffle=self.comparison_config.get("shuffle_data", False), + ) + ) + + elif task_name == "resid_mlp": + from spd.experiments.resid_mlp.models import ResidMLPTargetRunInfo + from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset + from spd.utils.data_utils import DatasetGeneratedDataLoader + + if "pretrained_model_path" not in config or not config["pretrained_model_path"]: + raise ValueError("pretrained_model_path must be set for ResidMLP models") + + target_run_info = ResidMLPTargetRunInfo.from_path(config["pretrained_model_path"]) + n_features = target_run_info.config.resid_mlp_model_config.n_features + synced_inputs = target_run_info.config.synced_inputs + + dataset = ResidMLPDataset( + n_features=n_features, + feature_probability=task_config["feature_probability"], + device=self.device, + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=None, + synced_inputs=synced_inputs, + ) + return iter( + DatasetGeneratedDataLoader( + dataset, + batch_size=self.comparison_config.get("eval_batch_size", 1), + shuffle=self.comparison_config.get("shuffle_data", False), + ) + ) + + elif task_name == "lm": + from spd.data import DatasetConfig, create_data_loader + + if "tokenizer_name" not in config or not config["tokenizer_name"]: + raise ValueError("tokenizer_name must be set for language models") + + dataset_config = DatasetConfig( + name=task_config["dataset_name"], + hf_tokenizer_path=config["tokenizer_name"], + split=task_config["eval_data_split"], + n_ctx=task_config["max_seq_len"], + is_tokenized=task_config["is_tokenized"], + streaming=task_config["streaming"], + column_name=task_config["column_name"], + shuffle_each_epoch=task_config["shuffle_each_epoch"], + seed=None, + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=self.comparison_config.get("eval_batch_size", 1), + buffer_size=task_config["buffer_size"], + global_seed=config["seed"] + 1, + ddp_rank=0, + ddp_world_size=1, + ) + return iter(loader) + + elif task_name == "ih": + from spd.experiments.ih.model import InductionModelTargetRunInfo + from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset + + if "pretrained_model_path" not in config or not config["pretrained_model_path"]: + raise ValueError("pretrained_model_path must be set for Induction Heads models") + + target_run_info = InductionModelTargetRunInfo.from_path(config["pretrained_model_path"]) + vocab_size = target_run_info.config.ih_model_config.vocab_size + seq_len = target_run_info.config.ih_model_config.seq_len + prefix_window = task_config.get("prefix_window") or seq_len - 3 + + dataset = InductionDataset( + vocab_size=vocab_size, + seq_len=seq_len, + prefix_window=prefix_window, + device=self.device, + ) + return iter( + DatasetGeneratedDataLoader( + dataset, + batch_size=self.comparison_config.get("eval_batch_size", 1), + shuffle=self.comparison_config.get("shuffle_data", False), + ) + ) + + raise ValueError( + f"Unsupported task type: {task_name}. Supported types: tms, lm, resid_mlp, ih" + ) + + def compute_activation_densities( + self, model: ComponentModel, eval_iterator: Iterator[Any], n_steps: int = 5 + ) -> dict[str, Float[Tensor, " C"]]: + """Compute activation densities using same logic as ComponentActivationDensity.""" + # Get config for this model + config_dict = self.current_config if model is self.current_model else self.reference_config + ci_alive_threshold = self.comparison_config.get("ci_alive_threshold", 0.0) + + device = next(iter(model.parameters())).device + n_tokens = 0 + component_activation_counts: dict[str, Float[Tensor, " C"]] = { + module_name: torch.zeros(model.C, device=device) for module_name in model.components + } + + model.eval() + with torch.no_grad(): + for _step in range(n_steps): + batch = extract_batch_data(next(eval_iterator)) + batch = batch.to(self.device) + _, pre_weight_acts = model( + batch, mode="pre_forward_cache", module_names=list(model.components.keys()) + ) + ci, _ci_upper_leaky = model.calc_causal_importances( + pre_weight_acts, + sigmoid_type=config_dict["sigmoid_type"], + sampling=config_dict["sampling"], + ) + + n_tokens_batch = next(iter(ci.values())).shape[:-1].numel() + n_tokens += n_tokens_batch + + for module_name, ci_vals in ci.items(): + active_components = ci_vals > ci_alive_threshold + n_activations_per_component = einops.reduce( + active_components, "... C -> C", "sum" + ) + component_activation_counts[module_name] += n_activations_per_component + + densities = { + module_name: component_activation_counts[module_name] / n_tokens + for module_name in model.components + } + + return densities + + def compute_geometric_similarities( + self, activation_densities: dict[str, Float[Tensor, " C"]] + ) -> dict[str, float]: + """Compute geometric similarities between subcomponents.""" + similarities = {} + + for layer_name in self.current_model.components: + if layer_name not in self.reference_model.components: + logger.warning(f"Layer {layer_name} not found in reference model, skipping") + continue + + current_components = self.current_model.components[layer_name] + reference_components = self.reference_model.components[layer_name] + + if current_components.C != reference_components.C: + logger.warning( + f"Component count mismatch for {layer_name}: {current_components.C} vs {reference_components.C}" + ) + continue + + # Extract U and V matrices + C = current_components.C + current_U = current_components.U # Shape: [C, d_out] + current_V = current_components.V # Shape: [d_in, C] + ref_U = reference_components.U + ref_V = reference_components.V + + # Filter out components that aren't active enough in the current model + C_alive = sum(activation_densities[layer_name] > self.density_threshold) + if C_alive == 0: + logger.warning( + f"No components are active enough in {layer_name} for density threshold {self.density_threshold}. Skipping." + ) + continue + + current_U_alive = current_U[activation_densities[layer_name] > self.density_threshold] + current_V_alive = current_V[ + :, activation_densities[layer_name] > self.density_threshold + ] + + # Compute rank-one matrices: V @ U for each component + current_rank_one = einops.einsum( + current_V_alive, + current_U_alive, + "d_in C_alive, C_alive d_out -> C_alive d_in d_out", + ) + ref_rank_one = einops.einsum(ref_V, ref_U, "d_in C, C d_out -> C d_in d_out") + + # Compute cosine similarities between all pairs + current_flat = current_rank_one.reshape(int(C_alive.item()), -1) + ref_flat = ref_rank_one.reshape(C, -1) + + current_norm = F.normalize(current_flat, p=2, dim=1) + ref_norm = F.normalize(ref_flat, p=2, dim=1) + + cosine_sim_matrix = einops.einsum( + current_norm, ref_norm, "C_alive d_in_d_out, C_ref d_in_d_out -> C_alive C_ref" + ) + cosine_sim_matrix = cosine_sim_matrix.abs() + + max_similarities = cosine_sim_matrix.max(dim=1).values + similarities[f"mean_max_abs_cosine_sim/{layer_name}"] = max_similarities.mean().item() + similarities[f"max_abs_cosine_sim_std/{layer_name}"] = max_similarities.std().item() + similarities[f"max_abs_cosine_sim_min/{layer_name}"] = max_similarities.min().item() + similarities[f"max_abs_cosine_sim_max/{layer_name}"] = max_similarities.max().item() + + metric_names = [ + "mean_max_abs_cosine_sim", + "max_abs_cosine_sim_std", + "max_abs_cosine_sim_min", + "max_abs_cosine_sim_max", + ] + + for metric_name in metric_names: + values = [ + similarities[f"{metric_name}/{layer_name}"] + for layer_name in self.current_model.components + if f"{metric_name}/{layer_name}" in similarities + ] + if values: + similarities[f"{metric_name}/all_layers"] = sum(values) / len(values) + + return similarities + + def run_comparison( + self, eval_iterator: Iterator[Any], n_steps: int | None = None + ) -> dict[str, float]: + """Run the full comparison pipeline.""" + if n_steps is None: + n_steps = self.comparison_config.get("n_eval_steps", 5) + assert isinstance(n_steps, int) # Ensure n_steps is an int for type checking + + logger.info("Computing activation densities for current model...") + activation_densities = self.compute_activation_densities( + self.current_model, eval_iterator, n_steps + ) + + logger.info("Computing geometric similarities...") + similarities = self.compute_geometric_similarities(activation_densities) + + return similarities + + +def main(): + """Main execution function.""" + parser = argparse.ArgumentParser(description="Compare two SPD models for geometric similarity") + parser.add_argument( + "--config", + type=str, + default="spd/scripts/compare_models_config.yaml", + help="Path to configuration file", + ) + args = parser.parse_args() + + config = load_config(args.config, dict) + current_model_path = config["current_model_path"] + reference_model_path = config["reference_model_path"] + density_threshold = config.get("density_threshold", 0.0) + device = config.get("device", "auto") + output_dir = Path(config.get("output_dir", "./comparison_results")) + output_dir.mkdir(parents=True, exist_ok=True) + + comparator = ModelComparator( + current_model_path=current_model_path, + reference_model_path=reference_model_path, + density_threshold=density_threshold, + device=device, + comparison_config=config, + ) + + logger.info("Setting up evaluation data...") + eval_iterator = comparator.create_eval_data_loader(comparator.current_config) + + logger.info("Starting model comparison...") + similarities = comparator.run_comparison(eval_iterator) + + results_file = output_dir / "similarity_results.json" + save_file(similarities, results_file) + + logger.info(f"Comparison complete! Results saved to {results_file}") + logger.info("Similarity metrics:") + for key, value in similarities.items(): + logger.info(f" {key}: {value:.4f}") + + +if __name__ == "__main__": + main() diff --git a/spd/scripts/compare_models_config.yaml b/spd/scripts/compare_models_config.yaml new file mode 100644 index 000000000..7dad04b78 --- /dev/null +++ b/spd/scripts/compare_models_config.yaml @@ -0,0 +1,20 @@ +# Configuration for compare_models.py + +# Model paths (supports both wandb: and local paths) +current_model_path: "wandb:goodfire/spd/runs/b5qe6t98" +reference_model_path: "wandb:goodfire/spd/runs/s2b158g1" + +# Analysis parameters +density_threshold: 0.001 # Minimum activation density for components to be included in comparison +n_eval_steps: 5 # Number of evaluation steps to compute activation densities + +# Data loading parameters +eval_batch_size: 32 # Batch size for evaluation data loading +shuffle_data: false # Whether to shuffle the evaluation data +ci_alive_threshold: 0.0 # Threshold for considering components as "alive" + +# Output settings +output_dir: "./comparison_results" # Directory to save results + +# Device settings +device: "auto" # Options: "auto", "cuda", "cpu" From 98a66207977105fa0bfbb6086731cc9ed2f9aa5c Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Thu, 18 Sep 2025 16:32:13 +0000 Subject: [PATCH 08/18] Updated registry to delete old obselete experiments --- spd/registry.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/spd/registry.py b/spd/registry.py index 94a28295b..d7b0b189e 100644 --- a/spd/registry.py +++ b/spd/registry.py @@ -116,30 +116,6 @@ class ExperimentConfig: config_path=Path("spd/experiments/lm/ss_gpt2_simple_noln_config.yaml"), expected_runtime=330, ), - "ss_llama_single_with_comparison": ExperimentConfig( - task_name="lm", - decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), - config_path=Path("spd/experiments/lm/ss_llama_single_with_comparison_config.yaml"), - expected_runtime=60 * 94, # Same as ss_llama_single - ), - "tms_5-2_geom_comparison": ExperimentConfig( - task_name="tms", - decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), - config_path=Path("spd/experiments/tms/tms_5-2_geom_comparison_config.yaml"), - expected_runtime=4, # Same as tms_5-2 - ), - "tms_5-2-id_geom_comparison": ExperimentConfig( - task_name="tms", - decomp_script=Path("spd/experiments/tms/tms_decomposition.py"), - config_path=Path("spd/experiments/tms/tms_5-2-id_geom_comparison_config.yaml"), - expected_runtime=4, # Same as tms_5-2-id - ), - "resid_mlp2_geom_comparison": ExperimentConfig( - task_name="resid_mlp", - decomp_script=Path("spd/experiments/resid_mlp/resid_mlp_decomposition.py"), - config_path=Path("spd/experiments/resid_mlp/resid_mlp2_geom_comparison_config.yaml"), - expected_runtime=5, # Same as resid_mlp2 - ), # "ss_emb": ExperimentConfig( # task_name="lm", # decomp_script=Path("spd/experiments/lm/lm_decomposition.py"), From 62bd77e3c1eea6a6635ce0ade7debbb14a3f6359 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Mon, 22 Sep 2025 13:45:08 +0000 Subject: [PATCH 09/18] Reorganized compare_models into subdirectory and cleaned up config code --- spd/configs.py | 13 - spd/scripts/compare_models.py | 393 ---------------- spd/scripts/compare_models/README.md | 23 + spd/scripts/compare_models/compare_models.py | 418 ++++++++++++++++++ .../compare_models_config.yaml | 9 +- 5 files changed, 445 insertions(+), 411 deletions(-) delete mode 100644 spd/scripts/compare_models.py create mode 100644 spd/scripts/compare_models/README.md create mode 100644 spd/scripts/compare_models/compare_models.py rename spd/scripts/{ => compare_models}/compare_models_config.yaml (71%) diff --git a/spd/configs.py b/spd/configs.py index 0bfc6d08e..05bf0ebc9 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -325,7 +325,6 @@ def microbatch_size(self) -> PositiveInt: RENAMED_CONFIG_KEYS: ClassVar[dict[str, str]] = { "print_freq": "eval_freq", "pretrained_model_name_hf": "pretrained_model_name", - "output_recon_loss_type": "output_loss_type", } @model_validator(mode="before") @@ -348,18 +347,6 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, if "slow_eval_freq" not in config_dict: config_dict["slow_eval_freq"] = config_dict["eval_freq"] - # Add backward compatibility for older model checkpoints - if "output_loss_type" not in config_dict: - config_dict["output_loss_type"] = "kl" # Default value - logger.info("Added missing output_loss_type field with default value 'kl'") - - # Remove forbidden fields from older configs (fields that are not part of current Config schema) - forbidden_fields = ["hidden_act_recon_coeff"] - for field in forbidden_fields: - if field in config_dict: - logger.info(f"Removing forbidden field: {field}") - del config_dict[field] - return config_dict @model_validator(mode="after") diff --git a/spd/scripts/compare_models.py b/spd/scripts/compare_models.py deleted file mode 100644 index 567f62067..000000000 --- a/spd/scripts/compare_models.py +++ /dev/null @@ -1,393 +0,0 @@ -"""Model comparison script for geometric similarity analysis. - -This script compares two SPD models by computing geometric similarities between -their learned subcomponents. It's designed for post-hoc analysis of completed runs. - -Usage: - python spd/scripts/compare_models.py --config spd/scripts/compare_models_config.yaml -""" - -import argparse -from collections.abc import Iterator -from pathlib import Path -from typing import Any - -import einops -import torch -import torch.nn.functional as F -from jaxtyping import Float -from torch import Tensor - -from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo -from spd.utils.distributed_utils import get_device -from spd.utils.general_utils import extract_batch_data, load_config -from spd.utils.run_utils import save_file - - -class ModelComparator: - """Compare two SPD models for geometric similarity between subcomponents.""" - - def __init__( - self, - current_model_path: str, - reference_model_path: str, - density_threshold: float = 0.0, - device: str = "auto", - comparison_config: dict[str, Any] | None = None, - ): - """Initialize the model comparator. - - Args: - current_model_path: Path to current model (wandb: or local path) - reference_model_path: Path to reference model (wandb: or local path) - density_threshold: Minimum activation density for components to be included - device: Device to run comparison on ("auto", "cuda", "cpu") - comparison_config: Full comparison configuration dict - """ - self.current_model_path = current_model_path - self.reference_model_path = reference_model_path - self.density_threshold = density_threshold - self.comparison_config = comparison_config or {} - - if device == "auto": - self.device = get_device() - else: - self.device = device - - logger.info(f"Loading current model from: {current_model_path}") - self.current_model, self.current_config = self._load_model_and_config(current_model_path) - - logger.info(f"Loading reference model from: {reference_model_path}") - self.reference_model, self.reference_config = self._load_model_and_config( - reference_model_path - ) - - def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, dict[str, Any]]: - """Load model and config using the standard pattern from existing codebase.""" - run_info = SPDRunInfo.from_path(model_path) - model = ComponentModel.from_run_info(run_info) - model.to(self.device) - model.eval() - model.requires_grad_(False) - - config_dict = run_info.config.model_dump() - - return model, config_dict - - def create_eval_data_loader(self, config: dict[str, Any]) -> Iterator[Any]: - """Create evaluation data loader using exact same patterns as decomposition scripts.""" - task_config = config.get("task_config", {}) - task_name = task_config.get("task_name") - - if not task_name: - raise ValueError("task_config.task_name must be set") - - if task_name == "tms": - from spd.experiments.tms.models import TMSTargetRunInfo - from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset - - if "pretrained_model_path" not in config or not config["pretrained_model_path"]: - raise ValueError("pretrained_model_path must be set for TMS models") - - target_run_info = TMSTargetRunInfo.from_path(config["pretrained_model_path"]) - n_features = target_run_info.config.tms_model_config.n_features - synced_inputs = target_run_info.config.synced_inputs - - dataset = SparseFeatureDataset( - n_features=n_features, - feature_probability=task_config["feature_probability"], - device=self.device, - data_generation_type=task_config["data_generation_type"], - value_range=(0.0, 1.0), - synced_inputs=synced_inputs, - ) - return iter( - DatasetGeneratedDataLoader( - dataset, - batch_size=self.comparison_config.get( - "eval_batch_size", 1 - ), # TODO get rid of 'get' pattern - shuffle=self.comparison_config.get("shuffle_data", False), - ) - ) - - elif task_name == "resid_mlp": - from spd.experiments.resid_mlp.models import ResidMLPTargetRunInfo - from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset - from spd.utils.data_utils import DatasetGeneratedDataLoader - - if "pretrained_model_path" not in config or not config["pretrained_model_path"]: - raise ValueError("pretrained_model_path must be set for ResidMLP models") - - target_run_info = ResidMLPTargetRunInfo.from_path(config["pretrained_model_path"]) - n_features = target_run_info.config.resid_mlp_model_config.n_features - synced_inputs = target_run_info.config.synced_inputs - - dataset = ResidMLPDataset( - n_features=n_features, - feature_probability=task_config["feature_probability"], - device=self.device, - calc_labels=False, - label_type=None, - act_fn_name=None, - label_fn_seed=None, - synced_inputs=synced_inputs, - ) - return iter( - DatasetGeneratedDataLoader( - dataset, - batch_size=self.comparison_config.get("eval_batch_size", 1), - shuffle=self.comparison_config.get("shuffle_data", False), - ) - ) - - elif task_name == "lm": - from spd.data import DatasetConfig, create_data_loader - - if "tokenizer_name" not in config or not config["tokenizer_name"]: - raise ValueError("tokenizer_name must be set for language models") - - dataset_config = DatasetConfig( - name=task_config["dataset_name"], - hf_tokenizer_path=config["tokenizer_name"], - split=task_config["eval_data_split"], - n_ctx=task_config["max_seq_len"], - is_tokenized=task_config["is_tokenized"], - streaming=task_config["streaming"], - column_name=task_config["column_name"], - shuffle_each_epoch=task_config["shuffle_each_epoch"], - seed=None, - ) - loader, _ = create_data_loader( - dataset_config=dataset_config, - batch_size=self.comparison_config.get("eval_batch_size", 1), - buffer_size=task_config["buffer_size"], - global_seed=config["seed"] + 1, - ddp_rank=0, - ddp_world_size=1, - ) - return iter(loader) - - elif task_name == "ih": - from spd.experiments.ih.model import InductionModelTargetRunInfo - from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset - - if "pretrained_model_path" not in config or not config["pretrained_model_path"]: - raise ValueError("pretrained_model_path must be set for Induction Heads models") - - target_run_info = InductionModelTargetRunInfo.from_path(config["pretrained_model_path"]) - vocab_size = target_run_info.config.ih_model_config.vocab_size - seq_len = target_run_info.config.ih_model_config.seq_len - prefix_window = task_config.get("prefix_window") or seq_len - 3 - - dataset = InductionDataset( - vocab_size=vocab_size, - seq_len=seq_len, - prefix_window=prefix_window, - device=self.device, - ) - return iter( - DatasetGeneratedDataLoader( - dataset, - batch_size=self.comparison_config.get("eval_batch_size", 1), - shuffle=self.comparison_config.get("shuffle_data", False), - ) - ) - - raise ValueError( - f"Unsupported task type: {task_name}. Supported types: tms, lm, resid_mlp, ih" - ) - - def compute_activation_densities( - self, model: ComponentModel, eval_iterator: Iterator[Any], n_steps: int = 5 - ) -> dict[str, Float[Tensor, " C"]]: - """Compute activation densities using same logic as ComponentActivationDensity.""" - # Get config for this model - config_dict = self.current_config if model is self.current_model else self.reference_config - ci_alive_threshold = self.comparison_config.get("ci_alive_threshold", 0.0) - - device = next(iter(model.parameters())).device - n_tokens = 0 - component_activation_counts: dict[str, Float[Tensor, " C"]] = { - module_name: torch.zeros(model.C, device=device) for module_name in model.components - } - - model.eval() - with torch.no_grad(): - for _step in range(n_steps): - batch = extract_batch_data(next(eval_iterator)) - batch = batch.to(self.device) - _, pre_weight_acts = model( - batch, mode="pre_forward_cache", module_names=list(model.components.keys()) - ) - ci, _ci_upper_leaky = model.calc_causal_importances( - pre_weight_acts, - sigmoid_type=config_dict["sigmoid_type"], - sampling=config_dict["sampling"], - ) - - n_tokens_batch = next(iter(ci.values())).shape[:-1].numel() - n_tokens += n_tokens_batch - - for module_name, ci_vals in ci.items(): - active_components = ci_vals > ci_alive_threshold - n_activations_per_component = einops.reduce( - active_components, "... C -> C", "sum" - ) - component_activation_counts[module_name] += n_activations_per_component - - densities = { - module_name: component_activation_counts[module_name] / n_tokens - for module_name in model.components - } - - return densities - - def compute_geometric_similarities( - self, activation_densities: dict[str, Float[Tensor, " C"]] - ) -> dict[str, float]: - """Compute geometric similarities between subcomponents.""" - similarities = {} - - for layer_name in self.current_model.components: - if layer_name not in self.reference_model.components: - logger.warning(f"Layer {layer_name} not found in reference model, skipping") - continue - - current_components = self.current_model.components[layer_name] - reference_components = self.reference_model.components[layer_name] - - if current_components.C != reference_components.C: - logger.warning( - f"Component count mismatch for {layer_name}: {current_components.C} vs {reference_components.C}" - ) - continue - - # Extract U and V matrices - C = current_components.C - current_U = current_components.U # Shape: [C, d_out] - current_V = current_components.V # Shape: [d_in, C] - ref_U = reference_components.U - ref_V = reference_components.V - - # Filter out components that aren't active enough in the current model - C_alive = sum(activation_densities[layer_name] > self.density_threshold) - if C_alive == 0: - logger.warning( - f"No components are active enough in {layer_name} for density threshold {self.density_threshold}. Skipping." - ) - continue - - current_U_alive = current_U[activation_densities[layer_name] > self.density_threshold] - current_V_alive = current_V[ - :, activation_densities[layer_name] > self.density_threshold - ] - - # Compute rank-one matrices: V @ U for each component - current_rank_one = einops.einsum( - current_V_alive, - current_U_alive, - "d_in C_alive, C_alive d_out -> C_alive d_in d_out", - ) - ref_rank_one = einops.einsum(ref_V, ref_U, "d_in C, C d_out -> C d_in d_out") - - # Compute cosine similarities between all pairs - current_flat = current_rank_one.reshape(int(C_alive.item()), -1) - ref_flat = ref_rank_one.reshape(C, -1) - - current_norm = F.normalize(current_flat, p=2, dim=1) - ref_norm = F.normalize(ref_flat, p=2, dim=1) - - cosine_sim_matrix = einops.einsum( - current_norm, ref_norm, "C_alive d_in_d_out, C_ref d_in_d_out -> C_alive C_ref" - ) - cosine_sim_matrix = cosine_sim_matrix.abs() - - max_similarities = cosine_sim_matrix.max(dim=1).values - similarities[f"mean_max_abs_cosine_sim/{layer_name}"] = max_similarities.mean().item() - similarities[f"max_abs_cosine_sim_std/{layer_name}"] = max_similarities.std().item() - similarities[f"max_abs_cosine_sim_min/{layer_name}"] = max_similarities.min().item() - similarities[f"max_abs_cosine_sim_max/{layer_name}"] = max_similarities.max().item() - - metric_names = [ - "mean_max_abs_cosine_sim", - "max_abs_cosine_sim_std", - "max_abs_cosine_sim_min", - "max_abs_cosine_sim_max", - ] - - for metric_name in metric_names: - values = [ - similarities[f"{metric_name}/{layer_name}"] - for layer_name in self.current_model.components - if f"{metric_name}/{layer_name}" in similarities - ] - if values: - similarities[f"{metric_name}/all_layers"] = sum(values) / len(values) - - return similarities - - def run_comparison( - self, eval_iterator: Iterator[Any], n_steps: int | None = None - ) -> dict[str, float]: - """Run the full comparison pipeline.""" - if n_steps is None: - n_steps = self.comparison_config.get("n_eval_steps", 5) - assert isinstance(n_steps, int) # Ensure n_steps is an int for type checking - - logger.info("Computing activation densities for current model...") - activation_densities = self.compute_activation_densities( - self.current_model, eval_iterator, n_steps - ) - - logger.info("Computing geometric similarities...") - similarities = self.compute_geometric_similarities(activation_densities) - - return similarities - - -def main(): - """Main execution function.""" - parser = argparse.ArgumentParser(description="Compare two SPD models for geometric similarity") - parser.add_argument( - "--config", - type=str, - default="spd/scripts/compare_models_config.yaml", - help="Path to configuration file", - ) - args = parser.parse_args() - - config = load_config(args.config, dict) - current_model_path = config["current_model_path"] - reference_model_path = config["reference_model_path"] - density_threshold = config.get("density_threshold", 0.0) - device = config.get("device", "auto") - output_dir = Path(config.get("output_dir", "./comparison_results")) - output_dir.mkdir(parents=True, exist_ok=True) - - comparator = ModelComparator( - current_model_path=current_model_path, - reference_model_path=reference_model_path, - density_threshold=density_threshold, - device=device, - comparison_config=config, - ) - - logger.info("Setting up evaluation data...") - eval_iterator = comparator.create_eval_data_loader(comparator.current_config) - - logger.info("Starting model comparison...") - similarities = comparator.run_comparison(eval_iterator) - - results_file = output_dir / "similarity_results.json" - save_file(similarities, results_file) - - logger.info(f"Comparison complete! Results saved to {results_file}") - logger.info("Similarity metrics:") - for key, value in similarities.items(): - logger.info(f" {key}: {value:.4f}") - - -if __name__ == "__main__": - main() diff --git a/spd/scripts/compare_models/README.md b/spd/scripts/compare_models/README.md new file mode 100644 index 000000000..ffae97831 --- /dev/null +++ b/spd/scripts/compare_models/README.md @@ -0,0 +1,23 @@ +# Model Comparison Script + +This directory contains the model comparison script for geometric similarity analysis. + +## Files + +- `compare_models.py` - Main script for comparing two SPD models +- `compare_models_config.yaml` - Default configuration file +- `out/` - Output directory (created automatically when script runs) + +## Usage + +```bash +# Using config file +python spd/scripts/compare_models/compare_models.py spd/scripts/compare_models/compare_models_config.yaml + +# Using command line arguments +python spd/scripts/compare_models/compare_models.py --current_model_path="wandb:..." --reference_model_path="wandb:..." +``` + +## Output + +Results are saved to the `out/` directory relative to this script's location, ensuring consistent output placement regardless of where the script is invoked from. diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py new file mode 100644 index 000000000..bbe0f7d8b --- /dev/null +++ b/spd/scripts/compare_models/compare_models.py @@ -0,0 +1,418 @@ +"""Model comparison script for geometric similarity analysis. + +This script compares two SPD models by computing geometric similarities between +their learned subcomponents. It's designed for post-hoc analysis of completed runs. + +Usage: + python spd/scripts/compare_models/compare_models.py spd/scripts/compare_models/compare_models_config.yaml + python spd/scripts/compare_models/compare_models.py --current_model_path="wandb:..." --reference_model_path="wandb:..." +""" + +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +import einops +import fire +import torch +import torch.nn.functional as F +from jaxtyping import Float +from pydantic import BaseModel, Field +from torch import Tensor + +from spd.configs import Config +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import extract_batch_data, load_config +from spd.utils.run_utils import save_file + + +class CompareModelsConfig(BaseModel): + """Configuration for model comparison script.""" + + current_model_path: str = Field(..., description="Path to current model (wandb: or local path)") + reference_model_path: str = Field( + ..., description="Path to reference model (wandb: or local path)" + ) + + density_threshold: float = Field( + default=0.001, + description="Minimum activation density for components to be included in comparison", + ) + n_eval_steps: int = Field( + default=5, description="Number of evaluation steps to compute activation densities" + ) + + eval_batch_size: int = Field(default=32, description="Batch size for evaluation data loading") + shuffle_data: bool = Field(default=False, description="Whether to shuffle the evaluation data") + ci_alive_threshold: float = Field( + default=0.0, description="Threshold for considering components as 'alive'" + ) + + output_dir: str | None = Field( + default=None, + description="Directory to save results (defaults to 'out' directory relative to script location)", + ) + + device: str = Field( + default="auto", description="Device to run comparison on (Options: 'auto', 'cuda', 'cpu')" + ) + + +class ModelComparator: + """Compare two SPD models for geometric similarity between subcomponents.""" + + def __init__( + self, + config: CompareModelsConfig, + ): + """Initialize the model comparator. + + Args: + config: CompareModelsConfig instance containing all configuration parameters + """ + self.config = config + self.current_model_path = config.current_model_path + self.reference_model_path = config.reference_model_path + self.density_threshold = config.density_threshold + + self.device = get_device() if config.device == "auto" else config.device + + logger.info(f"Loading current model from: {self.current_model_path}") + self.current_model, self.current_config = self._load_model_and_config( + self.current_model_path + ) + + logger.info(f"Loading reference model from: {self.reference_model_path}") + self.reference_model, self.reference_config = self._load_model_and_config( + self.reference_model_path + ) + + def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, Config]: + """Load model and config using the standard pattern from existing codebase.""" + run_info = SPDRunInfo.from_path(model_path) + model = ComponentModel.from_run_info(run_info) + model.to(self.device) + model.eval() + model.requires_grad_(False) + + return model, run_info.config + + def create_eval_data_loader(self, config: Config) -> Iterator[Any]: + """Create evaluation data loader using exact same patterns as decomposition scripts.""" + task_config = config.task_config + task_name = task_config.task_name + + if task_name == "tms": + return self._create_tms_data_loader(config) + elif task_name == "resid_mlp": + return self._create_resid_mlp_data_loader(config) + elif task_name == "lm": + return self._create_lm_data_loader(config) + elif task_name == "induction_head": + return self._create_ih_data_loader(config) + else: + raise ValueError( + f"Unsupported task type: {task_name}. Supported types: tms, lm, resid_mlp, induction_head" + ) + + def _create_tms_data_loader(self, config: Config) -> Iterator[Any]: + """Create data loader for TMS task.""" + from spd.experiments.tms.configs import TMSTaskConfig + from spd.experiments.tms.models import TMSTargetRunInfo + from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset + + assert isinstance(config.task_config, TMSTaskConfig) + task_config = config.task_config + + assert config.pretrained_model_path, "pretrained_model_path must be set for TMS models" + + target_run_info = TMSTargetRunInfo.from_path(config.pretrained_model_path) + + dataset = SparseFeatureDataset( + n_features=target_run_info.config.tms_model_config.n_features, + feature_probability=task_config.feature_probability, + device=self.device, + data_generation_type=task_config.data_generation_type, + value_range=(0.0, 1.0), + synced_inputs=target_run_info.config.synced_inputs, + ) + return iter( + DatasetGeneratedDataLoader( + dataset, + batch_size=self.config.eval_batch_size, + shuffle=self.config.shuffle_data, + ) + ) + + def _create_resid_mlp_data_loader(self, config: Config) -> Iterator[Any]: + """Create data loader for ResidMLP task.""" + from spd.experiments.resid_mlp.configs import ResidMLPTaskConfig + from spd.experiments.resid_mlp.models import ResidMLPTargetRunInfo + from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset + from spd.utils.data_utils import DatasetGeneratedDataLoader + + assert isinstance(config.task_config, ResidMLPTaskConfig) + task_config = config.task_config + + assert config.pretrained_model_path, "pretrained_model_path must be set for ResidMLP models" + + target_run_info = ResidMLPTargetRunInfo.from_path(config.pretrained_model_path) + + dataset = ResidMLPDataset( + n_features=target_run_info.config.resid_mlp_model_config.n_features, + feature_probability=task_config.feature_probability, + device=self.device, + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=None, + synced_inputs=target_run_info.config.synced_inputs, + ) + return iter( + DatasetGeneratedDataLoader( + dataset, + batch_size=self.config.eval_batch_size, + shuffle=self.config.shuffle_data, + ) + ) + + def _create_lm_data_loader(self, config: Config) -> Iterator[Any]: + """Create data loader for LM task.""" + from spd.data import DatasetConfig, create_data_loader + from spd.experiments.lm.configs import LMTaskConfig + + assert config.tokenizer_name, "tokenizer_name must be set" + assert isinstance(config.task_config, LMTaskConfig) + task_config = config.task_config + + dataset_config = DatasetConfig( + name=task_config.dataset_name, + hf_tokenizer_path=config.tokenizer_name, + split=task_config.eval_data_split, + n_ctx=task_config.max_seq_len, + is_tokenized=task_config.is_tokenized, + streaming=task_config.streaming, + column_name=task_config.column_name, + shuffle_each_epoch=task_config.shuffle_each_epoch, + seed=None, + ) + loader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=self.config.eval_batch_size, + buffer_size=task_config.buffer_size, + global_seed=config.seed + 1, + ddp_rank=0, + ddp_world_size=1, + ) + return iter(loader) + + def _create_ih_data_loader(self, config: Config) -> Iterator[Any]: + """Create data loader for IH task.""" + from spd.experiments.ih.configs import IHTaskConfig + from spd.experiments.ih.model import InductionModelTargetRunInfo + from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset + + assert isinstance(config.task_config, IHTaskConfig) + task_config = config.task_config + + assert config.pretrained_model_path, ( + "pretrained_model_path must be set for Induction Head models" + ) + + target_run_info = InductionModelTargetRunInfo.from_path(config.pretrained_model_path) + + dataset = InductionDataset( + vocab_size=target_run_info.config.ih_model_config.vocab_size, + seq_len=target_run_info.config.ih_model_config.seq_len, + prefix_window=task_config.prefix_window + or target_run_info.config.ih_model_config.seq_len - 3, + device=self.device, + ) + return iter( + DatasetGeneratedDataLoader( + dataset, + batch_size=self.config.eval_batch_size, + shuffle=self.config.shuffle_data, + ) + ) + + def compute_activation_densities( + self, model: ComponentModel, eval_iterator: Iterator[Any], n_steps: int + ) -> dict[str, Float[Tensor, " C"]]: + """Compute activation densities using same logic as ComponentActivationDensity.""" + + model_config = self.current_config if model is self.current_model else self.reference_config + ci_alive_threshold = self.config.ci_alive_threshold + + device = next(iter(model.parameters())).device + n_tokens = 0 + component_activation_counts: dict[str, Float[Tensor, " C"]] = { + module_name: torch.zeros(model.C, device=device) for module_name in model.components + } + + model.eval() + with torch.no_grad(): + for _step in range(n_steps): + batch = extract_batch_data(next(eval_iterator)) + batch = batch.to(self.device) + _, pre_weight_acts = model( + batch, mode="pre_forward_cache", module_names=list(model.components.keys()) + ) + ci, _ci_upper_leaky = model.calc_causal_importances( + pre_weight_acts, + sigmoid_type=model_config.sigmoid_type, + sampling=model_config.sampling, + ) + + n_tokens_batch = next(iter(ci.values())).shape[:-1].numel() + n_tokens += n_tokens_batch + + for module_name, ci_vals in ci.items(): + active_components = ci_vals > ci_alive_threshold + n_activations_per_component = einops.reduce( + active_components, "... C -> C", "sum" + ) + component_activation_counts[module_name] += n_activations_per_component + + densities = { + module_name: component_activation_counts[module_name] / n_tokens + for module_name in model.components + } + + return densities + + def compute_geometric_similarities( + self, activation_densities: dict[str, Float[Tensor, " C"]] + ) -> dict[str, float]: + """Compute geometric similarities between subcomponents.""" + similarities = {} + + for layer_name in self.current_model.components: + if layer_name not in self.reference_model.components: + logger.warning(f"Layer {layer_name} not found in reference model, skipping") + continue + + current_components = self.current_model.components[layer_name] + reference_components = self.reference_model.components[layer_name] + + # Extract U and V matrices + C_ref = reference_components.C + current_U = current_components.U # Shape: [C, d_out] + current_V = current_components.V # Shape: [d_in, C] + ref_U = reference_components.U + ref_V = reference_components.V + + # Filter out components that aren't active enough in the current model + alive_mask = activation_densities[layer_name] > self.density_threshold + C_curr_alive = sum(alive_mask) + if C_curr_alive == 0: + logger.warning( + f"No components are active enough in {layer_name} for density threshold {self.density_threshold}. Skipping." + ) + continue + + current_U_alive = current_U[alive_mask] + current_V_alive = current_V[:, alive_mask] + + # Compute rank-one matrices: V @ U for each component + current_rank_one = einops.einsum( + current_V_alive, + current_U_alive, + "d_in C_curr_alive, C_curr_alive d_out -> C_curr_alive d_in d_out", + ) + ref_rank_one = einops.einsum( + ref_V, ref_U, "d_in C_ref, C_ref d_out -> C_ref d_in d_out" + ) + + # Compute cosine similarities between all pairs + current_flat = current_rank_one.reshape(int(C_curr_alive.item()), -1) + ref_flat = ref_rank_one.reshape(C_ref, -1) + + current_norm = F.normalize(current_flat, p=2, dim=1) + ref_norm = F.normalize(ref_flat, p=2, dim=1) + + cosine_sim_matrix = einops.einsum( + current_norm, + ref_norm, + "C_curr_alive d_in_d_out, C_ref d_in_d_out -> C_curr_alive C_ref", + ) + cosine_sim_matrix = cosine_sim_matrix.abs() + + max_similarities = cosine_sim_matrix.max(dim=1).values + similarities[f"mean_max_abs_cosine_sim/{layer_name}"] = max_similarities.mean().item() + similarities[f"max_abs_cosine_sim_std/{layer_name}"] = max_similarities.std().item() + similarities[f"max_abs_cosine_sim_min/{layer_name}"] = max_similarities.min().item() + similarities[f"max_abs_cosine_sim_max/{layer_name}"] = max_similarities.max().item() + + metric_names = [ + "mean_max_abs_cosine_sim", + "max_abs_cosine_sim_std", + "max_abs_cosine_sim_min", + "max_abs_cosine_sim_max", + ] + + for metric_name in metric_names: + values = [ + similarities[f"{metric_name}/{layer_name}"] + for layer_name in self.current_model.components + if f"{metric_name}/{layer_name}" in similarities + ] + if values: + similarities[f"{metric_name}/all_layers"] = sum(values) / len(values) + + return similarities + + def run_comparison( + self, eval_iterator: Iterator[Any], n_steps: int | None = None + ) -> dict[str, float]: + """Run the full comparison pipeline.""" + if n_steps is None: + n_steps = self.config.n_eval_steps + assert isinstance(n_steps, int) + + logger.info("Computing activation densities for current model...") + activation_densities = self.compute_activation_densities( + self.current_model, eval_iterator, n_steps + ) + + logger.info("Computing geometric similarities...") + similarities = self.compute_geometric_similarities(activation_densities) + + return similarities + + +def main(config_path_or_obj: Path | str | CompareModelsConfig) -> None: + """Main execution function. + + Args: + config_path_or_obj: Path to YAML config file, config dict, or CompareModelsConfig instance + """ + config = load_config(config_path_or_obj, config_model=CompareModelsConfig) + + if config.output_dir is None: + output_dir = Path(__file__).parent / "out" + else: + output_dir = Path(config.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + comparator = ModelComparator(config) + + logger.info("Setting up evaluation data...") + eval_iterator = comparator.create_eval_data_loader(comparator.current_config) + + logger.info("Starting model comparison...") + similarities = comparator.run_comparison(eval_iterator) + + results_file = output_dir / "similarity_results.json" + save_file(similarities, results_file) + + logger.info(f"Comparison complete! Results saved to {results_file}") + logger.info("Similarity metrics:") + for key, value in similarities.items(): + logger.info(f" {key}: {value:.4f}") + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/spd/scripts/compare_models_config.yaml b/spd/scripts/compare_models/compare_models_config.yaml similarity index 71% rename from spd/scripts/compare_models_config.yaml rename to spd/scripts/compare_models/compare_models_config.yaml index 7dad04b78..99f5504e3 100644 --- a/spd/scripts/compare_models_config.yaml +++ b/spd/scripts/compare_models/compare_models_config.yaml @@ -1,8 +1,10 @@ # Configuration for compare_models.py # Model paths (supports both wandb: and local paths) -current_model_path: "wandb:goodfire/spd/runs/b5qe6t98" -reference_model_path: "wandb:goodfire/spd/runs/s2b158g1" +# current_model_path: "wandb:goodfire/spd/runs/b5qe6t98" +# reference_model_path: "wandb:goodfire/spd/runs/s2b158g1" +current_model_path: "wandb:goodfire/spd/runs/667z2n1b" +reference_model_path: "wandb:goodfire/spd/runs/vh4yszsd" # Analysis parameters density_threshold: 0.001 # Minimum activation density for components to be included in comparison @@ -13,8 +15,5 @@ eval_batch_size: 32 # Batch size for evaluation data loading shuffle_data: false # Whether to shuffle the evaluation data ci_alive_threshold: 0.0 # Threshold for considering components as "alive" -# Output settings -output_dir: "./comparison_results" # Directory to save results - # Device settings device: "auto" # Options: "auto", "cuda", "cpu" From 5173a6a9bc70b4d09f7e694d0370d2632a9aa4f7 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Mon, 22 Sep 2025 14:24:34 +0000 Subject: [PATCH 10/18] Updated README.md --- README.md | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2599f9761..8bdd2382b 100644 --- a/README.md +++ b/README.md @@ -85,10 +85,16 @@ python spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_5-2_conf For post-hoc analysis of completed runs, use the model comparison script to compute geometric similarities between subcomponents: ```bash -python spd/scripts/compare_models.py --config spd/scripts/compare_models_config.yaml +# Using config file +python spd/scripts/compare_models/compare_models.py spd/scripts/compare_models/compare_models_config.yaml + +# Using command line arguments +python spd/scripts/compare_models/compare_models.py --current_model_path="wandb:..." --reference_model_path="wandb:..." ``` -The comparison script supports both wandb and local model paths, and calculates mean max absolute cosine similarity metrics (among others) between learned subcomponents in different runs. See `spd/scripts/compare_models_config.yaml` for configuration options. +The comparison script supports both wandb and local model paths, and calculates mean max absolute cosine similarity metrics between learned subcomponents in different runs. + +See `spd/scripts/compare_models/README.md` for detailed usage instructions. ## Development From 181cac81ae0b8f9beab73c8c0aca26e0c6c90e8d Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Mon, 22 Sep 2025 14:27:35 +0000 Subject: [PATCH 11/18] Added some example models to the config --- .../compare_models/compare_models_config.yaml | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/spd/scripts/compare_models/compare_models_config.yaml b/spd/scripts/compare_models/compare_models_config.yaml index 99f5504e3..cf6b5fb60 100644 --- a/spd/scripts/compare_models/compare_models_config.yaml +++ b/spd/scripts/compare_models/compare_models_config.yaml @@ -1,10 +1,15 @@ # Configuration for compare_models.py # Model paths (supports both wandb: and local paths) -# current_model_path: "wandb:goodfire/spd/runs/b5qe6t98" -# reference_model_path: "wandb:goodfire/spd/runs/s2b158g1" -current_model_path: "wandb:goodfire/spd/runs/667z2n1b" -reference_model_path: "wandb:goodfire/spd/runs/vh4yszsd" + +# TMS 5-2-id example models: +# current_model_path: "wandb:goodfire/spd/runs/667z2n1b" +# reference_model_path: "wandb:goodfire/spd/runs/vh4yszsd" + +# SS LLAMA example models: +current_model_path: "wandb:goodfire/spd/runs/4r8yn2zt" +reference_model_path: "wandb:goodfire/spd/runs/2lq9dpnb" + # Analysis parameters density_threshold: 0.001 # Minimum activation density for components to be included in comparison From 8db7559be4cf2cd9bdd9b4fc447e7d1e4fb37bf4 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Mon, 22 Sep 2025 14:30:08 +0000 Subject: [PATCH 12/18] Getting rid of newline --- spd/configs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spd/configs.py b/spd/configs.py index f501f746f..a6f8cf743 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -337,7 +337,6 @@ def handle_deprecated_config_keys(cls, config_dict: dict[str, Any]) -> dict[str, config_dict["train_log_freq"] = 50 if "slow_eval_freq" not in config_dict: config_dict["slow_eval_freq"] = config_dict["eval_freq"] - return config_dict @model_validator(mode="after") From 0d05f0a53a91c5ca426e69359cf24d9d106cea96 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Tue, 23 Sep 2025 13:18:38 +0000 Subject: [PATCH 13/18] Minor changes to make the PR mergeable --- README.md | 8 +- spd/scripts/compare_models/compare_models.py | 123 +++++++++--------- .../compare_models/compare_models_config.yaml | 11 +- 3 files changed, 64 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index 8bdd2382b..768bd4f73 100644 --- a/README.md +++ b/README.md @@ -82,18 +82,12 @@ python spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_5-2_conf ### Model Comparison -For post-hoc analysis of completed runs, use the model comparison script to compute geometric similarities between subcomponents: +Use the model comparison script to analyse (post hoc) the geometric similarities between subcomponents of two different models: ```bash -# Using config file python spd/scripts/compare_models/compare_models.py spd/scripts/compare_models/compare_models_config.yaml - -# Using command line arguments -python spd/scripts/compare_models/compare_models.py --current_model_path="wandb:..." --reference_model_path="wandb:..." ``` -The comparison script supports both wandb and local model paths, and calculates mean max absolute cosine similarity metrics between learned subcomponents in different runs. - See `spd/scripts/compare_models/README.md` for detailed usage instructions. ## Development diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index bbe0f7d8b..c481b5a3d 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -8,7 +8,7 @@ python spd/scripts/compare_models/compare_models.py --current_model_path="wandb:..." --reference_model_path="wandb:..." """ -from collections.abc import Iterator +from collections.abc import Callable, Iterator from pathlib import Path from typing import Any @@ -37,17 +37,16 @@ class CompareModelsConfig(BaseModel): ) density_threshold: float = Field( - default=0.001, - description="Minimum activation density for components to be included in comparison", + ..., description="Minimum activation density for components to be included in comparison" ) n_eval_steps: int = Field( - default=5, description="Number of evaluation steps to compute activation densities" + ..., description="Number of evaluation steps to compute activation densities" ) - eval_batch_size: int = Field(default=32, description="Batch size for evaluation data loading") - shuffle_data: bool = Field(default=False, description="Whether to shuffle the evaluation data") + eval_batch_size: int = Field(..., description="Batch size for evaluation data loading") + shuffle_data: bool = Field(..., description="Whether to shuffle the evaluation data") ci_alive_threshold: float = Field( - default=0.0, description="Threshold for considering components as 'alive'" + ..., description="Threshold for considering components as 'alive'" ) output_dir: str | None = Field( @@ -55,38 +54,28 @@ class CompareModelsConfig(BaseModel): description="Directory to save results (defaults to 'out' directory relative to script location)", ) - device: str = Field( - default="auto", description="Device to run comparison on (Options: 'auto', 'cuda', 'cpu')" - ) - class ModelComparator: """Compare two SPD models for geometric similarity between subcomponents.""" - def __init__( - self, - config: CompareModelsConfig, - ): + def __init__(self, config: CompareModelsConfig): """Initialize the model comparator. Args: config: CompareModelsConfig instance containing all configuration parameters """ self.config = config - self.current_model_path = config.current_model_path - self.reference_model_path = config.reference_model_path self.density_threshold = config.density_threshold + self.device = get_device() - self.device = get_device() if config.device == "auto" else config.device - - logger.info(f"Loading current model from: {self.current_model_path}") + logger.info(f"Loading current model from: {config.current_model_path}") self.current_model, self.current_config = self._load_model_and_config( - self.current_model_path + config.current_model_path ) - logger.info(f"Loading reference model from: {self.reference_model_path}") + logger.info(f"Loading reference model from: {config.reference_model_path}") self.reference_model, self.reference_config = self._load_model_and_config( - self.reference_model_path + config.reference_model_path ) def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, Config]: @@ -99,36 +88,38 @@ def _load_model_and_config(self, model_path: str) -> tuple[ComponentModel, Confi return model, run_info.config - def create_eval_data_loader(self, config: Config) -> Iterator[Any]: + def create_eval_data_loader(self) -> Iterator[Any]: """Create evaluation data loader using exact same patterns as decomposition scripts.""" - task_config = config.task_config - task_name = task_config.task_name - - if task_name == "tms": - return self._create_tms_data_loader(config) - elif task_name == "resid_mlp": - return self._create_resid_mlp_data_loader(config) - elif task_name == "lm": - return self._create_lm_data_loader(config) - elif task_name == "induction_head": - return self._create_ih_data_loader(config) - else: + task_name = self.current_config.task_config.task_name + + data_loader_fns: dict[str, Callable[[], Iterator[Any]]] = { + "tms": self._create_tms_data_loader, + "resid_mlp": self._create_resid_mlp_data_loader, + "lm": self._create_lm_data_loader, + "induction_head": self._create_ih_data_loader, + } + + if task_name not in data_loader_fns: raise ValueError( - f"Unsupported task type: {task_name}. Supported types: tms, lm, resid_mlp, induction_head" + f"Unsupported task type: {task_name}. Supported types: {', '.join(data_loader_fns.keys())}" ) - def _create_tms_data_loader(self, config: Config) -> Iterator[Any]: + return data_loader_fns[task_name]() + + def _create_tms_data_loader(self) -> Iterator[Any]: """Create data loader for TMS task.""" from spd.experiments.tms.configs import TMSTaskConfig from spd.experiments.tms.models import TMSTargetRunInfo from spd.utils.data_utils import DatasetGeneratedDataLoader, SparseFeatureDataset - assert isinstance(config.task_config, TMSTaskConfig) - task_config = config.task_config + assert isinstance(self.current_config.task_config, TMSTaskConfig) + task_config = self.current_config.task_config - assert config.pretrained_model_path, "pretrained_model_path must be set for TMS models" + assert self.current_config.pretrained_model_path, ( + "pretrained_model_path must be set for TMS models" + ) - target_run_info = TMSTargetRunInfo.from_path(config.pretrained_model_path) + target_run_info = TMSTargetRunInfo.from_path(self.current_config.pretrained_model_path) dataset = SparseFeatureDataset( n_features=target_run_info.config.tms_model_config.n_features, @@ -146,19 +137,21 @@ def _create_tms_data_loader(self, config: Config) -> Iterator[Any]: ) ) - def _create_resid_mlp_data_loader(self, config: Config) -> Iterator[Any]: + def _create_resid_mlp_data_loader(self) -> Iterator[Any]: """Create data loader for ResidMLP task.""" from spd.experiments.resid_mlp.configs import ResidMLPTaskConfig from spd.experiments.resid_mlp.models import ResidMLPTargetRunInfo from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader - assert isinstance(config.task_config, ResidMLPTaskConfig) - task_config = config.task_config + assert isinstance(self.current_config.task_config, ResidMLPTaskConfig) + task_config = self.current_config.task_config - assert config.pretrained_model_path, "pretrained_model_path must be set for ResidMLP models" + assert self.current_config.pretrained_model_path, ( + "pretrained_model_path must be set for ResidMLP models" + ) - target_run_info = ResidMLPTargetRunInfo.from_path(config.pretrained_model_path) + target_run_info = ResidMLPTargetRunInfo.from_path(self.current_config.pretrained_model_path) dataset = ResidMLPDataset( n_features=target_run_info.config.resid_mlp_model_config.n_features, @@ -178,18 +171,18 @@ def _create_resid_mlp_data_loader(self, config: Config) -> Iterator[Any]: ) ) - def _create_lm_data_loader(self, config: Config) -> Iterator[Any]: + def _create_lm_data_loader(self) -> Iterator[Any]: """Create data loader for LM task.""" from spd.data import DatasetConfig, create_data_loader from spd.experiments.lm.configs import LMTaskConfig - assert config.tokenizer_name, "tokenizer_name must be set" - assert isinstance(config.task_config, LMTaskConfig) - task_config = config.task_config + assert self.current_config.tokenizer_name, "tokenizer_name must be set" + assert isinstance(self.current_config.task_config, LMTaskConfig) + task_config = self.current_config.task_config dataset_config = DatasetConfig( name=task_config.dataset_name, - hf_tokenizer_path=config.tokenizer_name, + hf_tokenizer_path=self.current_config.tokenizer_name, split=task_config.eval_data_split, n_ctx=task_config.max_seq_len, is_tokenized=task_config.is_tokenized, @@ -202,26 +195,28 @@ def _create_lm_data_loader(self, config: Config) -> Iterator[Any]: dataset_config=dataset_config, batch_size=self.config.eval_batch_size, buffer_size=task_config.buffer_size, - global_seed=config.seed + 1, + global_seed=self.current_config.seed + 1, ddp_rank=0, ddp_world_size=1, ) return iter(loader) - def _create_ih_data_loader(self, config: Config) -> Iterator[Any]: + def _create_ih_data_loader(self) -> Iterator[Any]: """Create data loader for IH task.""" from spd.experiments.ih.configs import IHTaskConfig from spd.experiments.ih.model import InductionModelTargetRunInfo from spd.utils.data_utils import DatasetGeneratedDataLoader, InductionDataset - assert isinstance(config.task_config, IHTaskConfig) - task_config = config.task_config + assert isinstance(self.current_config.task_config, IHTaskConfig) + task_config = self.current_config.task_config - assert config.pretrained_model_path, ( + assert self.current_config.pretrained_model_path, ( "pretrained_model_path must be set for Induction Head models" ) - target_run_info = InductionModelTargetRunInfo.from_path(config.pretrained_model_path) + target_run_info = InductionModelTargetRunInfo.from_path( + self.current_config.pretrained_model_path + ) dataset = InductionDataset( vocab_size=target_run_info.config.ih_model_config.vocab_size, @@ -260,7 +255,7 @@ def compute_activation_densities( _, pre_weight_acts = model( batch, mode="pre_forward_cache", module_names=list(model.components.keys()) ) - ci, _ci_upper_leaky = model.calc_causal_importances( + ci, _ = model.calc_causal_importances( pre_weight_acts, sigmoid_type=model_config.sigmoid_type, sampling=model_config.sampling, @@ -305,11 +300,11 @@ def compute_geometric_similarities( ref_V = reference_components.V # Filter out components that aren't active enough in the current model - alive_mask = activation_densities[layer_name] > self.density_threshold - C_curr_alive = sum(alive_mask) + alive_mask = activation_densities[layer_name] > self.config.density_threshold + C_curr_alive = int(alive_mask.sum().item()) if C_curr_alive == 0: logger.warning( - f"No components are active enough in {layer_name} for density threshold {self.density_threshold}. Skipping." + f"No components are active enough in {layer_name} for density threshold {self.config.density_threshold}. Skipping." ) continue @@ -327,7 +322,7 @@ def compute_geometric_similarities( ) # Compute cosine similarities between all pairs - current_flat = current_rank_one.reshape(int(C_curr_alive.item()), -1) + current_flat = current_rank_one.reshape(C_curr_alive, -1) ref_flat = ref_rank_one.reshape(C_ref, -1) current_norm = F.normalize(current_flat, p=2, dim=1) @@ -400,7 +395,7 @@ def main(config_path_or_obj: Path | str | CompareModelsConfig) -> None: comparator = ModelComparator(config) logger.info("Setting up evaluation data...") - eval_iterator = comparator.create_eval_data_loader(comparator.current_config) + eval_iterator = comparator.create_eval_data_loader() logger.info("Starting model comparison...") similarities = comparator.run_comparison(eval_iterator) diff --git a/spd/scripts/compare_models/compare_models_config.yaml b/spd/scripts/compare_models/compare_models_config.yaml index cf6b5fb60..cb89abcbf 100644 --- a/spd/scripts/compare_models/compare_models_config.yaml +++ b/spd/scripts/compare_models/compare_models_config.yaml @@ -3,12 +3,12 @@ # Model paths (supports both wandb: and local paths) # TMS 5-2-id example models: -# current_model_path: "wandb:goodfire/spd/runs/667z2n1b" -# reference_model_path: "wandb:goodfire/spd/runs/vh4yszsd" +current_model_path: "wandb:goodfire/spd/runs/667z2n1b" +reference_model_path: "wandb:goodfire/spd/runs/vh4yszsd" # SS LLAMA example models: -current_model_path: "wandb:goodfire/spd/runs/4r8yn2zt" -reference_model_path: "wandb:goodfire/spd/runs/2lq9dpnb" +# current_model_path: "wandb:goodfire/spd/runs/4r8yn2zt" +# reference_model_path: "wandb:goodfire/spd/runs/2lq9dpnb" # Analysis parameters @@ -19,6 +19,3 @@ n_eval_steps: 5 # Number of evaluation steps to compute activation densities eval_batch_size: 32 # Batch size for evaluation data loading shuffle_data: false # Whether to shuffle the evaluation data ci_alive_threshold: 0.0 # Threshold for considering components as "alive" - -# Device settings -device: "auto" # Options: "auto", "cuda", "cpu" From bf4048c6cd8a56cb836aa4a3b8595424a975231f Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Fri, 21 Nov 2025 16:28:04 +0000 Subject: [PATCH 14/18] wip: Add gradient clipping support to SPD optimizer --- spd/configs.py | 4 ++++ spd/run_spd.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/spd/configs.py b/spd/configs.py index 7f0c5b562..33de31af2 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -296,6 +296,10 @@ def all_module_patterns(self): default=1, description="Number of steps to accumulate gradients over before updating parameters", ) + grad_clip_norm: PositiveFloat | None = Field( + default=None, + description="If set, clip gradient norm to this value before each optimiser step", + ) # --- Faithfulness Warmup --- faithfulness_warmup_steps: NonNegativeInt = Field( diff --git a/spd/run_spd.py b/spd/run_spd.py index 6de94bd55..8ee63022c 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -14,6 +14,7 @@ from jaxtyping import Float, Int from PIL import Image from torch import Tensor +from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from tqdm import tqdm @@ -199,7 +200,8 @@ def create_pgd_data_iter() -> ( assert len(component_params) > 0, "No parameters found in components to optimize" - optimizer = optim.AdamW(component_params + ci_fn_params, lr=config.lr, weight_decay=0) + optimized_params = component_params + ci_fn_params + optimizer = optim.AdamW(optimized_params, lr=config.lr, weight_decay=0) lr_schedule_fn = get_lr_schedule_fn(config.lr_schedule, config.lr_exponential_halflife) logger.info(f"Base LR scheduler created: {config.lr_schedule}") @@ -393,6 +395,8 @@ def create_pgd_data_iter() -> ( # Skip gradient step if we are at the last step (last step just for plotting and logging) if step != config.steps: sync_across_processes() + if config.grad_clip_norm is not None: + clip_grad_norm_(optimized_params, config.grad_clip_norm) optimizer.step() if is_main_process(): From ae3a6355514e9b0a1c8bca8f20055729b8578d69 Mon Sep 17 00:00:00 2001 From: Lucius Bushnaq Date: Sat, 22 Nov 2025 18:37:18 +0000 Subject: [PATCH 15/18] Update compare models script and config --- spd/scripts/compare_models/compare_models.py | 207 ++++++++++++++---- .../compare_models/compare_models_config.yaml | 14 +- 2 files changed, 172 insertions(+), 49 deletions(-) diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index af7ea07f7..a655f06c4 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -37,11 +37,14 @@ class CompareModelsConfig(BaseConfig): ..., description="Path to reference model (wandb: or local path)" ) - density_threshold: float = Field( - ..., description="Minimum activation density for components to be included in comparison" + mean_ci_threshold: float = Field( + ..., + ge=0.0, + le=1.0, + description="Minimum mean causal importance for components to be included in comparison", ) n_eval_steps: int = Field( - ..., description="Number of evaluation steps to compute activation densities" + ..., description="Number of evaluation steps to compute mean causal importances" ) eval_batch_size: int = Field(..., description="Batch size for evaluation data loading") @@ -66,7 +69,7 @@ def __init__(self, config: CompareModelsConfig): config: CompareModelsConfig instance containing all configuration parameters """ self.config = config - self.density_threshold = config.density_threshold + self.mean_ci_threshold = config.mean_ci_threshold self.device = get_device() logger.info(f"Loading current model from: {config.current_model_path}") @@ -234,51 +237,125 @@ def _create_ih_data_loader(self) -> Iterator[Any]: ) ) - def compute_activation_densities( - self, model: ComponentModel, eval_iterator: Iterator[Any], n_steps: int - ) -> dict[str, Float[Tensor, " C"]]: - """Compute activation densities using same logic as ComponentActivationDensity.""" + def compute_ci_statistics( + self, batches: list[Any] + ) -> tuple[dict[str, Float[Tensor, " C"]], dict[str, Tensor]]: + """Compute mean causal importances and cosine similarity matrices per component.""" - model_config = self.current_config if model is self.current_model else self.reference_config - ci_alive_threshold = self.config.ci_alive_threshold + if not batches: + raise ValueError("No evaluation batches provided for CI statistics computation.") - device = get_obj_device(model) - n_tokens = 0 - component_activation_counts: dict[str, Float[Tensor, " C"]] = { - module_name: torch.zeros(model.C, device=device) for module_name in model.components - } + device = get_obj_device(self.current_model) + + component_ci_sums: dict[str, Float[Tensor, " C"]] = {} + component_example_counts: dict[str, Tensor] = {} + ci_cross_dot_products: dict[str, Tensor] = {} + ci_current_sq_sums: dict[str, Float[Tensor, " C"]] = {} + ci_reference_sq_sums: dict[str, Tensor] = {} + + for module_name, current_module in self.current_model.components.items(): + component_dim_current = current_module.C + component_ci_sums[module_name] = torch.zeros(component_dim_current, device=device) + component_example_counts[module_name] = torch.tensor(0.0, device=device) + ci_current_sq_sums[module_name] = torch.zeros(component_dim_current, device=device) + + reference_module = self.reference_model.components.get(module_name) + if reference_module is not None: + ci_cross_dot_products[module_name] = torch.zeros( + component_dim_current, reference_module.C, device=device + ) + ci_reference_sq_sums[module_name] = torch.zeros(reference_module.C, device=device) + + self.current_model.eval() + self.reference_model.eval() - model.eval() with torch.no_grad(): - for _step in range(n_steps): - batch = extract_batch_data(next(eval_iterator)) + for batch in batches: batch = batch.to(self.device) - pre_weight_acts = model(batch, cache_type="input").cache - ci = model.calc_causal_importances( - pre_weight_acts, - sampling=model_config.sampling, + pre_weight_current = self.current_model(batch, cache_type="input").cache + ci_current = self.current_model.calc_causal_importances( + pre_weight_current, + sampling=self.current_config.sampling, + ).lower_leaky + + pre_weight_reference = self.reference_model(batch, cache_type="input").cache + ci_reference = self.reference_model.calc_causal_importances( + pre_weight_reference, + sampling=self.reference_config.sampling, ).lower_leaky - n_tokens_batch = next(iter(ci.values())).shape[:-1].numel() - n_tokens += n_tokens_batch + for module_name, ci_vals_current in ci_current.items(): + ci_vals_current_fp32 = ci_vals_current.to(device=device, dtype=torch.float32) + + n_leading_dims = ci_vals_current_fp32.ndim - 1 + leading_dim_idxs = tuple(range(n_leading_dims)) + n_examples = float(ci_vals_current_fp32.shape[:n_leading_dims].numel()) + + component_ci_sums[module_name] += ci_vals_current_fp32.sum(dim=leading_dim_idxs) + component_example_counts[module_name] += n_examples + + if module_name not in ci_cross_dot_products: + continue + + if module_name not in ci_reference: + logger.warning( + "Module %s not found in reference CI outputs. Skipping cosine similarity.", + module_name, + ) + continue + + ci_vals_reference = ci_reference[module_name] + if ci_vals_current.shape != ci_vals_reference.shape: + logger.warning( + "Shape mismatch for module %s between current and reference CI outputs " + "(%s vs %s). Skipping cosine similarity.", + module_name, + ci_vals_current.shape, + ci_vals_reference.shape, + ) + continue + + ci_vals_reference_fp32 = ci_vals_reference.to( + device=device, dtype=torch.float32 + ) + + ci_current_flat = ci_vals_current_fp32.reshape( + -1, ci_vals_current_fp32.shape[-1] + ) + ci_reference_flat = ci_vals_reference_fp32.reshape( + -1, ci_vals_reference_fp32.shape[-1] + ) - for module_name, ci_vals in ci.items(): - active_components = ci_vals > ci_alive_threshold - n_activations_per_component = einops.reduce( - active_components, "... C -> C", "sum" + ci_cross_dot_products[module_name] += ( + ci_current_flat.transpose(0, 1) @ ci_reference_flat ) - component_activation_counts[module_name] += n_activations_per_component + ci_current_sq_sums[module_name] += (ci_current_flat.square()).sum(dim=0) + ci_reference_sq_sums[module_name] += (ci_reference_flat.square()).sum(dim=0) - densities = { - module_name: component_activation_counts[module_name] / n_tokens - for module_name in model.components + mean_component_cis = { + module_name: component_ci_sums[module_name] + / component_example_counts[module_name].clamp_min(1.0) + for module_name in component_ci_sums } - return densities + ci_cosine_matrices: dict[str, Tensor] = {} + eps = 1e-12 + for module_name, dot_products in ci_cross_dot_products.items(): + current_norm = torch.sqrt(ci_current_sq_sums[module_name]).clamp_min(eps) + reference_norm = torch.sqrt(ci_reference_sq_sums[module_name]).clamp_min(eps) + denom = torch.outer(current_norm, reference_norm) + cos_matrix = torch.zeros_like(dot_products) + nonzero_mask = denom > 0 + cos_matrix[nonzero_mask] = dot_products[nonzero_mask] / denom[nonzero_mask] + ci_cosine_matrices[module_name] = cos_matrix + + return mean_component_cis, ci_cosine_matrices def compute_geometric_similarities( - self, activation_densities: dict[str, Float[Tensor, " C"]] + self, + mean_component_cis: dict[str, Float[Tensor, " C"]], + ci_cosine_similarities: dict[str, Tensor], ) -> dict[str, float]: """Compute geometric similarities between subcomponents.""" similarities = {} @@ -299,11 +376,15 @@ def compute_geometric_similarities( ref_V = reference_components.V # Filter out components that aren't active enough in the current model - alive_mask = activation_densities[layer_name] > self.config.density_threshold + alive_mask = mean_component_cis[layer_name] > self.mean_ci_threshold C_curr_alive = int(alive_mask.sum().item()) + logger.info( + f"Layer {layer_name}: {C_curr_alive} components above mean CI threshold " + f"{self.mean_ci_threshold}" + ) if C_curr_alive == 0: logger.warning( - f"No components are active enough in {layer_name} for density threshold {self.config.density_threshold}. Skipping." + f"No components meet the mean CI threshold {self.mean_ci_threshold} in {layer_name}. Skipping." ) continue @@ -340,6 +421,26 @@ def compute_geometric_similarities( similarities[f"max_abs_cosine_sim_min/{layer_name}"] = max_similarities.min().item() similarities[f"max_abs_cosine_sim_max/{layer_name}"] = max_similarities.max().item() + if layer_name in ci_cosine_similarities: + ci_cos_matrix = ci_cosine_similarities[layer_name] + if ci_cos_matrix.shape[0] != alive_mask.shape[0]: + logger.warning( + "Mismatch between CI cosine matrix rows (%s) and component count (%s) for %s.", + ci_cos_matrix.shape[0], + alive_mask.shape[0], + layer_name, + ) + else: + ci_cos_alive = ci_cos_matrix[alive_mask] + if ci_cos_alive.numel() > 0: + ci_cos_max = ci_cos_alive.max(dim=1).values + similarities[f"ci_cosine_mean/{layer_name}"] = ci_cos_max.mean().item() + similarities[f"ci_cosine_std/{layer_name}"] = ci_cos_max.std( + unbiased=False + ).item() + similarities[f"ci_cosine_min/{layer_name}"] = ci_cos_max.min().item() + similarities[f"ci_cosine_max/{layer_name}"] = ci_cos_max.max().item() + metric_names = [ "mean_max_abs_cosine_sim", "max_abs_cosine_sim_std", @@ -347,7 +448,14 @@ def compute_geometric_similarities( "max_abs_cosine_sim_max", ] - for metric_name in metric_names: + cosine_metric_names = [ + "ci_cosine_mean", + "ci_cosine_std", + "ci_cosine_min", + "ci_cosine_max", + ] + + for metric_name in metric_names + cosine_metric_names: values = [ similarities[f"{metric_name}/{layer_name}"] for layer_name in self.current_model.components @@ -366,13 +474,28 @@ def run_comparison( n_steps = self.config.n_eval_steps assert isinstance(n_steps, int) - logger.info("Computing activation densities for current model...") - activation_densities = self.compute_activation_densities( - self.current_model, eval_iterator, n_steps - ) + batches: list[Any] = [] + for step in range(n_steps): + try: + batch = extract_batch_data(next(eval_iterator)) + except StopIteration: + if step == 0: + raise ValueError("Evaluation iterator provided no batches.") from None + logger.warning( + "Evaluation iterator exhausted after %s steps (requested %s).", + step, + n_steps, + ) + break + batches.append(batch) + + logger.info("Computing causal importance statistics for current and reference models...") + mean_component_cis, ci_cosine_similarities = self.compute_ci_statistics(batches) logger.info("Computing geometric similarities...") - similarities = self.compute_geometric_similarities(activation_densities) + similarities = self.compute_geometric_similarities( + mean_component_cis, ci_cosine_similarities + ) return similarities diff --git a/spd/scripts/compare_models/compare_models_config.yaml b/spd/scripts/compare_models/compare_models_config.yaml index cb89abcbf..4545ada15 100644 --- a/spd/scripts/compare_models/compare_models_config.yaml +++ b/spd/scripts/compare_models/compare_models_config.yaml @@ -3,19 +3,19 @@ # Model paths (supports both wandb: and local paths) # TMS 5-2-id example models: -current_model_path: "wandb:goodfire/spd/runs/667z2n1b" -reference_model_path: "wandb:goodfire/spd/runs/vh4yszsd" +# current_model_path: "wandb:goodfire/spd/runs/667z2n1b" +# reference_model_path: "wandb:goodfire/spd/runs/vh4yszsd" # SS LLAMA example models: -# current_model_path: "wandb:goodfire/spd/runs/4r8yn2zt" -# reference_model_path: "wandb:goodfire/spd/runs/2lq9dpnb" +current_model_path: "wandb:goodfire/spd/runs/ifg7jmm2" +reference_model_path: "wandb:goodfire/spd/runs/9i4u2kqa" # Analysis parameters -density_threshold: 0.001 # Minimum activation density for components to be included in comparison -n_eval_steps: 5 # Number of evaluation steps to compute activation densities +mean_ci_threshold: 1e-5 # Minimum mean causal importance (0-1) for components to be included +n_eval_steps: 5 # Number of evaluation steps to compute mean causal importances # Data loading parameters -eval_batch_size: 32 # Batch size for evaluation data loading +eval_batch_size: 128 # Batch size for evaluation data loading shuffle_data: false # Whether to shuffle the evaluation data ci_alive_threshold: 0.0 # Threshold for considering components as "alive" From caaa1e0e9c078d9448532bcc0a6e06bef4e9f55a Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Tue, 25 Nov 2025 03:31:47 +0000 Subject: [PATCH 16/18] Add comprehensive Claude Code documentation and checklist MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added two documentation files to help AI assistants work effectively with the SPD codebase: - CLAUDE_COMPREHENSIVE.md: Complete reference guide covering development philosophy, coding standards, architecture patterns, workflows, and collaboration practices - CLAUDE_CHECKLIST.md: Pre-submission checklist for verifying code changes meet SPD standards before committing These documents ensure consistent code quality and help future AI assistants understand project conventions, reducing onboarding time and maintaining codebase consistency. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- CLAUDE_CHECKLIST.md | 134 ++++++++ CLAUDE_COMPREHENSIVE.md | 669 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 803 insertions(+) create mode 100644 CLAUDE_CHECKLIST.md create mode 100644 CLAUDE_COMPREHENSIVE.md diff --git a/CLAUDE_CHECKLIST.md b/CLAUDE_CHECKLIST.md new file mode 100644 index 000000000..ac43425e9 --- /dev/null +++ b/CLAUDE_CHECKLIST.md @@ -0,0 +1,134 @@ +# CLAUDE_CHECKLIST.md - Pre-Submission Checklist + +Use this checklist before submitting any code changes to ensure your contribution meets SPD repository standards. + +As you work through this checklist, you might notice something and then get distracted when fixing it. You need to restart the checklist again after your fixes. You might therefore want to keep a running list of changes to make, then make them, then start the checklist again for all of them. + +## Code Style & Formatting + +### Naming +- [ ] **Files & modules**: `snake_case.py` +- [ ] **Functions & variables**: `snake_case` +- [ ] **Classes**: `PascalCase` +- [ ] **Constants**: `UPPERCASE_WITH_UNDERSCORES` +- [ ] **Private functions**: Prefixed with `_` +- [ ] **Abbreviations**: Uppercase (e.g., `CI`, `L0`, `KL`) + +### Type Annotations +- [ ] **Used jaxtyping for tensors** - `Float[Tensor, "... C d_in"]` format (runtime checking not yet enabled) +- [ ] **Used PEP 604 unions** - `str | None` NOT `Optional[str]` +- [ ] **Used lowercase generics** - `dict`, `list`, `tuple` NOT `Dict`, `List`, `Tuple` +- [ ] **Avoided redundant annotations** - Don't write `my_thing: Thing = Thing()` or `name: str = "John"` +- [ ] **Type checking passes with no errors** - Run `make type` successfully and fix all issues (uses basedpyright, NOT mypy) + +### Comments & Documentation +- [ ] **No obvious comments** - If code is self-explanatory, no comment needed. (Temp comments during development are fine if cleaned up before committing) +- [ ] **Complex logic explained** - Comments focus on "why" not "what" +- [ ] **Google-style docstrings** - Used `Args:`, `Returns:`, `Raises:` sections where needed +- [ ] **Non-obvious information only** - Docstrings don't repeat what's obvious from signature + +### Formatting +- [ ] **Ruff formatting applied** - Run `make format` before committing + +## Code Quality + +### Error Handling (Fail Fast) +- [ ] **Liberal assertions** - Assert all assumptions about data/state +- [ ] **Clear error messages** - Assertions include descriptive messages +- [ ] **Explicit error types** - Use `ValueError`, `NotImplementedError`, `RuntimeError` appropriately +- [ ] **Fail immediately** - Code fails when wrong, doesn't recover silently +- [ ] **Use try-except only for expected errors** - Assertions for invariants/assumptions. Try-except only when errors are expected and handled (e.g., path resolution, file not found) + +### Tensor Operations +- [ ] **Used einops by default** - Preferred over raw einsum for clarity +- [ ] **Asserted shapes liberally** - Verify tensor dimensions +- [ ] **Documented complex operations** - Explain non-obvious tensor manipulations + +### Design Patterns +- [ ] **Followed existing patterns** - Match architecture style of surrounding code (ABC for interfaces, Protocol for metrics, Pydantic for configs) +- [ ] **Metrics decoupled** - Each metric in its own file within `spd/metrics/` directory. Figures in `spd/figures.py` +- [ ] **Used Pydantic for configs** - Configs are frozen (`frozen=True`) and forbid extras (`extra="forbid"`) +- [ ] **Config paths handled correctly** - If handling paths in configs, support both relative paths and `wandb:` prefix format +- [ ] **New experiments registered** - If adding new experiment, added to `spd/registry.py` with proper structure +- [ ] **Experiment structure followed** - Experiments have `models.py`, `configs.py`, `{task}_decomposition.py` in flat structure + +## Testing + +- [ ] **Tests written** - Unit tests for new functionality. Regression tests for bug fixes. +- [ ] **Tests run successfully** - Run `make test` (or `make test-all` if relevant) +- [ ] **Test files named correctly** - `test_*.py` format +- [ ] **Test functions named correctly** - `def test_*():` with descriptive names +- [ ] **Slow tests marked** - Used `@pytest.mark.slow` for slow tests +- [ ] **Focus on unit tests** - Not production code (no deployment). Integration tests often too much overhead for research code. Interactive use catches issues at low cost. Add integration tests only if testing complex interactions that can't be validated in units. + +## Pre-Commit Checks + +- [ ] **Ran `make check`** - Full pre-commit suite passes (format + type check) +- [ ] **No type errors** - basedpyright reports no issues +- [ ] **No lint errors** - ruff reports no issues + +## Git & Version Control + +### Before Committing +- [ ] **Reviewed every line of the diff** - Understand every change being committed +- [ ] **Only relevant files staged** - Don't commit unrelated changes or all files +- [ ] **No secrets committed** - No `.env`, `credentials.json`, or similar files +- [ ] **Used correct branch name** - Format: `refactor/X`, `feature/Y`, or `fix/Z` + +### Commit Message +- [ ] **Explains "what" and "why"** - Not just describing the diff +- [ ] **Clear and descriptive** - Focused on relevant changes +- [ ] **Explains purpose** - Why this change is needed + +### Committing +- [ ] **NOT using `--no-verify`** - Almost never appropriate. Pre-commit checks exist for a reason. +- [ ] **Pre-commit hooks run** - Automatically runs `make format` and `make type` +- [ ] **All hooks passed** - No failures from pre-commit checks + +## Pull Request (if creating) + +### PR Content +- [ ] **Analyzed all changes** - Reviewed git diff and git status before creating PR +- [ ] **Title is clear** - Concise summary of changes +- [ ] **Used PR template** - Filled out all sections in `.github/pull_request_template.md`: + - Description - What changed + - Related Issue - "Closes #XX" format if applicable + - Motivation and Context - Why needed + - Testing - How tested + - Breaking Changes - Listed if any + +### PR Quality +- [ ] **All CI checks pass** - GitHub Actions successful +- [ ] **Merged latest from dev** - Branch is up to date +- [ ] **Only relevant files** - No unrelated changes included +- [ ] **Self-reviewed** - Went through diff yourself first + +## Cluster Usage (if applicable) + +If running experiments on the cluster: +- [ ] **NOT exceeding 8 GPUs total** - Including all sweeps/evals combined +- [ ] **Monitored jobs** - Used `squeue` to check current usage +- [ ] **Used appropriate resources** - GPU vs CPU flags set correctly + +## Final Self-Review + +- [ ] **Code is simple** - Straightforward for researchers with varying experience +- [ ] **No over-engineering** - Only made changes directly requested or clearly necessary +- [ ] **No unnecessary features** - Didn't add extra functionality beyond the task +- [ ] **No premature abstraction** - Didn't create helpers/utilities for one-time operations +- [ ] **No backwards-compatibility hacks** - Removed unused code completely instead of commenting +- [ ] **Followed fail-fast principle** - Code fails immediately when assumptions violated +- [ ] **Type safety maintained** - All functions properly typed +- [ ] **Tests are sufficient** - Core functionality tested, not over-tested + +## Common Mistakes to Avoid + +- ❌ Forgetting to remove obvious comments like `# get dataloader` +- ❌ Committing without running `make check` +- ❌ Using `--no-verify` flag +- ❌ Recovering silently from errors instead of failing +- ❌ Adding type annotations to obvious assignments like `name: str = "John"` +- ❌ Committing all files instead of only relevant changes +- ❌ Using more than 8 GPUs on cluster (total across all jobs) +- ❌ Failing to consult CLAUDE_COMPREHENSIVE.md for clarification in cases where the checklist is unclear. +- ❌ Starting this checklist, noticing an issue, fixing it, and then forgetting to start the checklist **from the start** again. diff --git a/CLAUDE_COMPREHENSIVE.md b/CLAUDE_COMPREHENSIVE.md new file mode 100644 index 000000000..2d387a674 --- /dev/null +++ b/CLAUDE_COMPREHENSIVE.md @@ -0,0 +1,669 @@ +# CLAUDE_COMPREHENSIVE.md - Complete Development Guide for SPD + +This guide covers everything needed to understand, develop, and contribute to the SPD (Stochastic Parameter Decomposition) codebase. + +## 1. Introduction + +For AI assistants and developers. Covers: +- Environment setup and project structure +- Development philosophy and coding standards +- Architecture patterns and design principles +- Common workflows and usage patterns +- Testing, deployment, and collaboration practices + +### How to Use This Guide + +**Two Documents:** +- **CLAUDE_COMPREHENSIVE.md** (this document) - Complete reference for understanding the codebase, architecture, and development practices. Read this to learn how the project works. +- **CLAUDE_CHECKLIST.md** - Pre-submission checklist for verifying your code changes meet SPD standards. Use this before committing to ensure your work follows all conventions. + +**Workflow:** Read the comprehensive guide to understand context, then use the checklist to verify your changes before submission. + +## 2. Environment Setup & Quick Start + +**IMPORTANT**: Always activate the virtual environment before running Python or git operations: +```bash +source .venv/bin/activate +``` + +**Installation:** +```bash +make install-dev # Install with dev dependencies and pre-commit hooks +make install # Install package only (pip install -e .) +``` + +**Environment:** +- `.env` file with WandB credentials (see `.env.example`) +- WandB for experiment tracking and model storage +- Runs generate timestamped output directories (configs, models, plots) + +## 3. Project Overview + +SPD is a research framework for analyzing neural network components through sparse parameter decomposition. Supports experimental domains: +- **TMS** (Toy Model of Superposition) +- **ResidualMLP** (residual MLP analysis) +- **Language Models** +- **Identity Insertion** + +### Available Experiments + +Defined in `spd/registry.py`: + +- `tms_5-2`, `tms_5-2-id` - TMS: 5 features, 2 hidden dims (id = fixed identity in-between) +- `tms_40-10`, `tms_40-10-id` - TMS: 40 features, 10 hidden dims +- `resid_mlp1`, `resid_mlp2`, `resid_mlp3` - ResidualMLP: 1-3 layers +- `ss_emb` - Language models (from HuggingFace) + +### Research Papers + +**Stochastic Parameter Decomposition (SPD)** +- [`papers/Stochastic_Parameter_Decomposition/spd_paper.md`](papers/Stochastic_Parameter_Decomposition/spd_paper.md) +- Introduces core SPD framework, stochastic masking, and optimization techniques +- Note: Development has continued beyond the paper implementation + +**Attribution-based Parameter Decomposition (APD)** +- [`papers/Attribution_based_Parameter_Decomposition/apd_paper.md`](papers/Attribution_based_Parameter_Decomposition/apd_paper.md) +- Precursor to SPD, first linear parameter decomposition +- High-level conceptual insights and theoretical foundations + +### Key Data Flow + +1. Experiments load pretrained target models via WandB or local paths +2. Target models are wrapped in ComponentModel with specified target modules +3. SPD optimization runs via `spd.run_spd.optimize()` with config-driven loss combination +4. Results include component masks, causal importance scores, and visualizations + +### Component Analysis + +- Components = sparse decompositions of model parameters +- Stochastic masking enables differentiable sparsity +- Causal importance quantifies contributions +- Loss terms balance faithfulness, reconstruction, sparsity + +## 4. Development Philosophy & Principles + +### Core Principles (TLDR) + +1. **Simplicity First** - Code for researchers with varying experience. Prioritize simple, straightforward code. + +2. **Type Safety** - Use types, einops, jaxtyping, liberal assertions, Pydantic validation, strict pyright. + +3. **Fail Fast** - Code fails immediately when wrong, not silently. Liberal assertions, clear errors, explicit types. + +4. **Documentation** - Comments for complex logic only. Skip obvious comments. + +5. **Modularity** - Registry pattern, abstract interfaces, protocols. Decouple metrics from core. + +6. **Reproducibility** - Centralized configs, seed management, WandB tracking. + +7. **Performance** - Distributed training, parallel testing, optimized CI/CD. + +8. **Maintainability** - Consistent naming, clear architecture, comprehensive tooling. + +## 5. Development Workflow & Commands + +**Package Manager:** uv (NOT pip/poetry) + +### Make Targets + +```bash +make install # Install package only +make install-dev # Install with dev deps and pre-commit hooks +make check # Run full pre-commit suite (format + type check) +make format # Ruff lint + format +make type # BasedPyright type checking +make test # Run tests (excluding slow tests) +make test-all # Run all tests including slow ones +make coverage # Generate coverage reports +``` + +### Pre-commit Hooks + +Automatically run `make format` and `make type` before commits (install with `make install-dev`) + +### CI/CD Pipeline (GitHub Actions) + +1. Checkout code +2. Set up Python 3.13 via uv +3. Install dependencies with CPU-only PyTorch +4. Run basedpyright type checking +5. Run ruff lint and format +6. Run pytest with parallel execution (max 4 workers) + +**Special CI install:** +```bash +make install-ci # Uses CPU wheels, unsafe-best-match index strategy +``` + +## 6. Code Style & Formatting + +### Naming Conventions + +- **Files & modules**: `snake_case.py` (e.g., `component_model.py`) +- **Functions & variables**: `snake_case` (e.g., `create_data_loader()`) +- **Classes**: `PascalCase` (e.g., `ComponentModel`) +- **Constants**: `UPPERCASE_WITH_UNDERSCORES` (e.g., `REPO_ROOT`) +- **Private functions**: Prefix with underscore (e.g., `_infer_backend()`) +- **Abbreviations**: Uppercase in variables (e.g., `CI`, `L0`, `KL`) + +### Formatting Rules + +- **Line length**: 100 characters (strict, enforced by ruff) +- **Formatter**: ruff (configured in pyproject.toml) +- **Import organization**: stdlib → third-party → local +- **Import sorting**: Handled by ruff/isort + +**Ruff Configuration:** +- Enabled rules: pycodestyle (E), Pyflakes (F), pyupgrade (UP), flake8-bugbear (B), flake8-simplify (SIM), isort (I) +- Ignored: F722 (jaxtyping incompatibility), E731 (lambda functions allowed), E501 (long lines) + +## 7. Type Annotations + +### Core Principles + +- Use **jaxtyping** for tensor shapes: `Float[Tensor, "... C d_in"]` (runtime checking not yet enabled) +- Use **PEP 604 union syntax**: `str | None` (NOT `Optional[str]`) +- Use **lowercase generic types**: `dict`, `list`, `tuple` (NOT `Dict`, `List`, `Tuple`) +- **Don't annotate when redundant**: `my_thing = Thing()` not `my_thing: Thing = Thing()`, or `name = "John"` not `name: str = "John"` + +### Examples + +```python +# Good - jaxtyping with explicit dimensions +def forward(self, x: Float[Tensor, "... C d_in"]) -> Float[Tensor, "... C d_out"]: + return einops.einsum(x, self.W, "... C d_in, C d_in d_out -> ... C d_out") + self.b + +# Good - PEP 604 union syntax +def load_model(path: str | Path) -> Model | None: + pass + +# Bad - old style +from typing import Optional, Dict +def load_model(path: Optional[str]) -> Dict[str, Any]: + pass +``` + +### Type Checking + +- Uses **basedpyright** (NOT mypy) - forked pyright for better performance +- Strict mode enabled: `strictListInference`, `strictDictionaryInference`, `strictSetInference` +- Reports: `MissingTypeArgument`, `UnknownParameterType`, `IncompatibleMethodOverride`, `ImportCycles` +- Excluded: `wandb` directory, third-party code, frontend +- Run with `make type` + +## 8. Documentation & Comments + +### Philosophy: Don't Write Obvious Comments + +Your first instinct should be: **"If I couldn't write any comments, how would I write this code?"** + +If code is self-explanatory, skip the comment. Only comment to explain complex logic, focusing on **"why" not "what"**. + +If you find it helps you develop, you can write whatever comments you like when developing, so long as you remember to come back and fix them later. + +### Bad (Obvious): +```python +# get dataloader +dataloader = get_dataloader(config) +``` + +### Good (Explains Complex Logic): +```python +# We need to mask out future positions for causal attention +# Upper triangular matrix excludes the diagonal (hence k=1) +causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) +``` + +### Docstring Format + +Use **Google-style** with `Args:`, `Returns:`, `Raises:` sections. Single-line for simple functions, multi-line for complex. Focus on non-obvious information. + +```python +def tokenize_and_concatenate(dataset: Dataset, tokenizer: PreTrainedTokenizer, ...) -> Dataset: + """Tokenize and concatenate a dataset of text. + + Args: + dataset: HuggingFace dataset to tokenize + tokenizer: Pretrained tokenizer to use + ... + + Returns: + Tokenized and concatenated dataset + """ +``` + +## 9. Architecture & Design Patterns + +### Core Pattern: Wrapper + Registry + Config + +1. **ComponentModel**: Wraps PyTorch models and injects components +2. **Registry** (`registry.py`): Centralized experiment configuration +3. **Config System** (Pydantic): Type-safe config loading/validation + +### Design Principle: Decouple Metrics from Core + +Metric and figure code encapsulated in `spd/metrics.py` and `spd/figures.py`. + +### Key Design Patterns + +**1. Abstract Base Classes for Interfaces** +```python +class LoadableModule(nn.Module, ABC): + @classmethod + @abstractmethod + def from_pretrained(cls, _path: ModelPath) -> "LoadableModule": + raise NotImplementedError("Subclasses must implement from_pretrained method.") +``` + +**2. Protocol-Based Design** +```python +class Metric(Protocol): + slow: ClassVar[bool] = False + metric_section: ClassVar[str] + + def update(...) -> None: ... + def compute(self) -> Any: ... +``` + +**3. Dataclass-Based Configuration** +```python +@dataclass +class ExperimentConfig: + task_name: TaskName + decomp_script: Path + config_path: Path + expected_runtime: int + canonical_run: str | None = None +``` + +**4. Pydantic for Runtime Validation** +```python +class BaseConfig(BaseModel): + model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True) + + @classmethod + def from_file(cls, path: Path | str) -> Self: + """Load config from path to a JSON or YAML file.""" +``` + +### Core Architecture Components + +- `spd/run_spd.py` - Main SPD optimization logic +- `spd/configs.py` - Pydantic config classes +- `spd/registry.py` - Centralized experiment registry +- `spd/models/component_model.py` - ComponentModel wrapper +- `spd/models/components.py` - Component types (Linear, Embedding, etc.) +- `spd/losses.py` - Loss functions (faithfulness, reconstruction, importance minimality) +- `spd/metrics.py` - Metrics (CI-L0, KL divergence, etc.) +- `spd/figures.py` - Figures (CI histograms, Identity plots, etc.) + +## 10. Project Structure + +``` +spd/ +├── spd/ # Main package +│ ├── models/ # Core model classes +│ ├── metrics/ # Metric implementations +│ ├── utils/ # Utilities (distributed, logging, data) +│ ├── experiments/ # Experiment implementations +│ │ ├── tms/ # Toy Model of Superposition +│ │ ├── resid_mlp/ # Residual MLP +│ │ ├── lm/ # Language models +│ │ └── ih/ # Identity insertion +│ ├── app/ # Streamlit application +│ │ ├── backend/ +│ │ └── frontend/ +│ ├── scripts/ # CLI entry points +│ └── [core modules] +├── tests/ # Test suite +│ ├── metrics/ # Metric tests +│ ├── scripts_run/ # Integration tests +│ └── [unit tests] +├── papers/ # Research papers (markdown) +├── typings/ # Type stubs +└── [configuration files] +``` + +### Organizational Principles + +- **Flat within experiments**: Each has `models.py`, `configs.py`, `{task}_decomposition.py`, `train_*.py`, `*_config.yaml`, `plotting.py` +- **Centralized registry**: `spd/registry.py` manages experiment configs +- **Clear separation**: Core logic vs metrics vs experiments +- **Modular metrics**: Each metric in its own file + +## 11. Configuration System + +### Multi-layered Configuration + +1. **YAML config files** define experiment parameters +2. **Pydantic config classes** provide type safety and validation +3. **Environment variables** can override runtime settings +4. **Nested config objects** for task-specific configs + +### Key Conventions + +- Paths: relative to repo root or `"wandb:"` prefix for WandB paths +- Configs **immutable** (`frozen=True`) and **forbid extra fields** (`extra="forbid"`) +- `ModelPath` type validates and normalizes paths automatically +- Pydantic validators handle deprecated keys and path resolution + +### Example Config + +```yaml +wandb_project: spd +seed: 0 +C: 1200 +n_mask_samples: 1 +ci_fn_type: "shared_mlp" +ci_fn_hidden_dims: [1000] +loss_metric_configs: + - classname: "ImportanceMinimalityLoss" + coeff: 0.004 + pnorm: 2.0 +``` + + +## 12. Error Handling & Fail Fast + +### Fail-Fast Philosophy (Negative Space Programming) + +Code should fail immediately when assumptions are violated, preventing bugs from propagating. + +### Assertions + +**If there's an assumption you're making while writing code, assert it:** +- If you were right, then it won't matter. If you were wrong, then the code **should** fail + +```python +assert component_params, "component_params is empty" +assert x.shape[-1] == 1, "Last dimension should be 1 after the final layer" +assert cfg.coeff is not None, "All loss metric configs must have a coeff" +``` + +### Explicit Error Types + +```python +raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}") +raise NotImplementedError("Subclasses must implement from_pretrained method.") +raise RuntimeError("Embedding modules not supported for identity insertion") +``` + +### Try-Except for Expected Errors + +```python +try: + return path.relative_to(REPO_ROOT) +except ValueError: + # If the path is not relative to REPO_ROOT, return the original path + return path +``` + +## 13. Tensor Operations + +### Use Einops for Clarity + +- Try to use **einops** by default for clarity over raw einsum +- **Assert shapes liberally** +- **Document complex tensor manipulations** + +**Example:** +```python +# Preferred - clear dimensions +result = einops.einsum(x, self.W, "... C d_in, C d_in d_out -> ... C d_out") + self.b + +# Also good - assert shapes +assert x.shape[-1] == d_in, f"Expected last dim to be {d_in}, got {x.shape[-1]}" +``` + +## 14. Testing Strategy + +### Testing Philosophy + +Tests ensure code works as expected, not for production (no deployment). Focus on unit tests for core functionality. Don't worry about integration/end-to-end tests - too much overhead for research code. Interactive use catches issues at low cost. + +**Framework:** pytest with pytest-xdist for parallel execution + +### Test Organization + +- **Test files**: `test_*.py` +- **Test functions**: `def test_*():` with descriptive names +- **Tests mirror source structure**: `tests/metrics/`, `tests/scripts_run/` +- **Fixtures centralized** in `conftest.py` and `metrics/fixtures.py` + +### Test Markers + +- `@pytest.mark.slow` - Excluded by default, run with `make test-all` +- `@pytest.mark.requires_wandb` - Tests requiring WandB access + +## 15. Logging + +Use `spd.log.logger` with special methods: `.info()`, `.warning()`, `.error()` (standard), `.values()` (dict of metrics), `.section()` (visual separator), `.set_format()` (swap formatter). + +```python +from spd.log import logger +logger.values({"loss": 0.42}, msg="Training metrics") +logger.section("Evaluation Phase") +``` + +**Config:** Console (INFO), File (WARNING → `logs/logs.log`), named "spd" + +## 16. Common Usage Patterns + +### Running SPD Experiments + +Use `spd-run` command: + +```bash +spd-run --experiments tms_5-2 # Specific experiment +spd-run --experiments tms_5-2,resid_mlp1 # Multiple experiments +spd-run # All experiments +``` + +Or run directly: +```bash +uv run spd/experiments/tms/tms_decomposition.py spd/experiments/tms/tms_5-2_config.yaml +``` + +Outputs: losses and figure paths for analysis. + +### Metrics and Figures + +Defined in `spd/metrics.py` and `spd/figures.py` as dictionaries of functions. Select and parameterize in experiment configs for easy extension without modifying core framework. + +### Running Sweeps + +Run hyperparameter sweeps using WandB on the GPU cluster: + +```bash +spd-run --experiments --sweep --n-agents [--cpu] [--job_suffix ] +``` + +**Examples:** +```bash +spd-run --experiments tms_5-2 --sweep --n-agents 4 # Run TMS 5-2 sweep with 4 GPU agents +spd-run --experiments resid_mlp2 --sweep --n-agents 3 --cpu # Run ResidualMLP2 sweep with 3 CPU agents +spd-run --sweep --n-agents 10 # Sweep all experiments with 10 agents +spd-run --experiments tms_5-2 --sweep custom.yaml --n-agents 2 # Use custom sweep params file +``` + +**How it works:** Creates WandB sweep from `spd/scripts/sweep_params.yaml` (or custom), deploys SLURM agents (GPU by default, `--cpu` for CPU), git snapshot for consistency. + +**Sweep parameters:** Load from `sweep_params.yaml` or custom file. Supports global and experiment-specific configs: + +```yaml +# Global parameters applied to all experiments +global: + seed: + values: [0, 1, 2] + lr: + values: [0.001, 0.01] + +# Experiment-specific parameters (override global) +tms_5-2: + seed: + values: [100, 200] # Overrides global seed + task_config: + feature_probability: + values: [0.05, 0.1] +``` + +**Logs:** Agent logs are found in `~/slurm_logs/slurm-_.out` + +### Evaluation Runs + +Run with default hyperparameters: + +```bash +spd-run # All experiments +spd-run --experiments tms_5-2-id,resid_mlp2,resid_mlp3 # Specific experiments +``` + +Multiple experiments without `--sweep` creates W&B report with aggregated visualizations. + +### Additional Options + +```bash +spd-run --project my-project # Use custom W&B project +spd-run --job_suffix test # Add suffix to SLURM job names +spd-run --no-create_report # Skip W&B report creation +``` + +### Cluster Usage Guidelines + +**IMPORTANT:** +- **DO NOT use more than 8 GPUs at one time** +- This includes not setting off multiple sweeps/evals that total >8 GPUs +- Monitor jobs with: `squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me` + +## 17. Distributed Training + +### DistributedState Management + +```python +@dataclass(frozen=True, slots=True) +class DistributedState: + rank: int + world_size: int + local_rank: int + backend: Literal["nccl", "gloo"] +``` + +### Conventions + +- **MPI-based** rank initialization +- **NCCL backend** for GPU, **gloo** for CPU +- Utilities in `spd/utils/distributed.py`: gradient sync, metric averaging, device detection +- `torch.nn.parallel.DistributedDataParallel` for multi-GPU training + +## 18. Git & Pull Request Workflow + +### Branch Naming + +- `refactor/X` - Refactoring work +- `feature/Y` - New features +- `fix/Z` - Bug fixes + +### Using GitHub CLI + +- To view issues and PRs: `gh issue view 28` or `gh pr view 30` +- Use the PR template defined in `.github/pull_request_template.md` +- Important: You should almost never use --no-verify. The pre-commit checks are there for a reason. + +### PR Checklist + +- Review every line of the diff +- All CI checks pass +- Merge latest changes from dev branch +- Use "Closes #XX" format for issue linking +- Only commit files that include relevant changes, don't commit all files + +### Commit Messages + +Explain "what" and "why". Clear, descriptive, focused on relevant changes. Explain purpose, not just the diff. + +### PR Template Sections + +1. Description - What changed +2. Related Issue - Use "Closes #XX" format +3. Motivation and Context - Why needed +4. Testing - How tested +5. Breaking Changes + +## 19. Key Dependencies & Tools + +### Core Stack + +- **PyTorch** (>=2.6) +- **Transformers** - HuggingFace models and tokenizers +- **WandB** (>=0.20.1) - Optional, disable with `wandb_project=None` +- **Pydantic** (<2.12) +- **jaxtyping** - Type annotations for tensors +- **einops** - Tensor operations (preferred over einsum) +- **Fire** - CLI argument parsing + +### Development Tooling + +- **ruff** - Linter and formatter (NOT black + flake8 + isort) +- **basedpyright** - Type checker (NOT mypy) +- **pytest + pytest-xdist** - Testing with parallelization +- **uv** - Package manager (NOT pip/poetry) +- **pre-commit** - Git hooks + +### Additional Libraries + +- **datasets** (>=2.21.0) - HuggingFace data loading +- **streamlit** - Web UI +- **python-dotenv** - Environment variables +- **torchvision** (>=0.23,<0.24) + +## 20. Quick Reference + +### Key Principles Summary + +1. **Simplicity** - Code for researchers with varying experience +2. **Type Safety** - jaxtyping, Pydantic, strict basedpyright +3. **Fail Fast** - Liberal assertions, explicit errors +4. **Minimal Comments** - Complex logic only +5. **Modularity** - Registry pattern, interfaces, protocols +6. **Decouple Metrics** - Separate from core +7. **Reproducibility** - Centralized configs, seeds, WandB +8. **Research Testing** - Unit tests, minimal integration +9. **Clear Architecture** - Wrapper + Registry + Config +10. **Consistent Style** - 100 char, snake_case, PEP 604 + +### Common Commands Cheatsheet + +```bash +# Setup +source .venv/bin/activate +make install-dev + +# Development +make check # Format + type check +make format # Ruff lint and format +make type # Type check only +make test # Run tests (fast) +make test-all # Run all tests + +# Running experiments +spd-run --experiments tms_5-2 +spd-run --experiments tms_5-2 --sweep --n-agents 4 + +# Git/GitHub +gh issue view 28 +gh pr view 30 +git checkout -b feature/my-feature + +# Monitoring cluster +squeue --format="%.18i %.9P %.15j %.12u %.12T %.10M %.9l %.6D %b %R" --me +``` + +### File Locations Reference + +- **Core SPD**: `spd/run_spd.py`, `spd/configs.py`, `spd/registry.py` +- **Models**: `spd/models/component_model.py`, `spd/models/components.py` +- **Metrics**: `spd/metrics.py`, `spd/figures.py` +- **Experiments**: `spd/experiments/{tms,resid_mlp,lm,ih}/` +- **Tests**: `tests/`, `tests/metrics/`, `tests/scripts_run/` +- **Configs**: `spd/experiments/*/\*_config.yaml` +- **Papers**: `papers/Stochastic_Parameter_Decomposition/`, `papers/Attribution_based_Parameter_Decomposition/` From 9be1829f317b8e62e783e43e23a4bc6f98657014 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Tue, 25 Nov 2025 03:46:55 +0000 Subject: [PATCH 17/18] Add checklist cues to prevent common omissions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added two checklist items to prevent future AI assistants from forgetting important steps: - "Checked existing patterns" item to ensure new files follow existing conventions - "Restarted checklist after any changes" with explicit STOP instruction to prevent incomplete verification Also fixed references from "dev branch" to "main branch" throughout both documentation files, as the repository uses main as the primary development branch. These changes address feedback from PR review process where these steps were accidentally omitted. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- CLAUDE_CHECKLIST.md | 4 +++- CLAUDE_COMPREHENSIVE.md | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/CLAUDE_CHECKLIST.md b/CLAUDE_CHECKLIST.md index ac43425e9..9c86c0ca3 100644 --- a/CLAUDE_CHECKLIST.md +++ b/CLAUDE_CHECKLIST.md @@ -70,6 +70,7 @@ As you work through this checklist, you might notice something and then get dist ## Git & Version Control ### Before Committing +- [ ] **Checked existing patterns** - If adding new files (docs, configs, tests, etc.), looked at similar existing files for formatting/structure conventions to follow - [ ] **Reviewed every line of the diff** - Understand every change being committed - [ ] **Only relevant files staged** - Don't commit unrelated changes or all files - [ ] **No secrets committed** - No `.env`, `credentials.json`, or similar files @@ -99,7 +100,7 @@ As you work through this checklist, you might notice something and then get dist ### PR Quality - [ ] **All CI checks pass** - GitHub Actions successful -- [ ] **Merged latest from dev** - Branch is up to date +- [ ] **Merged latest from main** - Branch is up to date - [ ] **Only relevant files** - No unrelated changes included - [ ] **Self-reviewed** - Went through diff yourself first @@ -112,6 +113,7 @@ If running experiments on the cluster: ## Final Self-Review +- [ ] **Restarted checklist after any changes** - If you made ANY changes while going through this checklist, you MUST restart from the beginning. Did you restart? If not, STOP and restart now. - [ ] **Code is simple** - Straightforward for researchers with varying experience - [ ] **No over-engineering** - Only made changes directly requested or clearly necessary - [ ] **No unnecessary features** - Didn't add extra functionality beyond the task diff --git a/CLAUDE_COMPREHENSIVE.md b/CLAUDE_COMPREHENSIVE.md index 2d387a674..cc97df976 100644 --- a/CLAUDE_COMPREHENSIVE.md +++ b/CLAUDE_COMPREHENSIVE.md @@ -573,7 +573,7 @@ class DistributedState: - Review every line of the diff - All CI checks pass -- Merge latest changes from dev branch +- Merge latest changes from main branch - Use "Closes #XX" format for issue linking - Only commit files that include relevant changes, don't commit all files From 783b6e0f94d3ea520d3c2e0b8b34ee354b70a690 Mon Sep 17 00:00:00 2001 From: Lee Sharkey Date: Tue, 25 Nov 2025 06:00:08 +0000 Subject: [PATCH 18/18] Remove obvious comment from compare_models.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per CLAUDE_CHECKLIST.md, removed redundant comment that was obvious from the code itself. The line `alive_mask = mean_component_cis[layer_name] > self.mean_ci_threshold` is self-explanatory. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- spd/scripts/compare_models/compare_models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/spd/scripts/compare_models/compare_models.py b/spd/scripts/compare_models/compare_models.py index a655f06c4..83128afee 100644 --- a/spd/scripts/compare_models/compare_models.py +++ b/spd/scripts/compare_models/compare_models.py @@ -375,7 +375,6 @@ def compute_geometric_similarities( ref_U = reference_components.U ref_V = reference_components.V - # Filter out components that aren't active enough in the current model alive_mask = mean_component_cis[layer_name] > self.mean_ci_threshold C_curr_alive = int(alive_mask.sum().item()) logger.info(