diff --git a/.vscode/launch.json b/.vscode/launch.json index 75c8edbb2..86646f0de 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -164,6 +164,19 @@ "PYDEVD_DISABLE_FILE_VALIDATION": "1" } }, + { + "name": "ss_llama_simple_mlp-1L", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/experiments/lm/lm_decomposition.py", + "args": "${workspaceFolder}/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml", + "python": "${command:python.interpreterPath}", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } + }, { "name": "ss_gpt2", "type": "debugpy", diff --git a/spd/app/backend/lib/activation_contexts.py b/spd/app/backend/lib/activation_contexts.py index a2938a408..276d1c958 100644 --- a/spd/app/backend/lib/activation_contexts.py +++ b/spd/app/backend/lib/activation_contexts.py @@ -182,18 +182,18 @@ def get_activations_data( ci_at_active = window_ci_values[:, n_tokens_either_side] # Move to CPU/numpy once (faster than .tolist()) - batch_idx_np = batch_idx.cpu().numpy() - seq_idx_np = seq_idx.cpu().numpy() - comp_idx_np = comp_idx.cpu().numpy() - window_token_ids_np = window_token_ids.cpu().numpy() - window_ci_values_np = window_ci_values.cpu().numpy() - ci_at_active_np = ci_at_active.cpu().numpy() + batch_idx_np = batch_idx.cpu().float().numpy() + seq_idx_np = seq_idx.cpu().float().numpy() + comp_idx_np = comp_idx.cpu().float().numpy() + window_token_ids_np = window_token_ids.cpu().float().numpy() + window_ci_values_np = window_ci_values.cpu().float().numpy() + ci_at_active_np = ci_at_active.cpu().float().numpy() # Get token IDs at active position for token counting active_token_ids = window_token_ids_np[:, n_tokens_either_side] # Get predicted tokens at each firing position - firing_predicted_tokens = predicted_token_ids[batch_idx, seq_idx].cpu().numpy() + firing_predicted_tokens = predicted_token_ids[batch_idx, seq_idx].cpu().float().numpy() # Process by component - group firings and use batch add unique_components = np.unique(comp_idx_np) diff --git a/spd/app/frontend/vite.config.ts b/spd/app/frontend/vite.config.ts index bfec3a4ab..e9c93e7fd 100644 --- a/spd/app/frontend/vite.config.ts +++ b/spd/app/frontend/vite.config.ts @@ -4,7 +4,7 @@ import { svelte } from "@sveltejs/vite-plugin-svelte"; // https://vite.dev/config/ export default defineConfig({ plugins: [svelte()], - // server: { - // hmr: false, - // }, + server: { + hmr: false, + }, }); diff --git a/spd/configs.py b/spd/configs.py index b2f9647df..1713a99b3 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -227,6 +227,8 @@ class UVPlotsConfig(BaseConfig): SamplingType = Literal["continuous", "binomial"] +DType = Literal["float32", "bfloat16"] + class Config(BaseConfig): # --- WandB @@ -245,6 +247,10 @@ class Config(BaseConfig): # --- General --- seed: int = Field(default=0, description="Random seed for reproducibility") + dtype: DType = Field( + default="bfloat16", + description="Default torch dtype for computation. Supports 'float32' and 'bfloat16'.", + ) C: PositiveInt = Field( ..., description="The number of subcomponents per layer", diff --git a/spd/experiments/ih/train_ih.py b/spd/experiments/ih/train_ih.py index fee39f3c8..4bc6a4c86 100644 --- a/spd/experiments/ih/train_ih.py +++ b/spd/experiments/ih/train_ih.py @@ -230,8 +230,8 @@ def plot_attention_maps_post_training( for layer_index in range(model.config.n_layers): for head_index in range(model.config.n_heads): - avg_attn = avg_attn_weights[layer_index, head_index, :, :].cpu().numpy() - max_attn = max_attn_weights[layer_index, head_index, :, :].cpu().numpy() + avg_attn = avg_attn_weights[layer_index, head_index, :, :].cpu().float().numpy() + max_attn = max_attn_weights[layer_index, head_index, :, :].cpu().float().numpy() fig, ax = plt.subplots(1, 2, figsize=(12, 6)) assert isinstance(ax, np.ndarray), "Expected ax to be a numpy array of axes" diff --git a/spd/experiments/lm/lm_decomposition.py b/spd/experiments/lm/lm_decomposition.py index 59aa8120e..a463eaa23 100644 --- a/spd/experiments/lm/lm_decomposition.py +++ b/spd/experiments/lm/lm_decomposition.py @@ -5,10 +5,11 @@ from pathlib import Path import fire +import torch import wandb from simple_stories_train.run_info import RunInfo as SSRunInfo -from spd.configs import Config +from spd.configs import Config, DType from spd.data import DatasetConfig, create_data_loader from spd.experiments.lm.configs import LMTaskConfig from spd.log import logger @@ -26,6 +27,11 @@ from spd.utils.run_utils import get_output_dir from spd.utils.wandb_utils import init_wandb +DTYPE_MAP: dict[DType, torch.dtype] = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + @with_distributed_cleanup def main( @@ -107,6 +113,8 @@ def main( pretrained_model_class.from_pretrained, # pyright: ignore[reportAttributeAccessIssue] config.pretrained_model_name, ) + + # target_model.to(DTYPE_MAP[config.dtype]) target_model.eval() if is_main_process(): diff --git a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml index fc25649d6..9f5181128 100644 --- a/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml +++ b/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml @@ -10,12 +10,12 @@ ci_fn_hidden_dims: sampling: continuous sigmoid_type: leaky_hard target_module_patterns: - - "h.*.mlp.c_fc" - - "h.*.mlp.down_proj" - - "h.*.attn.q_proj" - - "h.*.attn.k_proj" - - "h.*.attn.v_proj" - - "h.*.attn.o_proj" +- h.*.mlp.c_fc +- h.*.mlp.down_proj +- h.*.attn.q_proj +- h.*.attn.k_proj +- h.*.attn.v_proj +- h.*.attn.o_proj identity_module_patterns: null use_delta_component: true loss_metric_configs: @@ -28,12 +28,16 @@ loss_metric_configs: eps: 1.0e-12 - coeff: 0.5 classname: StochasticReconSubsetLoss + routing: + type: uniform_k_subset - coeff: 0.5 init: random step_size: 1.0 n_steps: 1 mask_scope: shared_across_batch classname: PGDReconSubsetLoss + routing: + type: uniform_k_subset - coeff: 1000000.0 classname: FaithfulnessLoss output_loss_type: kl @@ -41,6 +45,8 @@ lr: 0.0002 steps: 400000 batch_size: 64 gradient_accumulation_steps: 1 +grad_clip_norm_components: null +grad_clip_norm_ci_fns: null faithfulness_warmup_steps: 200 faithfulness_warmup_lr: 0.001 faithfulness_warmup_weight_decay: 0.0 @@ -54,7 +60,7 @@ eval_batch_size: 64 slow_eval_freq: 10000 n_eval_steps: 5 slow_eval_on_first_step: true -save_freq: 60000 +save_freq: null eval_metric_configs: - classname: CIHistograms n_batches_accum: 5 diff --git a/spd/experiments/resid_mlp/plotting.py b/spd/experiments/resid_mlp/plotting.py index 489043ac8..f3fde00f1 100644 --- a/spd/experiments/resid_mlp/plotting.py +++ b/spd/experiments/resid_mlp/plotting.py @@ -65,7 +65,7 @@ def plot_individual_feature_response( color=cmap_viridis(f / n_features), marker=".", s=s[order], - alpha=alpha[order].numpy(), # pyright: ignore[reportArgumentType] + alpha=alpha[order].float().numpy(), # pyright: ignore[reportArgumentType] # According to the announcement, alpha is allowed to be an iterable since v3.4.0, # but the docs & type annotations seem to be wrong. Here's the announcement: # https://matplotlib.org/stable/users/prev_whats_new/whats_new_3.4.0.html#transparency-alpha-can-be-set-as-an-array-in-collections diff --git a/spd/experiments/resid_mlp/resid_mlp_interp.py b/spd/experiments/resid_mlp/resid_mlp_interp.py index fcbbf5bc6..44b778668 100644 --- a/spd/experiments/resid_mlp/resid_mlp_interp.py +++ b/spd/experiments/resid_mlp/resid_mlp_interp.py @@ -406,8 +406,8 @@ def plot_neuron_contribution_pairs( # Plot points separately for each layer with different colors for layer in range(n_layers): - x_values = relu_conns[layer].flatten().cpu().detach().numpy() - y_values = max_component_contributions[layer].flatten().cpu().detach().numpy() + x_values = relu_conns[layer].flatten().cpu().detach().float().numpy() + y_values = max_component_contributions[layer].flatten().cpu().detach().float().numpy() layer_label = {0: "First MLP", 1: "Second MLP", 2: "Third MLP"}.get(layer, f"Layer {layer}") @@ -451,8 +451,8 @@ def plot_neuron_contribution_pairs( # Add some statistics to the plot # Calculate correlation for all points combined - all_x = relu_conns.flatten().cpu().detach().numpy() - all_y = max_component_contributions.flatten().cpu().detach().numpy() + all_x = relu_conns.flatten().cpu().detach().float().numpy() + all_y = max_component_contributions.flatten().cpu().detach().float().numpy() correlation = np.corrcoef(all_x, all_y)[0, 1] ax.text( 0.05, diff --git a/spd/experiments/resid_mlp/train_resid_mlp.py b/spd/experiments/resid_mlp/train_resid_mlp.py index 3a51ec7d8..409f54cc5 100644 --- a/spd/experiments/resid_mlp/train_resid_mlp.py +++ b/spd/experiments/resid_mlp/train_resid_mlp.py @@ -130,7 +130,7 @@ def train( loss = loss.mean() final_losses.append(loss) final_losses = torch.stack(final_losses).mean().cpu().detach() - logger.info(f"Final losses: {final_losses.numpy()}") + logger.info(f"Final losses: {final_losses.float().numpy()}") return final_losses diff --git a/spd/experiments/tms/plotting.py b/spd/experiments/tms/plotting.py index daa7ec57b..8133a2bf8 100644 --- a/spd/experiments/tms/plotting.py +++ b/spd/experiments/tms/plotting.py @@ -122,7 +122,7 @@ def filter_significant_subnets( # Filter subnets filtered_subnets = subnets[mask] - subnets_indices = subnet_feature_norms_order[:n_significant].cpu().numpy() + subnets_indices = subnet_feature_norms_order[:n_significant].cpu().float().numpy() return filtered_subnets, subnets_indices, n_significant @@ -148,7 +148,7 @@ def plot( for subnet_idx in range(n_subnets): ax = axs[subnet_idx] - self._plot_single_vector(ax, subnets[subnet_idx].cpu().detach().numpy(), colors) + self._plot_single_vector(ax, subnets[subnet_idx].cpu().detach().float().numpy(), colors) self._style_axis(ax) ax.set_title( @@ -225,7 +225,7 @@ def plot( ax = axs[subnet_idx] self._plot_single_network( ax, - subnets_abs[subnet_idx].cpu().detach().numpy(), + subnets_abs[subnet_idx].cpu().detach().float().numpy(), subnets_abs.max().item(), n_features, n_hidden, @@ -464,9 +464,13 @@ def plot( plot_configs.append( { "title": "Target model", - "linear1_weights": model.linear1.weight.T.detach().cpu().numpy(), + "linear1_weights": model.linear1.weight.T.detach().cpu().float().numpy(), "hidden_weights": [ - cast(torch.nn.Linear, model.hidden_layers[i]).weight.T.detach().cpu().numpy() + cast(torch.nn.Linear, model.hidden_layers[i]) + .weight.T.detach() + .cpu() + .float() + .numpy() for i in range(config.n_hidden_layers) ] if config.n_hidden_layers > 0 and model.hidden_layers is not None @@ -476,10 +480,10 @@ def plot( ) # Sum of components - sum_linear1 = linear1_subnets.sum(dim=0).numpy() + sum_linear1 = linear1_subnets.sum(dim=0).float().numpy() sum_hidden = None if hidden_layer_components: - sum_hidden = [hw.sum(dim=0).numpy() for hw in hidden_layer_components] + sum_hidden = [hw.sum(dim=0).float().numpy() for hw in hidden_layer_components] plot_configs.append( { "title": "Sum of components", @@ -494,7 +498,7 @@ def plot( comp_type = component_types[idx] if comp_type == "linear": # Linear component: show weights in linear1/2, zeros in hidden - linear_weights = linear1_subnets[idx].numpy() + linear_weights = linear1_subnets[idx].float().numpy() hidden_weights = None if config.n_hidden_layers > 0 and model.hidden_layers is not None: # Show zeros for hidden layers (not identity) @@ -507,7 +511,7 @@ def plot( linear_weights = np.zeros((config.n_features, config.n_hidden)) hidden_weights = None if hidden_layer_components is not None: - hidden_weights = [hw[idx].numpy() for hw in hidden_layer_components] + hidden_weights = [hw[idx].float().numpy() for hw in hidden_layer_components] plot_configs.append( { @@ -789,7 +793,7 @@ def _plot_heatmaps( for idx in range(n_subnets): ax = axs[idx] - ax.imshow(weights[idx].cpu().detach().numpy(), cmap=cmap, vmin=vmin, vmax=vmax) + ax.imshow(weights[idx].cpu().detach().float().numpy(), cmap=cmap, vmin=vmin, vmax=vmax) # Set title if idx == 0: @@ -923,7 +927,7 @@ def plot_cosine_similarity_analysis(self) -> Figure: _, max_cosine_sim, _ = self.analyzer.compute_cosine_similarities() fig, ax = plt.subplots() - ax.bar(range(max_cosine_sim.shape[0]), max_cosine_sim.cpu().detach().numpy()) + ax.bar(range(max_cosine_sim.shape[0]), max_cosine_sim.cpu().detach().float().numpy()) ax.axhline(1, color="grey", linestyle="--") ax.set_xlabel("Input feature index") ax.set_ylabel("Max cosine similarity") diff --git a/spd/experiments/tms/train_tms.py b/spd/experiments/tms/train_tms.py index e63d9fbd1..8bfefe9c8 100644 --- a/spd/experiments/tms/train_tms.py +++ b/spd/experiments/tms/train_tms.py @@ -97,7 +97,7 @@ def plot_intro_diagram(model: TMSModel, filepath: Path) -> None: plt.rcParams["figure.dpi"] = 200 _, ax = plt.subplots(1, 1, figsize=(2, 2)) - W = WA.cpu().detach().numpy() + W = WA.cpu().detach().float().numpy() ax.scatter(W[:, 0], W[:, 1], c=color) ax.set_aspect("equal") ax.add_collection( @@ -135,7 +135,7 @@ def plot_cosine_similarity_distribution( _, ax = plt.subplots(1, 1, figsize=(4, 4)) - sims = masked_sims.cpu().numpy() + sims = masked_sims.cpu().float().numpy() ax.scatter(sims, np.zeros_like(sims), alpha=0.5) ax.set_xlim(-1, 1) ax.set_xlabel("Cosine Similarity") diff --git a/spd/metrics/ce_and_kl_losses.py b/spd/metrics/ce_and_kl_losses.py index d93dcbc86..8d061bc08 100644 --- a/spd/metrics/ce_and_kl_losses.py +++ b/spd/metrics/ce_and_kl_losses.py @@ -141,7 +141,7 @@ def kl_vs_target(logits: Tensor) -> float: # Rounded causal importances as masks rounded_mask_infos = make_mask_infos( - {k: (v > self.rounding_threshold).float() for k, v in ci.items()} + {k: (v > self.rounding_threshold).to(v.dtype) for k, v in ci.items()} ) rounded_masked_logits = self.model(batch, mask_infos=rounded_mask_infos) rounded_masked_ce_loss = ce_vs_labels(rounded_masked_logits) diff --git a/spd/models/component_model.py b/spd/models/component_model.py index a3fcc4fac..e62d5d97b 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -91,6 +91,7 @@ def __init__( ) self.target_model = target_model + self.target_model.compile() self.C = C self.pretrained_model_output_attr = pretrained_model_output_attr self.target_module_paths = get_target_module_paths(target_model, target_module_patterns) diff --git a/spd/plotting.py b/spd/plotting.py index 81c9c1d5d..9303f7f05 100644 --- a/spd/plotting.py +++ b/spd/plotting.py @@ -61,7 +61,7 @@ def _plot_causal_importances_figure( images = [] for j, (mask_name, mask) in enumerate(ci_vals.items()): # mask has shape (batch, C) or (batch, pos, C) - mask_data = mask.detach().cpu().numpy() + mask_data = mask.detach().cpu().float().numpy() if has_pos_dim: assert mask_data.ndim == 3 mask_data = mask_data[:, 0, :] @@ -127,7 +127,7 @@ def plot_mean_component_cis_both_scales( processed_data = [] for module_name, mean_component_ci in mean_component_cis.items(): sorted_components = torch.sort(mean_component_ci, descending=True)[0] - processed_data.append((module_name, sorted_components.detach().cpu().numpy())) + processed_data.append((module_name, sorted_components.detach().cpu().float().numpy())) # Create both figures images = [] @@ -326,7 +326,7 @@ def plot_UV_matrices( for j, (name, component) in enumerate(sorted(components.items())): # Plot V matrix V = component.V if all_perm_indices is None else component.V[:, all_perm_indices[name]] - V_np = V.detach().cpu().numpy() + V_np = V.detach().cpu().float().numpy() im = axs[j, 0].matshow(V_np, aspect="auto", cmap="coolwarm") axs[j, 0].set_ylabel("d_in index") axs[j, 0].set_xlabel("Component index") @@ -335,7 +335,7 @@ def plot_UV_matrices( # Plot U matrix U = component.U if all_perm_indices is None else component.U[all_perm_indices[name], :] - U_np = U.detach().cpu().numpy() + U_np = U.detach().cpu().float().numpy() im = axs[j, 1].matshow(U_np, aspect="auto", cmap="coolwarm") axs[j, 1].set_ylabel("Component index") axs[j, 1].set_xlabel("d_out index") @@ -400,7 +400,7 @@ def plot_component_activation_density( col = i // n_rows ax = axs[row, col] - data = density.detach().cpu().numpy() + data = density.detach().cpu().float().numpy() ax.hist(data, bins=bins) ax.set_yscale("log") # Beware, memory leak unless gc.collect() is called after eval loop ax.set_title(module_name) # Add module name as title to each subplot @@ -469,7 +469,7 @@ def plot_ci_values_histograms( col = i // n_rows ax = axs[row, col] - data = layer_ci.flatten().cpu().numpy() + data = layer_ci.flatten().cpu().float().numpy() ax.hist(data, bins=bins) ax.set_yscale("log") # Beware, memory leak unless gc.collect() is called after eval loop ax.set_title(f"Causal importances for {layer_name}") diff --git a/spd/run_spd.py b/spd/run_spd.py index 925d1dd0e..43a846843 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -20,6 +20,7 @@ from spd.configs import ( Config, + DType, LossMetricConfigType, MetricConfigType, PGDMultiBatchConfig, @@ -52,6 +53,11 @@ from spd.utils.run_utils import save_file from spd.utils.wandb_utils import try_wandb +DTYPE_MAP: dict[DType, torch.dtype] = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + def run_faithfulness_warmup( component_model: ComponentModel, @@ -127,6 +133,11 @@ def optimize( ) -> None: """Run the optimization loop for LM decomposition.""" + # # Set default dtype for all tensor operations + # torch_dtype = DTYPE_MAP[config.dtype] + # torch.set_default_dtype(torch_dtype) + # logger.info(f"Set default torch dtype to {config.dtype}") + train_iterator = loop_dataloader(train_loader) eval_iterator = loop_dataloader(eval_loader) @@ -258,33 +269,35 @@ def create_pgd_data_iter() -> ( for _ in range(config.gradient_accumulation_steps): microbatch = extract_batch_data(next(train_iterator)).to(device) - # NOTE: we need to call the wrapped_model at least once each step in order to setup - # the DDP gradient syncing for all parameters in the component model. Gradients will - # sync regardless of whether the parameters are used in this call to wrapped_model. - target_model_output: OutputWithCache = wrapped_model(microbatch, cache_type="input") + # Use torch autocast around model forward pass and loss calcs + with torch.autocast(device_type=device, dtype=torch.bfloat16): + # NOTE: we need to call the wrapped_model at least once each step in order to setup + # the DDP gradient syncing for all parameters in the component model. Gradients will + # sync regardless of whether the parameters are used in this call to wrapped_model. + target_model_output: OutputWithCache = wrapped_model(microbatch, cache_type="input") + + ci = component_model.calc_causal_importances( + pre_weight_acts=target_model_output.cache, + detach_inputs=False, + sampling=config.sampling, + ) - ci = component_model.calc_causal_importances( - pre_weight_acts=target_model_output.cache, - detach_inputs=False, - sampling=config.sampling, - ) + alive_tracker.update(ci=ci.lower_leaky) - alive_tracker.update(ci=ci.lower_leaky) - - microbatch_total_loss, microbatch_loss_terms = compute_total_loss( - loss_metric_configs=config.loss_metric_configs, - model=component_model, - batch=microbatch, - ci=ci, - target_out=target_model_output.output, - weight_deltas=weight_deltas, - pre_weight_acts=target_model_output.cache, - current_frac_of_training=step / config.steps, - sampling=config.sampling, - use_delta_component=config.use_delta_component, - n_mask_samples=config.n_mask_samples, - output_loss_type=config.output_loss_type, - ) + microbatch_total_loss, microbatch_loss_terms = compute_total_loss( + loss_metric_configs=config.loss_metric_configs, + model=component_model, + batch=microbatch, + ci=ci, + target_out=target_model_output.output, + weight_deltas=weight_deltas, + pre_weight_acts=target_model_output.cache, + current_frac_of_training=step / config.steps, + sampling=config.sampling, + use_delta_component=config.use_delta_component, + n_mask_samples=config.n_mask_samples, + output_loss_type=config.output_loss_type, + ) microbatch_total_loss.div_(config.gradient_accumulation_steps).backward() for loss_name, loss_value in microbatch_loss_terms.items(): diff --git a/spd/utils/component_utils.py b/spd/utils/component_utils.py index 9936ba10e..511fb67e2 100644 --- a/spd/utils/component_utils.py +++ b/spd/utils/component_utils.py @@ -22,9 +22,9 @@ def calc_stochastic_component_mask_info( for layer, ci in causal_importances.items(): match component_mask_sampling: case "binomial": - stochastic_source = torch.randint(0, 2, ci.shape, device=device).float() + stochastic_source = torch.randint(0, 2, ci.shape, device=device).to(dtype) case "continuous": - stochastic_source = torch.rand_like(ci) + stochastic_source = torch.rand_like(ci).to(dtype) component_masks[layer] = ci + (1 - ci) * stochastic_source weight_deltas_and_masks: dict[str, WeightDeltaAndMask] | None = None diff --git a/spd/utils/target_ci_solutions.py b/spd/utils/target_ci_solutions.py index 87ad929c2..97d2540f0 100644 --- a/spd/utils/target_ci_solutions.py +++ b/spd/utils/target_ci_solutions.py @@ -71,7 +71,7 @@ def permute_to_identity_hungarian( effective_rows = min(batch, C) # Hungarian algorithm on the effective_rows x C submatrix - cost_matrix = -ci_vals[:effective_rows].detach().cpu().numpy() + cost_matrix = -ci_vals[:effective_rows].detach().cpu().float().numpy() _, col_indices = linear_sum_assignment(cost_matrix) # Build complete permutation