Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,19 @@
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
},
{
"name": "ss_llama_simple_mlp-1L",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/spd/experiments/lm/lm_decomposition.py",
"args": "${workspaceFolder}/spd/experiments/lm/ss_llama_simple_mlp-1L.yaml",
"python": "${command:python.interpreterPath}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
},
{
"name": "ss_gpt2",
"type": "debugpy",
Expand Down
14 changes: 7 additions & 7 deletions spd/app/backend/lib/activation_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,18 @@ def get_activations_data(
ci_at_active = window_ci_values[:, n_tokens_either_side]

# Move to CPU/numpy once (faster than .tolist())
batch_idx_np = batch_idx.cpu().numpy()
seq_idx_np = seq_idx.cpu().numpy()
comp_idx_np = comp_idx.cpu().numpy()
window_token_ids_np = window_token_ids.cpu().numpy()
window_ci_values_np = window_ci_values.cpu().numpy()
ci_at_active_np = ci_at_active.cpu().numpy()
batch_idx_np = batch_idx.cpu().float().numpy()
seq_idx_np = seq_idx.cpu().float().numpy()
comp_idx_np = comp_idx.cpu().float().numpy()
window_token_ids_np = window_token_ids.cpu().float().numpy()
window_ci_values_np = window_ci_values.cpu().float().numpy()
ci_at_active_np = ci_at_active.cpu().float().numpy()

# Get token IDs at active position for token counting
active_token_ids = window_token_ids_np[:, n_tokens_either_side]

# Get predicted tokens at each firing position
firing_predicted_tokens = predicted_token_ids[batch_idx, seq_idx].cpu().numpy()
firing_predicted_tokens = predicted_token_ids[batch_idx, seq_idx].cpu().float().numpy()

# Process by component - group firings and use batch add
unique_components = np.unique(comp_idx_np)
Expand Down
6 changes: 3 additions & 3 deletions spd/app/frontend/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { svelte } from "@sveltejs/vite-plugin-svelte";
// https://vite.dev/config/
export default defineConfig({
plugins: [svelte()],
// server: {
// hmr: false,
// },
server: {
hmr: false,
},
});
6 changes: 6 additions & 0 deletions spd/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ class UVPlotsConfig(BaseConfig):

SamplingType = Literal["continuous", "binomial"]

DType = Literal["float32", "bfloat16"]


class Config(BaseConfig):
# --- WandB
Expand All @@ -245,6 +247,10 @@ class Config(BaseConfig):

# --- General ---
seed: int = Field(default=0, description="Random seed for reproducibility")
dtype: DType = Field(
default="bfloat16",
description="Default torch dtype for computation. Supports 'float32' and 'bfloat16'.",
)
C: PositiveInt = Field(
...,
description="The number of subcomponents per layer",
Expand Down
4 changes: 2 additions & 2 deletions spd/experiments/ih/train_ih.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,8 @@ def plot_attention_maps_post_training(

for layer_index in range(model.config.n_layers):
for head_index in range(model.config.n_heads):
avg_attn = avg_attn_weights[layer_index, head_index, :, :].cpu().numpy()
max_attn = max_attn_weights[layer_index, head_index, :, :].cpu().numpy()
avg_attn = avg_attn_weights[layer_index, head_index, :, :].cpu().float().numpy()
max_attn = max_attn_weights[layer_index, head_index, :, :].cpu().float().numpy()

fig, ax = plt.subplots(1, 2, figsize=(12, 6))
assert isinstance(ax, np.ndarray), "Expected ax to be a numpy array of axes"
Expand Down
10 changes: 9 additions & 1 deletion spd/experiments/lm/lm_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from pathlib import Path

import fire
import torch
import wandb
from simple_stories_train.run_info import RunInfo as SSRunInfo

from spd.configs import Config
from spd.configs import Config, DType
from spd.data import DatasetConfig, create_data_loader
from spd.experiments.lm.configs import LMTaskConfig
from spd.log import logger
Expand All @@ -26,6 +27,11 @@
from spd.utils.run_utils import get_output_dir
from spd.utils.wandb_utils import init_wandb

DTYPE_MAP: dict[DType, torch.dtype] = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}


@with_distributed_cleanup
def main(
Expand Down Expand Up @@ -107,6 +113,8 @@ def main(
pretrained_model_class.from_pretrained, # pyright: ignore[reportAttributeAccessIssue]
config.pretrained_model_name,
)

# target_model.to(DTYPE_MAP[config.dtype])
target_model.eval()

if is_main_process():
Expand Down
20 changes: 13 additions & 7 deletions spd/experiments/lm/ss_llama_simple_mlp-1L.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ ci_fn_hidden_dims:
sampling: continuous
sigmoid_type: leaky_hard
target_module_patterns:
- "h.*.mlp.c_fc"
- "h.*.mlp.down_proj"
- "h.*.attn.q_proj"
- "h.*.attn.k_proj"
- "h.*.attn.v_proj"
- "h.*.attn.o_proj"
- h.*.mlp.c_fc
- h.*.mlp.down_proj
- h.*.attn.q_proj
- h.*.attn.k_proj
- h.*.attn.v_proj
- h.*.attn.o_proj
identity_module_patterns: null
use_delta_component: true
loss_metric_configs:
Expand All @@ -28,19 +28,25 @@ loss_metric_configs:
eps: 1.0e-12
- coeff: 0.5
classname: StochasticReconSubsetLoss
routing:
type: uniform_k_subset
- coeff: 0.5
init: random
step_size: 1.0
n_steps: 1
mask_scope: shared_across_batch
classname: PGDReconSubsetLoss
routing:
type: uniform_k_subset
- coeff: 1000000.0
classname: FaithfulnessLoss
output_loss_type: kl
lr: 0.0002
steps: 400000
batch_size: 64
gradient_accumulation_steps: 1
grad_clip_norm_components: null
grad_clip_norm_ci_fns: null
faithfulness_warmup_steps: 200
faithfulness_warmup_lr: 0.001
faithfulness_warmup_weight_decay: 0.0
Expand All @@ -54,7 +60,7 @@ eval_batch_size: 64
slow_eval_freq: 10000
n_eval_steps: 5
slow_eval_on_first_step: true
save_freq: 60000
save_freq: null
eval_metric_configs:
- classname: CIHistograms
n_batches_accum: 5
Expand Down
2 changes: 1 addition & 1 deletion spd/experiments/resid_mlp/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def plot_individual_feature_response(
color=cmap_viridis(f / n_features),
marker=".",
s=s[order],
alpha=alpha[order].numpy(), # pyright: ignore[reportArgumentType]
alpha=alpha[order].float().numpy(), # pyright: ignore[reportArgumentType]
# According to the announcement, alpha is allowed to be an iterable since v3.4.0,
# but the docs & type annotations seem to be wrong. Here's the announcement:
# https://matplotlib.org/stable/users/prev_whats_new/whats_new_3.4.0.html#transparency-alpha-can-be-set-as-an-array-in-collections
Expand Down
8 changes: 4 additions & 4 deletions spd/experiments/resid_mlp/resid_mlp_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,8 +406,8 @@ def plot_neuron_contribution_pairs(

# Plot points separately for each layer with different colors
for layer in range(n_layers):
x_values = relu_conns[layer].flatten().cpu().detach().numpy()
y_values = max_component_contributions[layer].flatten().cpu().detach().numpy()
x_values = relu_conns[layer].flatten().cpu().detach().float().numpy()
y_values = max_component_contributions[layer].flatten().cpu().detach().float().numpy()

layer_label = {0: "First MLP", 1: "Second MLP", 2: "Third MLP"}.get(layer, f"Layer {layer}")

Expand Down Expand Up @@ -451,8 +451,8 @@ def plot_neuron_contribution_pairs(

# Add some statistics to the plot
# Calculate correlation for all points combined
all_x = relu_conns.flatten().cpu().detach().numpy()
all_y = max_component_contributions.flatten().cpu().detach().numpy()
all_x = relu_conns.flatten().cpu().detach().float().numpy()
all_y = max_component_contributions.flatten().cpu().detach().float().numpy()
correlation = np.corrcoef(all_x, all_y)[0, 1]
ax.text(
0.05,
Expand Down
2 changes: 1 addition & 1 deletion spd/experiments/resid_mlp/train_resid_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def train(
loss = loss.mean()
final_losses.append(loss)
final_losses = torch.stack(final_losses).mean().cpu().detach()
logger.info(f"Final losses: {final_losses.numpy()}")
logger.info(f"Final losses: {final_losses.float().numpy()}")
return final_losses


Expand Down
26 changes: 15 additions & 11 deletions spd/experiments/tms/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def filter_significant_subnets(
# Filter subnets
filtered_subnets = subnets[mask]

subnets_indices = subnet_feature_norms_order[:n_significant].cpu().numpy()
subnets_indices = subnet_feature_norms_order[:n_significant].cpu().float().numpy()

return filtered_subnets, subnets_indices, n_significant

Expand All @@ -148,7 +148,7 @@ def plot(

for subnet_idx in range(n_subnets):
ax = axs[subnet_idx]
self._plot_single_vector(ax, subnets[subnet_idx].cpu().detach().numpy(), colors)
self._plot_single_vector(ax, subnets[subnet_idx].cpu().detach().float().numpy(), colors)
self._style_axis(ax)

ax.set_title(
Expand Down Expand Up @@ -225,7 +225,7 @@ def plot(
ax = axs[subnet_idx]
self._plot_single_network(
ax,
subnets_abs[subnet_idx].cpu().detach().numpy(),
subnets_abs[subnet_idx].cpu().detach().float().numpy(),
subnets_abs.max().item(),
n_features,
n_hidden,
Expand Down Expand Up @@ -464,9 +464,13 @@ def plot(
plot_configs.append(
{
"title": "Target model",
"linear1_weights": model.linear1.weight.T.detach().cpu().numpy(),
"linear1_weights": model.linear1.weight.T.detach().cpu().float().numpy(),
"hidden_weights": [
cast(torch.nn.Linear, model.hidden_layers[i]).weight.T.detach().cpu().numpy()
cast(torch.nn.Linear, model.hidden_layers[i])
.weight.T.detach()
.cpu()
.float()
.numpy()
for i in range(config.n_hidden_layers)
]
if config.n_hidden_layers > 0 and model.hidden_layers is not None
Expand All @@ -476,10 +480,10 @@ def plot(
)

# Sum of components
sum_linear1 = linear1_subnets.sum(dim=0).numpy()
sum_linear1 = linear1_subnets.sum(dim=0).float().numpy()
sum_hidden = None
if hidden_layer_components:
sum_hidden = [hw.sum(dim=0).numpy() for hw in hidden_layer_components]
sum_hidden = [hw.sum(dim=0).float().numpy() for hw in hidden_layer_components]
plot_configs.append(
{
"title": "Sum of components",
Expand All @@ -494,7 +498,7 @@ def plot(
comp_type = component_types[idx]
if comp_type == "linear":
# Linear component: show weights in linear1/2, zeros in hidden
linear_weights = linear1_subnets[idx].numpy()
linear_weights = linear1_subnets[idx].float().numpy()
hidden_weights = None
if config.n_hidden_layers > 0 and model.hidden_layers is not None:
# Show zeros for hidden layers (not identity)
Expand All @@ -507,7 +511,7 @@ def plot(
linear_weights = np.zeros((config.n_features, config.n_hidden))
hidden_weights = None
if hidden_layer_components is not None:
hidden_weights = [hw[idx].numpy() for hw in hidden_layer_components]
hidden_weights = [hw[idx].float().numpy() for hw in hidden_layer_components]

plot_configs.append(
{
Expand Down Expand Up @@ -789,7 +793,7 @@ def _plot_heatmaps(

for idx in range(n_subnets):
ax = axs[idx]
ax.imshow(weights[idx].cpu().detach().numpy(), cmap=cmap, vmin=vmin, vmax=vmax)
ax.imshow(weights[idx].cpu().detach().float().numpy(), cmap=cmap, vmin=vmin, vmax=vmax)

# Set title
if idx == 0:
Expand Down Expand Up @@ -923,7 +927,7 @@ def plot_cosine_similarity_analysis(self) -> Figure:
_, max_cosine_sim, _ = self.analyzer.compute_cosine_similarities()

fig, ax = plt.subplots()
ax.bar(range(max_cosine_sim.shape[0]), max_cosine_sim.cpu().detach().numpy())
ax.bar(range(max_cosine_sim.shape[0]), max_cosine_sim.cpu().detach().float().numpy())
ax.axhline(1, color="grey", linestyle="--")
ax.set_xlabel("Input feature index")
ax.set_ylabel("Max cosine similarity")
Expand Down
4 changes: 2 additions & 2 deletions spd/experiments/tms/train_tms.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def plot_intro_diagram(model: TMSModel, filepath: Path) -> None:
plt.rcParams["figure.dpi"] = 200
_, ax = plt.subplots(1, 1, figsize=(2, 2))

W = WA.cpu().detach().numpy()
W = WA.cpu().detach().float().numpy()
ax.scatter(W[:, 0], W[:, 1], c=color)
ax.set_aspect("equal")
ax.add_collection(
Expand Down Expand Up @@ -135,7 +135,7 @@ def plot_cosine_similarity_distribution(

_, ax = plt.subplots(1, 1, figsize=(4, 4))

sims = masked_sims.cpu().numpy()
sims = masked_sims.cpu().float().numpy()
ax.scatter(sims, np.zeros_like(sims), alpha=0.5)
ax.set_xlim(-1, 1)
ax.set_xlabel("Cosine Similarity")
Expand Down
2 changes: 1 addition & 1 deletion spd/metrics/ce_and_kl_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def kl_vs_target(logits: Tensor) -> float:

# Rounded causal importances as masks
rounded_mask_infos = make_mask_infos(
{k: (v > self.rounding_threshold).float() for k, v in ci.items()}
{k: (v > self.rounding_threshold).to(v.dtype) for k, v in ci.items()}
)
rounded_masked_logits = self.model(batch, mask_infos=rounded_mask_infos)
rounded_masked_ce_loss = ce_vs_labels(rounded_masked_logits)
Expand Down
1 change: 1 addition & 0 deletions spd/models/component_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
)

self.target_model = target_model
self.target_model.compile()
self.C = C
self.pretrained_model_output_attr = pretrained_model_output_attr
self.target_module_paths = get_target_module_paths(target_model, target_module_patterns)
Expand Down
12 changes: 6 additions & 6 deletions spd/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _plot_causal_importances_figure(
images = []
for j, (mask_name, mask) in enumerate(ci_vals.items()):
# mask has shape (batch, C) or (batch, pos, C)
mask_data = mask.detach().cpu().numpy()
mask_data = mask.detach().cpu().float().numpy()
if has_pos_dim:
assert mask_data.ndim == 3
mask_data = mask_data[:, 0, :]
Expand Down Expand Up @@ -127,7 +127,7 @@ def plot_mean_component_cis_both_scales(
processed_data = []
for module_name, mean_component_ci in mean_component_cis.items():
sorted_components = torch.sort(mean_component_ci, descending=True)[0]
processed_data.append((module_name, sorted_components.detach().cpu().numpy()))
processed_data.append((module_name, sorted_components.detach().cpu().float().numpy()))

# Create both figures
images = []
Expand Down Expand Up @@ -326,7 +326,7 @@ def plot_UV_matrices(
for j, (name, component) in enumerate(sorted(components.items())):
# Plot V matrix
V = component.V if all_perm_indices is None else component.V[:, all_perm_indices[name]]
V_np = V.detach().cpu().numpy()
V_np = V.detach().cpu().float().numpy()
im = axs[j, 0].matshow(V_np, aspect="auto", cmap="coolwarm")
axs[j, 0].set_ylabel("d_in index")
axs[j, 0].set_xlabel("Component index")
Expand All @@ -335,7 +335,7 @@ def plot_UV_matrices(

# Plot U matrix
U = component.U if all_perm_indices is None else component.U[all_perm_indices[name], :]
U_np = U.detach().cpu().numpy()
U_np = U.detach().cpu().float().numpy()
im = axs[j, 1].matshow(U_np, aspect="auto", cmap="coolwarm")
axs[j, 1].set_ylabel("Component index")
axs[j, 1].set_xlabel("d_out index")
Expand Down Expand Up @@ -400,7 +400,7 @@ def plot_component_activation_density(
col = i // n_rows
ax = axs[row, col]

data = density.detach().cpu().numpy()
data = density.detach().cpu().float().numpy()
ax.hist(data, bins=bins)
ax.set_yscale("log") # Beware, memory leak unless gc.collect() is called after eval loop
ax.set_title(module_name) # Add module name as title to each subplot
Expand Down Expand Up @@ -469,7 +469,7 @@ def plot_ci_values_histograms(
col = i // n_rows
ax = axs[row, col]

data = layer_ci.flatten().cpu().numpy()
data = layer_ci.flatten().cpu().float().numpy()
ax.hist(data, bins=bins)
ax.set_yscale("log") # Beware, memory leak unless gc.collect() is called after eval loop
ax.set_title(f"Causal importances for {layer_name}")
Expand Down
Loading
Loading