diff --git a/pyproject.toml b/pyproject.toml index cef1fc201..8d48394eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,9 @@ dependencies = [ "aiolimiter>=1.2", "openrouter>=0.1.1", "httpx>=0.28.0", - "zstandard", # For streaming datasets + "zstandard", # For streaming datasets "kaleido==0.2.1", + "numba>=0.64.0", ] [dependency-groups] diff --git a/spd/clustering/CLAUDE.md b/spd/clustering/CLAUDE.md index f502f8785..7bdbd64bf 100644 --- a/spd/clustering/CLAUDE.md +++ b/spd/clustering/CLAUDE.md @@ -80,7 +80,7 @@ Computes pairwise distances between clustering runs in an ensemble: ```python ClusteringPipelineConfig # Pipeline settings (n_runs, distances_methods, SLURM config) ClusteringRunConfig # Single run settings (model_path, batch_size, n_tokens, merge_config) -MergeConfig # Merge algorithm params (alpha, iters, activation_threshold) +MergeConfig # Merge algorithm params (alpha, iters, activation_threshold, filter_dead_stat) ``` ### Data Structures diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index 2738daf3e..521fe71b0 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -2,6 +2,7 @@ from functools import cached_property from typing import Any, Literal, NamedTuple +import numpy as np import torch from jaxtyping import Bool, Float, Float16 from torch import Tensor @@ -14,7 +15,8 @@ ClusterCoactivationShaped, ComponentLabels, ) -from spd.clustering.util import ModuleFilterFunc +from spd.clustering.sample_membership import CompressedMembership +from spd.clustering.util import DeadComponentFilterStat, ModuleFilterFunc from spd.log import logger from spd.models.component_model import ComponentModel, OutputWithCache @@ -115,6 +117,17 @@ def compute_coactivatons( return activations_f16.T @ activations_f16 +def _get_component_filter_values( + activations: ActivationsTensor, + filter_stat: DeadComponentFilterStat, +) -> Float[Tensor, " c"]: + if filter_stat == "max": + return activations.max(dim=0).values + + assert filter_stat == "mean", f"Unsupported dead component filter stat: {filter_stat}" + return activations.mean(dim=0) + + class FilteredActivations(NamedTuple): activations: ActivationsTensor "activations after filtering dead components" @@ -144,21 +157,26 @@ def filter_dead_components( activations: ActivationsTensor, labels: ComponentLabels, filter_dead_threshold: float = 0.01, + filter_dead_stat: DeadComponentFilterStat = "max", ) -> FilteredActivations: """Filter out dead components based on a threshold if `filter_dead_threshold` is 0, no filtering is applied. activations and labels are returned as is, `dead_components_labels` is `None`. - otherwise, components whose **maximum** activations across all samples is below the threshold - are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`. + otherwise, components whose aggregate activation statistic across all samples is below the + threshold are considered dead and filtered out. The statistic is selected by + `filter_dead_stat` and the labels of dead components are returned in `dead_components_labels`. `dead_components_labels` will also be `None` if no components were below the threshold. """ dead_components_lst: ComponentLabels | None = None if filter_dead_threshold > 0: dead_components_lst = ComponentLabels(list()) - max_act: Float[Tensor, " c"] = activations.max(dim=0).values - dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold + filter_values: Float[Tensor, " c"] = _get_component_filter_values( + activations=activations, + filter_stat=filter_dead_stat, + ) + dead_components: Bool[Tensor, " c"] = filter_values < filter_dead_threshold if dead_components.any(): activations = activations[:, ~dead_components] @@ -266,6 +284,305 @@ def get_module_activations(self) -> dict[str, ActivationsTensor]: return result +@dataclass(frozen=True) +class ProcessedMemberships: + """Processed, compressed sample memberships for exact merge iteration.""" + + module_component_counts: dict[str, int] + module_alive_counts: dict[str, int] + labels: ComponentLabels + dead_components_lst: ComponentLabels | None + memberships: list[CompressedMembership] + n_samples: int + preview: ProcessedActivations | None = None + + @property + def n_components_original(self) -> int: + return sum(self.module_component_counts.values()) + + @property + def n_components_alive(self) -> int: + return len(self.labels) + + @property + def n_components_dead(self) -> int: + return len(self.dead_components_lst) if self.dead_components_lst else 0 + + def validate(self) -> None: + assert self.n_components_alive == len(self.memberships), ( + f"{self.n_components_alive = } != {len(self.memberships) = }" + ) + assert self.n_components_alive + self.n_components_dead == self.n_components_original, ( + f"{self.n_components_alive = } + {self.n_components_dead = } != {self.n_components_original = }" + ) + + +class MembershipBuilder: + """Streaming builder for compressed sample memberships. + + This stores only active sample ids per component plus a small dense preview + for plots/logging. It assumes thresholded boolean merge semantics. + """ + + def __init__( + self, + *, + activation_threshold: float, + filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat, + filter_modules: ModuleFilterFunc | None, + preview_n_samples: int = 256, + ) -> None: + self.activation_threshold = activation_threshold + self.filter_dead_threshold = filter_dead_threshold + self.filter_dead_stat = filter_dead_stat + self.filter_modules = filter_modules + self.preview_n_samples = preview_n_samples + + self.n_samples = 0 + self.module_component_counts: dict[str, int] = {} + self.max_activations: dict[str, Float[Tensor, " c"]] = {} + self.sum_activations: dict[str, Float[Tensor, " c"]] = {} + self.sample_idx_chunks: dict[str, list[list[np.ndarray]]] = {} + self.preview_chunks: dict[str, list[Tensor]] = {} + self.module_order: list[str] = [] + self._preview_rows = 0 + + def _ensure_module(self, key: str, n_components: int) -> None: + if key in self.module_component_counts: + assert self.module_component_counts[key] == n_components, ( + f"Inconsistent component count for module '{key}': " + f"{self.module_component_counts[key]} vs {n_components}" + ) + return + + self.module_component_counts[key] = n_components + self.max_activations[key] = torch.full((n_components,), float("-inf")) + self.sum_activations[key] = torch.zeros((n_components,), dtype=torch.float64) + self.sample_idx_chunks[key] = [[] for _ in range(n_components)] + self.preview_chunks[key] = [] + self.module_order.append(key) + + def add_batch( + self, + activations: dict[str, Float[Tensor, "samples C"]], + ) -> None: + """Add a batch of per-module activations shaped [samples, components].""" + filtered = ( + {key: act for key, act in activations.items() if self.filter_modules(key)} + if self.filter_modules is not None + else activations + ) + if not filtered: + return + + batch_n_samples = next(iter(filtered.values())).shape[0] + sample_offset = self.n_samples + + for key, act in filtered.items(): + act_cpu = act.detach().cpu() + assert act_cpu.ndim == 2, f"Expected 2D activations, got shape {tuple(act_cpu.shape)}" + self._ensure_module(key, act_cpu.shape[1]) + + self.max_activations[key] = torch.maximum( + self.max_activations[key], act_cpu.max(dim=0).values + ) + self.sum_activations[key] += act_cpu.sum(dim=0, dtype=torch.float64) + + if self._preview_rows < self.preview_n_samples: + remaining = self.preview_n_samples - self._preview_rows + self.preview_chunks[key].append(act_cpu[:remaining].clone()) + + mask_np = (act_cpu.numpy() > self.activation_threshold).T + comp_indices, row_indices = np.nonzero(mask_np) + if comp_indices.size > 0: + row_indices = row_indices.astype(np.int32, copy=False) + sample_offset + split_points = np.flatnonzero(np.diff(comp_indices)) + 1 + row_groups = np.split(row_indices, split_points) + comp_groups = np.split(comp_indices, split_points) + for comp_group, row_group in zip(comp_groups, row_groups, strict=True): + self.sample_idx_chunks[key][int(comp_group[0])].append(row_group) + + self.n_samples += batch_n_samples + self._preview_rows = min(self.n_samples, self.preview_n_samples) + + def finalize(self) -> ProcessedMemberships: + module_alive_counts: dict[str, int] = {} + alive_labels = ComponentLabels(list()) + dead_labels = ComponentLabels(list()) + memberships: list[CompressedMembership] = [] + + preview_module_component_counts: dict[str, int] = {} + preview_module_alive_counts: dict[str, int] = {} + preview_chunks_alive: list[Tensor] = [] + + for key in self.module_order: + filter_values = ( + self.max_activations[key] + if self.filter_dead_stat == "max" + else (self.sum_activations[key] / self.n_samples).to( + self.max_activations[key].dtype + ) + ) + n_components = self.module_component_counts[key] + alive = ( + filter_values >= self.filter_dead_threshold + if self.filter_dead_threshold > 0 + else torch.ones(n_components, dtype=torch.bool) + ) + n_alive = int(alive.sum().item()) + module_alive_counts[key] = n_alive + preview_module_component_counts[key] = n_components + preview_module_alive_counts[key] = n_alive + + preview_tensor = ( + torch.cat(self.preview_chunks[key], dim=0) + if self.preview_chunks[key] + else torch.empty((0, n_components), dtype=filter_values.dtype) + ) + + for comp_idx in range(n_components): + label = f"{key}:{comp_idx}" + if alive[comp_idx]: + alive_labels.append(label) + chunks = self.sample_idx_chunks[key][comp_idx] + sample_ids = ( + np.concatenate(chunks).astype(np.int64, copy=False) + if chunks + else np.empty((0,), dtype=np.int64) + ) + memberships.append( + CompressedMembership.from_sample_indices( + sample_indices=sample_ids, + n_samples=self.n_samples, + ) + ) + else: + dead_labels.append(label) + + if n_alive > 0: + preview_chunks_alive.append(preview_tensor[:, alive]) + + preview: ProcessedActivations | None = None + if preview_chunks_alive: + preview = ProcessedActivations( + module_component_counts=preview_module_component_counts, + module_alive_counts=preview_module_alive_counts, + activations=torch.cat(preview_chunks_alive, dim=1), + labels=ComponentLabels(alive_labels.copy()), + dead_components_lst=ComponentLabels(dead_labels.copy()) if dead_labels else None, + ) + + result = ProcessedMemberships( + module_component_counts=self.module_component_counts, + module_alive_counts=module_alive_counts, + labels=alive_labels, + dead_components_lst=dead_labels if dead_labels else None, + memberships=memberships, + n_samples=self.n_samples, + preview=preview, + ) + result.validate() + return result + + +def collect_memberships_lm( + model: ComponentModel, + dataloader: DataLoader[Any], + n_tokens: int, + n_tokens_per_seq: int, + device: torch.device | str, + seed: int, + activation_threshold: float, + filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat = "max", + filter_modules: ModuleFilterFunc | None = None, + preview_n_samples: int = 256, +) -> ProcessedMemberships: + """Collect LM activations across batches into compressed memberships.""" + rng = torch.Generator().manual_seed(seed) + builder = MembershipBuilder( + activation_threshold=activation_threshold, + filter_dead_threshold=filter_dead_threshold, + filter_dead_stat=filter_dead_stat, + filter_modules=filter_modules, + preview_n_samples=preview_n_samples, + ) + n_collected = 0 + + pbar = tqdm(dataloader, desc="Collecting activations", unit="batch") + for batch_data in pbar: + input_ids = batch_data["input_ids"] + batch_size, n_ctx = input_ids.shape + + activations = component_activations(model=model, batch=input_ids, device=device) + + positions = torch.randint(0, n_ctx, (batch_size, n_tokens_per_seq), generator=rng) + batch_indices = torch.arange(batch_size).unsqueeze(1).expand_as(positions) + + sampled_activations: dict[str, Float[Tensor, "samples C"]] = {} + n_remaining = n_tokens - n_collected + batch_take = min(batch_size * n_tokens_per_seq, n_remaining) + for key, act in activations.items(): + sampled = act[batch_indices, positions].reshape(batch_size * n_tokens_per_seq, -1) + sampled_activations[key] = sampled[:batch_take] + + builder.add_batch(sampled_activations) + + n_collected += batch_take + pbar.set_postfix(tokens=f"{n_collected}/{n_tokens}") + if n_collected >= n_tokens: + break + + assert n_collected >= n_tokens, ( + f"Dataloader exhausted: collected {n_collected} tokens but needed {n_tokens}" + ) + logger.info(f"Collected {n_collected} token activations (requested {n_tokens})") + return builder.finalize() + + +def collect_memberships_resid_mlp( + model: ComponentModel, + dataloader: DataLoader[Any], + n_samples: int, + device: torch.device | str, + activation_threshold: float, + filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat = "max", + filter_modules: ModuleFilterFunc | None = None, + preview_n_samples: int = 256, +) -> ProcessedMemberships: + """Collect ResidMLP activations across batches into compressed memberships.""" + builder = MembershipBuilder( + activation_threshold=activation_threshold, + filter_dead_threshold=filter_dead_threshold, + filter_dead_stat=filter_dead_stat, + filter_modules=filter_modules, + preview_n_samples=preview_n_samples, + ) + n_collected = 0 + + pbar = tqdm(dataloader, desc="Collecting activations", unit="batch") + for batch_data in pbar: + batch, _ = batch_data + activations = component_activations(model=model, batch=batch, device=device) + + n_remaining = n_samples - n_collected + batch_take = min(batch.shape[0], n_remaining) + builder.add_batch({key: act[:batch_take] for key, act in activations.items()}) + + n_collected += batch_take + pbar.set_postfix(samples=f"{n_collected}/{n_samples}") + if n_collected >= n_samples: + break + + assert n_collected >= n_samples, ( + f"Dataloader exhausted: collected {n_collected} samples but needed {n_samples}" + ) + logger.info(f"Collected {n_collected} resid_mlp activations (requested {n_samples})") + return builder.finalize() + + def process_activations( activations: dict[ str, # module name to @@ -273,6 +590,7 @@ def process_activations( | Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations) ], filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat = "max", seq_mode: Literal["concat", "seq_mean", None] = None, filter_modules: ModuleFilterFunc | None = None, ) -> ProcessedActivations: @@ -309,8 +627,11 @@ def process_activations( c = act.shape[-1] module_component_counts[key] = c if filter_dead_threshold > 0: - max_act: Float[Tensor, " c"] = act.max(dim=0).values - alive = max_act >= filter_dead_threshold + filter_values: Float[Tensor, " c"] = _get_component_filter_values( + activations=act, + filter_stat=filter_dead_stat, + ) + alive = filter_values >= filter_dead_threshold alive_masks[key] = alive total_alive += int(alive.sum().item()) else: diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py index 4ab830319..cac0dbfa6 100644 --- a/spd/clustering/clustering_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -43,6 +43,10 @@ class ClusteringRunConfig(BaseConfig): ) batch_size: PositiveInt = Field(..., description="Batch size for processing") + n_samples: PositiveInt | None = Field( + default=None, + description="Number of activation samples to collect for non-LM tasks. Defaults to one batch if unset.", + ) n_tokens: PositiveInt | None = Field( default=None, description="Number of token activation samples to collect (LM only)", diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index f1b3425d1..b715fa681 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -2,10 +2,15 @@ import torch from jaxtyping import Bool, Float +from scipy import sparse from torch import Tensor from spd.clustering.consts import ClusterCoactivationShaped, MergePair from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.sample_membership import ( + CompressedMembership, + count_group_overlaps_from_component_rows, +) def compute_mdl_cost( @@ -187,3 +192,58 @@ def recompute_coacts_merge_pair( coact_new, activation_mask_new, ) + + +def recompute_coacts_merge_pair_memberships( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + merge_pair: MergePair, + memberships: list[CompressedMembership], + component_activity_csr: sparse.csr_matrix, +) -> tuple[ + GroupMerge, + Float[Tensor, "k_groups-1 k_groups-1"], + list[CompressedMembership], +]: + """Recompute coactivations after a merge using compressed memberships.""" + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + assert len(memberships) == k_groups, "Memberships must match coactivation matrix shape" + + new_group_idx: int = min(merge_pair) + remove_idx: int = max(merge_pair) + merged_membership = memberships[merge_pair[0]].union(memberships[merge_pair[1]]) + + merge_new: GroupMerge = merges.merge_groups( + merge_pair[0], + merge_pair[1], + ) + + mask: Bool[Tensor, " k_groups"] = torch.ones( + coact.shape[0], dtype=torch.bool, device=coact.device + ) + mask[remove_idx] = False + coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact[mask, :][:, mask].clone() + + merged_rows = merged_membership.to_sample_indices() + coact_with_merge_np = count_group_overlaps_from_component_rows( + merged_rows=merged_rows, + component_activity_csr=component_activity_csr, + group_idxs=merge_new.group_idxs.cpu().numpy(), + n_groups=merge_new.k_groups, + ) + coact_with_merge = torch.tensor( + coact_with_merge_np, + dtype=coact.dtype, + device=coact.device, + ) + + coact_new[new_group_idx, :] = coact_with_merge + coact_new[:, new_group_idx] = coact_with_merge + coact_new[new_group_idx, new_group_idx] = float(merged_membership.count()) + + memberships_new = memberships.copy() + memberships_new[new_group_idx] = merged_membership + memberships_new.pop(remove_idx) + + return merge_new, coact_new, memberships_new diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml index 9af7106d8..e7031fca5 100644 --- a/spd/clustering/configs/crc/example.yaml +++ b/spd/clustering/configs/crc/example.yaml @@ -12,6 +12,7 @@ merge_config: merge_pair_sampling_kwargs: threshold: 0.05 # For range sampler: fraction of the range of costs to sample from filter_dead_threshold: 0.001 # Threshold for filtering dead components + filter_dead_stat: "max" # Dead-component statistic: "max" or "mean" module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules wandb_project: spd @@ -20,4 +21,4 @@ logging_intervals: stat: 1 # for k_groups, merge_pair_cost, mdl_loss tensor: 100 # for wandb_log_tensor and fraction_* calculations plot: 100 # for calling the plotting callback - artifact: 100 # for calling the artifact callback \ No newline at end of file + artifact: 100 # for calling the artifact callback diff --git a/spd/clustering/membership_snapshot.py b/spd/clustering/membership_snapshot.py new file mode 100644 index 000000000..f8cd1445f --- /dev/null +++ b/spd/clustering/membership_snapshot.py @@ -0,0 +1,98 @@ +import json +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +from scipy import sparse + +from spd.clustering.consts import ComponentLabels +from spd.clustering.sample_membership import ( + CompressedMembership, + memberships_to_sample_component_matrix, +) + + +@dataclass(frozen=True, slots=True) +class MembershipSnapshot: + """Disk-friendly sparse membership snapshot for repeatable merge benchmarks.""" + + matrix_csc: sparse.csc_matrix + labels: ComponentLabels + + @property + def n_samples(self) -> int: + shape = self.matrix_csc.shape + assert shape is not None + return int(shape[0]) + + @property + def n_components(self) -> int: + shape = self.matrix_csc.shape + assert shape is not None + return int(shape[1]) + + def to_memberships(self) -> list[CompressedMembership]: + memberships: list[CompressedMembership] = [] + for col_idx in range(self.n_components): + sample_indices = self.matrix_csc.indices[ + self.matrix_csc.indptr[col_idx] : self.matrix_csc.indptr[col_idx + 1] + ].astype(np.int64, copy=False) + memberships.append( + CompressedMembership.from_sample_indices( + sample_indices=sample_indices, + n_samples=self.n_samples, + ) + ) + return memberships + + def to_csr(self) -> sparse.sparray | sparse.spmatrix: + return self.matrix_csc.tocsr() + + +def memberships_to_csc( + memberships: list[CompressedMembership], + n_samples: int, +) -> sparse.csc_matrix: + matrix = memberships_to_sample_component_matrix(memberships, fmt="csc") + assert isinstance(matrix, sparse.csc_matrix) + shape = matrix.shape + assert shape is not None + assert shape[0] == n_samples + return matrix + + +def save_membership_snapshot( + output_dir: Path, + *, + memberships: list[CompressedMembership], + labels: ComponentLabels, + n_samples: int, +) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + matrix_path = output_dir / "memberships.npz" + metadata_path = output_dir / "metadata.json" + + matrix_csc = memberships_to_csc(memberships, n_samples=n_samples) + sparse.save_npz(matrix_path, matrix_csc) + metadata_path.write_text( + json.dumps( + { + "n_samples": n_samples, + "n_components": len(labels), + "labels": list(labels), + }, + indent=2, + ) + ) + return output_dir + + +def load_membership_snapshot(path: Path) -> MembershipSnapshot: + matrix_path = path / "memberships.npz" + metadata_path = path / "metadata.json" + matrix_csc = sparse.load_npz(matrix_path).tocsc() + metadata = json.loads(metadata_path.read_text()) + labels = ComponentLabels(metadata["labels"]) + assert matrix_csc.shape[0] == metadata["n_samples"] + assert matrix_csc.shape[1] == metadata["n_components"] + return MembershipSnapshot(matrix_csc=matrix_csc, labels=labels) diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index dba55c878..d95017aa1 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -4,6 +4,7 @@ This wraps the pure merge_iteration_pure() function and adds WandB/plotting callbacks. """ +import time import warnings from typing import Protocol @@ -16,6 +17,7 @@ compute_mdl_cost, compute_merge_costs, recompute_coacts_merge_pair, + recompute_coacts_merge_pair_memberships, ) from spd.clustering.consts import ( ActivationsTensor, @@ -27,6 +29,30 @@ from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory +from spd.clustering.sample_membership import ( + CompressedMembership, + compute_coactivation_matrix_from_csr, + memberships_to_sample_component_csr, +) +from spd.log import logger + + +def _choose_coact_device(coact: ClusterCoactivationShaped) -> torch.device: + """Prefer GPU for dense cost-matrix math when enough memory is available.""" + if not torch.cuda.is_available(): + return coact.device + + if coact.device.type == "cuda": + return coact.device + + free_bytes, _ = torch.cuda.mem_get_info() + coact_bytes = coact.numel() * coact.element_size() + # Current coact, a temporary clone during recompute, and the costs tensor dominate. + required_bytes = coact_bytes * 3 + 512 * 1024**2 + if free_bytes >= required_bytes: + return torch.device("cuda") + + return coact.device class LogCallback(Protocol): @@ -186,3 +212,143 @@ def merge_iteration( # finish up # ================================================== return merge_history + + +def merge_iteration_memberships( + merge_config: MergeConfig, + memberships: list[CompressedMembership], + n_samples: int, + component_labels: ComponentLabels, + log_callback: LogCallback | None = None, +) -> MergeHistory: + """Exact merge iteration using compressed sample memberships.""" + csr_start = time.perf_counter() + component_activity_csr = memberships_to_sample_component_csr(memberships) + logger.info( + "Built component activity CSR in " + f"{time.perf_counter() - csr_start:.2f}s " + f"(shape={component_activity_csr.shape}, nnz={component_activity_csr.nnz})" + ) + + coact_start = time.perf_counter() + logger.info( + "Building coactivation matrix from compressed memberships " + f"(n_groups={len(memberships)}, n_samples={n_samples})" + ) + current_coact: ClusterCoactivationShaped = compute_coactivation_matrix_from_csr( + component_activity_csr + ) + logger.info( + "Built coactivation matrix in " + f"{time.perf_counter() - coact_start:.2f}s " + f"(shape={tuple(current_coact.shape)})" + ) + coact_device = _choose_coact_device(current_coact) + if coact_device != current_coact.device: + transfer_start = time.perf_counter() + current_coact = current_coact.to(device=coact_device) + logger.info( + "Moved compressed coactivation matrix to " + f"{coact_device} in {time.perf_counter() - transfer_start:.2f}s" + ) + else: + logger.info(f"Keeping compressed coactivation matrix on {current_coact.device}") + + c_components: int = current_coact.shape[0] + assert current_coact.shape[1] == c_components, "Coactivation matrix must be square" + + num_iters: int = merge_config.get_num_iters(c_components) + current_merge: GroupMerge = GroupMerge.identity(n_components=c_components) + current_memberships = memberships.copy() + k_groups: int = c_components + + merge_history: MergeHistory = MergeHistory.from_config( + merge_config=merge_config, + labels=component_labels, + ) + + pbar: tqdm[int] = tqdm( + range(num_iters), + unit="iter", + total=num_iters, + ) + merge_start = time.perf_counter() + log_every = min(10, num_iters) + for iter_idx in pbar: + costs: ClusterCoactivationShaped = compute_merge_costs( + coact=current_coact / n_samples, + merges=current_merge, + alpha=merge_config.alpha, + ) + + merge_pair: MergePair = merge_config.merge_pair_sample(costs) + + current_merge, current_coact, current_memberships = recompute_coacts_merge_pair_memberships( + coact=current_coact, + merges=current_merge, + merge_pair=merge_pair, + memberships=current_memberships, + component_activity_csr=component_activity_csr, + ) + + merge_history.add_iteration( + idx=iter_idx, + selected_pair=merge_pair, + current_merge=current_merge, + ) + + diag_acts: Float[Tensor, " k_groups"] = torch.diag(current_coact) + mdl_loss: float = compute_mdl_cost( + acts=diag_acts, + merges=current_merge, + alpha=merge_config.alpha, + ) + mdl_loss_norm: float = mdl_loss / n_samples + merge_pair_cost: float = float(costs[merge_pair].item()) + + pbar.set_description(f"k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}") + + if log_callback is not None: + log_callback( + iter_idx=iter_idx, + current_coact=current_coact, + component_labels=component_labels, + current_merge=current_merge, + costs=costs, + merge_history=merge_history, + k_groups=k_groups, + merge_pair_cost=merge_pair_cost, + mdl_loss=mdl_loss, + mdl_loss_norm=mdl_loss_norm, + diag_acts=diag_acts, + ) + + if (iter_idx + 1) % log_every == 0 or iter_idx == 0 or iter_idx + 1 == num_iters: + elapsed = time.perf_counter() - merge_start + logger.info( + "Compressed merge progress: " + f"iter={iter_idx + 1}/{num_iters}, " + f"elapsed={elapsed:.2f}s, " + f"sec_per_iter={elapsed / (iter_idx + 1):.4f}, " + f"k_groups={k_groups - 1}" + ) + + k_groups -= 1 + assert current_coact.shape[0] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_coact.shape[1] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert len(current_memberships) == k_groups, ( + "Membership count should match number of groups" + ) + + if k_groups <= 3: + warnings.warn( + f"Stopping early at iteration {iter_idx} as only {k_groups} groups left", + stacklevel=2, + ) + break + + return merge_history diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index f471879b2..c4e246946 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -14,7 +14,7 @@ MergePairSampler, MergePairSamplerKey, ) -from spd.clustering.util import ModuleFilterFunc, ModuleFilterSource +from spd.clustering.util import DeadComponentFilterStat, ModuleFilterFunc, ModuleFilterSource from spd.spd_types import Probability MergeConfigKey = Literal[ @@ -24,6 +24,7 @@ "merge_pair_sampling_method", "merge_pair_sampling_kwargs", "filter_dead_threshold", + "filter_dead_stat", ] @@ -66,7 +67,11 @@ class MergeConfig(BaseConfig): ) filter_dead_threshold: float = Field( default=0.001, - description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", + description="Threshold for filtering out dead components using the statistic selected by filter_dead_stat.", + ) + filter_dead_stat: DeadComponentFilterStat = Field( + default="max", + description="Statistic used to determine whether a component is dead before clustering.", ) module_name_filter: ModuleFilterSource = Field( default=None, diff --git a/spd/clustering/sample_membership.py b/spd/clustering/sample_membership.py new file mode 100644 index 000000000..efdfd089e --- /dev/null +++ b/spd/clustering/sample_membership.py @@ -0,0 +1,366 @@ +from dataclasses import dataclass +from typing import Literal + +import numpy as np +import torch +from numba import njit +from scipy import sparse + +from spd.clustering.consts import ClusterCoactivationShaped + +_POPCOUNT_TABLE = np.unpackbits(np.arange(256, dtype=np.uint8)[:, None], axis=1).sum(axis=1) + + +def _index_dtype_for(n_samples: int) -> np.dtype[np.unsignedinteger]: + return np.dtype(np.uint32) if n_samples <= np.iinfo(np.uint32).max else np.dtype(np.uint64) + + +def _bytes_for_bitset(n_samples: int) -> int: + return (n_samples + 7) // 8 + + +def _prefer_sparse(n_samples: int, active_count: int, dtype: np.dtype[np.generic]) -> bool: + return active_count * np.dtype(dtype).itemsize <= _bytes_for_bitset(n_samples) + + +def _sample_indices_to_bits( + sample_indices: np.ndarray, + n_samples: int, +) -> np.ndarray: + n_bytes = _bytes_for_bitset(n_samples) + bits = np.zeros(n_bytes, dtype=np.uint8) + if sample_indices.size == 0: + return bits + + sample_indices = sample_indices.astype(np.int64, copy=False) + byte_indices = sample_indices // 8 + bit_offsets = sample_indices % 8 + np.bitwise_or.at(bits, byte_indices, (1 << bit_offsets).astype(np.uint8)) + return bits + + +def _count_sparse_sparse_intersection(a: np.ndarray, b: np.ndarray) -> int: + i = 0 + j = 0 + count = 0 + len_a = len(a) + len_b = len(b) + while i < len_a and j < len_b: + a_i = int(a[i]) + b_j = int(b[j]) + if a_i == b_j: + count += 1 + i += 1 + j += 1 + elif a_i < b_j: + i += 1 + else: + j += 1 + return count + + +def _union_sparse_sparse(a: np.ndarray, b: np.ndarray) -> np.ndarray: + i = 0 + j = 0 + out: list[int] = [] + len_a = len(a) + len_b = len(b) + while i < len_a and j < len_b: + a_i = int(a[i]) + b_j = int(b[j]) + if a_i == b_j: + out.append(a_i) + i += 1 + j += 1 + elif a_i < b_j: + out.append(a_i) + i += 1 + else: + out.append(b_j) + j += 1 + while i < len_a: + out.append(int(a[i])) + i += 1 + while j < len_b: + out.append(int(b[j])) + j += 1 + dtype = a.dtype if a.size > 0 else b.dtype + return np.asarray(out, dtype=dtype) + + +def _count_sparse_bitset_intersection( + sample_indices: np.ndarray, + bits: np.ndarray, +) -> int: + if sample_indices.size == 0: + return 0 + sample_indices = sample_indices.astype(np.int64, copy=False) + byte_indices = sample_indices // 8 + bit_offsets = sample_indices % 8 + return int(np.count_nonzero(bits[byte_indices] & ((1 << bit_offsets).astype(np.uint8)))) + + +def _bitset_to_sample_indices(bits: np.ndarray, n_samples: int) -> np.ndarray: + nonzero_byte_idxs = np.flatnonzero(bits) + if nonzero_byte_idxs.size == 0: + return np.empty((0,), dtype=_index_dtype_for(n_samples)) + + unpacked = np.unpackbits(bits[nonzero_byte_idxs], bitorder="little") + active_bit_positions = np.flatnonzero(unpacked) + sample_indices = nonzero_byte_idxs[active_bit_positions // 8].astype( + np.int64, copy=False + ) * 8 + (active_bit_positions % 8) + sample_indices = sample_indices[sample_indices < n_samples] + return sample_indices.astype(_index_dtype_for(n_samples), copy=False) + + +@njit(cache=True) # pyright: ignore[reportUntypedFunctionDecorator] +def _count_group_overlaps_rows_numba( + merged_rows: np.ndarray, + indptr: np.ndarray, + indices: np.ndarray, + group_idxs: np.ndarray, + n_groups: int, +) -> np.ndarray: + counts = np.zeros(n_groups, dtype=np.int64) + seen = np.full(n_groups, -1, dtype=np.int64) + stamp = 0 + for row in merged_rows: + stamp += 1 + start = indptr[row] + end = indptr[row + 1] + for pos in range(start, end): + group_idx = group_idxs[indices[pos]] + if seen[group_idx] == stamp: + continue + seen[group_idx] = stamp + counts[group_idx] += 1 + return counts + + +@dataclass(frozen=True, slots=True) +class CompressedMembership: + """Exact sample memberships stored sparsely when cheaper, otherwise as a bitset.""" + + n_samples: int + active_count: int + sample_indices: np.ndarray | None = None + bits: np.ndarray | None = None + + def __post_init__(self) -> None: + has_sparse = self.sample_indices is not None + has_bits = self.bits is not None + assert has_sparse != has_bits, "Membership must use exactly one representation" + + @classmethod + def empty(cls, n_samples: int) -> "CompressedMembership": + dtype = _index_dtype_for(n_samples) + return cls(n_samples=n_samples, active_count=0, sample_indices=np.empty((0,), dtype=dtype)) + + @classmethod + def from_sample_indices( + cls, + sample_indices: np.ndarray, + n_samples: int, + ) -> "CompressedMembership": + dtype = _index_dtype_for(n_samples) + sample_indices = sample_indices.astype(dtype, copy=False) + active_count = int(sample_indices.size) + if _prefer_sparse(n_samples, active_count, dtype): + return cls( + n_samples=n_samples, active_count=active_count, sample_indices=sample_indices + ) + return cls( + n_samples=n_samples, + active_count=active_count, + bits=_sample_indices_to_bits(sample_indices, n_samples), + ) + + @classmethod + def from_bits( + cls, + bits: np.ndarray, + n_samples: int, + active_count: int, + ) -> "CompressedMembership": + return cls( + n_samples=n_samples, + active_count=active_count, + bits=bits.astype(np.uint8, copy=False), + ) + + @property + def is_sparse(self) -> bool: + return self.sample_indices is not None + + def count(self) -> int: + return self.active_count + + def intersection_count(self, other: "CompressedMembership") -> int: + assert self.n_samples == other.n_samples, "Memberships must share sample space" + if self.sample_indices is not None and other.sample_indices is not None: + return _count_sparse_sparse_intersection(self.sample_indices, other.sample_indices) + if self.sample_indices is not None: + assert other.bits is not None + return _count_sparse_bitset_intersection(self.sample_indices, other.bits) + if other.sample_indices is not None: + assert self.bits is not None + return _count_sparse_bitset_intersection(other.sample_indices, self.bits) + + assert self.bits is not None and other.bits is not None + overlap = np.bitwise_and(self.bits, other.bits) + return int(_POPCOUNT_TABLE[overlap].sum(dtype=np.uint64)) + + def union(self, other: "CompressedMembership") -> "CompressedMembership": + assert self.n_samples == other.n_samples, "Memberships must share sample space" + overlap_count = self.intersection_count(other) + union_count = self.active_count + other.active_count - overlap_count + + if self.sample_indices is not None and other.sample_indices is not None: + union_indices = _union_sparse_sparse(self.sample_indices, other.sample_indices) + return CompressedMembership.from_sample_indices(union_indices, self.n_samples) + + if self.bits is not None and other.bits is not None: + return CompressedMembership.from_bits( + bits=np.bitwise_or(self.bits, other.bits), + n_samples=self.n_samples, + active_count=union_count, + ) + + if self.bits is not None: + base_bits = self.bits.copy() + sparse_indices = other.sample_indices + else: + assert other.bits is not None + base_bits = other.bits.copy() + sparse_indices = self.sample_indices + + assert sparse_indices is not None + if sparse_indices.size > 0: + sparse_indices_i64 = sparse_indices.astype(np.int64, copy=False) + byte_indices = sparse_indices_i64 // 8 + bit_offsets = sparse_indices_i64 % 8 + np.bitwise_or.at(base_bits, byte_indices, (1 << bit_offsets).astype(np.uint8)) + + return CompressedMembership.from_bits( + bits=base_bits, + n_samples=self.n_samples, + active_count=union_count, + ) + + def to_bool_array(self) -> np.ndarray: + if self.sample_indices is not None: + result = np.zeros(self.n_samples, dtype=bool) + result[self.sample_indices.astype(np.int64, copy=False)] = True + return result + + assert self.bits is not None + unpacked = np.unpackbits(self.bits, bitorder="little") + return unpacked[: self.n_samples].astype(bool, copy=False) + + def to_sample_indices(self) -> np.ndarray: + if self.sample_indices is not None: + return self.sample_indices + + assert self.bits is not None + return _bitset_to_sample_indices(self.bits, self.n_samples) + + +def memberships_to_sample_component_matrix( + memberships: list[CompressedMembership], + *, + fmt: Literal["csr", "csc"] = "csr", +) -> sparse.csr_matrix | sparse.csc_matrix: + """Build a binary sample-by-component sparse matrix from memberships.""" + n_groups = len(memberships) + if n_groups == 0: + empty = sparse.csr_matrix((0, 0), dtype=np.uint8) + return empty if fmt == "csr" else sparse.csc_matrix(empty) + + n_samples = memberships[0].n_samples + assert all(membership.n_samples == n_samples for membership in memberships), ( + "Memberships must share sample space" + ) + + nnz = sum(membership.count() for membership in memberships) + row_indices = np.empty(nnz, dtype=np.int64) + col_indices = np.empty(nnz, dtype=np.int32) + + offset = 0 + for group_idx, membership in enumerate(memberships): + sample_indices = membership.to_sample_indices().astype(np.int64, copy=False) + group_nnz = sample_indices.size + row_indices[offset : offset + group_nnz] = sample_indices + col_indices[offset : offset + group_nnz] = group_idx + offset += group_nnz + + values = np.ones(nnz, dtype=np.uint8) + matrix = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(n_samples, n_groups), + dtype=np.uint8, + ) + return matrix if fmt == "csr" else sparse.csc_matrix(matrix) + + +def memberships_to_sample_component_csr( + memberships: list[CompressedMembership], +) -> sparse.csr_matrix: + """Build a binary sample-by-component CSR matrix from memberships.""" + matrix = memberships_to_sample_component_matrix(memberships, fmt="csr") + assert isinstance(matrix, sparse.csr_matrix) + return matrix + + +def count_group_overlaps_from_component_rows( + merged_rows: np.ndarray, + component_activity_csr: sparse.csr_matrix, + group_idxs: np.ndarray, + n_groups: int, +) -> np.ndarray: + """Count exact merged-group overlaps by scanning original component rows. + + `group_idxs` maps original component indices to the current group index after + the candidate merge has been applied. + """ + merged_rows_i64 = merged_rows.astype(np.int64, copy=False) + indptr = component_activity_csr.indptr.astype(np.int64, copy=False) + indices = component_activity_csr.indices.astype(np.int32, copy=False) + group_idxs_i64 = group_idxs.astype(np.int64, copy=False) + return _count_group_overlaps_rows_numba( + merged_rows=merged_rows_i64, + indptr=indptr, + indices=indices, + group_idxs=group_idxs_i64, + n_groups=n_groups, + ) + + +def compute_coactivation_matrix_from_csr( + component_activity_csr: sparse.csr_matrix, +) -> ClusterCoactivationShaped: + """Compute the full coactivation matrix from a sample-by-component CSR matrix.""" + activation_matrix = component_activity_csr.astype(np.int32, copy=False) + coact = (activation_matrix.T @ activation_matrix).toarray() + return torch.from_numpy(coact.astype(np.float32, copy=False)) + + +def compute_coactivation_matrix( + memberships: list[CompressedMembership], +) -> ClusterCoactivationShaped: + """Compute the full coactivation matrix from compressed memberships. + + This builds a sparse sample-by-component matrix and computes X.T @ X, + which is much faster than Python-level pairwise intersections in the + typical highly sparse regime. + """ + n_groups = len(memberships) + if n_groups == 0: + return torch.empty((0, 0), dtype=torch.float32) + + n_samples = memberships[0].n_samples + assert all(membership.n_samples == n_samples for membership in memberships), ( + "Memberships must share sample space" + ) + + return compute_coactivation_matrix_from_csr(memberships_to_sample_component_csr(memberships)) diff --git a/spd/clustering/scripts/benchmark_merge_snapshot.py b/spd/clustering/scripts/benchmark_merge_snapshot.py new file mode 100644 index 000000000..bd253d807 --- /dev/null +++ b/spd/clustering/scripts/benchmark_merge_snapshot.py @@ -0,0 +1,86 @@ +"""Benchmark exact merge performance from a cached membership snapshot.""" + +import argparse +import time +from pathlib import Path +from typing import Any + +import numpy as np + +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.membership_snapshot import load_membership_snapshot +from spd.clustering.merge import merge_iteration_memberships +from spd.clustering.sample_membership import count_group_overlaps_from_component_rows +from spd.utils.general_utils import replace_pydantic_model + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--snapshot-dir", type=Path, required=True) + parser.add_argument("--config", type=Path, required=True) + parser.add_argument("--iters", type=int, default=100) + parser.add_argument("--profile-overlap", action="store_true") + args = parser.parse_args() + + snapshot = load_membership_snapshot(args.snapshot_dir) + memberships = snapshot.to_memberships() + + if args.profile_overlap: + row_csr: Any = snapshot.to_csr() + merged = memberships[0].union(memberships[1]) + merged_indices = merged.to_sample_indices().astype(np.int64, copy=False) + group_idxs = np.arange(snapshot.n_components, dtype=np.int64) + + st = time.time() + current = np.array([merged.intersection_count(m) for m in memberships], dtype=np.int64) + current_s = time.time() - st + + st = time.time() + row_counts = count_group_overlaps_from_component_rows( + merged_rows=merged_indices, + component_activity_csr=row_csr, + group_idxs=group_idxs, + n_groups=snapshot.n_components, + ) + row_counts_s = time.time() - st + + assert np.array_equal(current, row_counts) + print( + { + "phase": "overlap_profile", + "merged_size": int(merged_indices.size), + "current_s": round(current_s, 4), + "row_counts_s": round(row_counts_s, 4), + "speedup": round(current_s / row_counts_s, 2), + } + ) + + run_config = ClusteringRunConfig.from_file(args.config) + merge_config = replace_pydantic_model(run_config.merge_config, {"iters": args.iters}) + + st = time.time() + history = merge_iteration_memberships( + merge_config=merge_config, + memberships=memberships, + n_samples=snapshot.n_samples, + component_labels=snapshot.labels, + log_callback=None, + ) + elapsed = time.time() - st + print( + { + "phase": "merge_benchmark", + "snapshot_dir": str(args.snapshot_dir), + "n_samples": snapshot.n_samples, + "n_components": snapshot.n_components, + "iters": history.n_iters_current, + "elapsed_s": round(elapsed, 2), + "sec_per_iter": round(elapsed / history.n_iters_current, 4) + if history.n_iters_current + else None, + } + ) + + +if __name__ == "__main__": + main() diff --git a/spd/clustering/scripts/cache_membership_snapshot.py b/spd/clustering/scripts/cache_membership_snapshot.py new file mode 100644 index 000000000..2d65caa08 --- /dev/null +++ b/spd/clustering/scripts/cache_membership_snapshot.py @@ -0,0 +1,93 @@ +"""Collect compressed memberships once and save them for repeated merge benchmarking.""" + +import argparse +from pathlib import Path + +from spd.clustering.activations import collect_memberships_lm, collect_memberships_resid_mlp +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.dataset import create_clustering_dataloader +from spd.clustering.membership_snapshot import save_membership_snapshot +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import replace_pydantic_model + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--config", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--n-tokens", type=int, default=None) + parser.add_argument("--n-samples", type=int, default=None) + parser.add_argument("--batch-size", type=int, default=None) + args = parser.parse_args() + + run_config = ClusteringRunConfig.from_file(args.config) + overrides: dict[str, int] = {} + if args.n_tokens is not None: + overrides["n_tokens"] = args.n_tokens + if args.n_samples is not None: + overrides["n_samples"] = args.n_samples + if args.batch_size is not None: + overrides["batch_size"] = args.batch_size + if overrides: + run_config = replace_pydantic_model(run_config, overrides) + + assert run_config.merge_config.activation_threshold is not None, ( + "Snapshotting only supports thresholded compressed memberships" + ) + + spd_run = SPDRunInfo.from_path(run_config.model_path) + task_name: TaskName = spd_run.config.task_config.task_name + device = get_device() + model = ComponentModel.from_run_info(spd_run).to(device) + dataloader = create_clustering_dataloader( + model_path=run_config.model_path, + task_name=task_name, + batch_size=run_config.batch_size, + seed=run_config.dataset_seed, + ) + + if task_name == "lm": + assert run_config.n_tokens is not None + assert run_config.n_tokens_per_seq is not None + processed = collect_memberships_lm( + model=model, + dataloader=dataloader, + n_tokens=run_config.n_tokens, + n_tokens_per_seq=run_config.n_tokens_per_seq, + device=device, + seed=run_config.dataset_seed, + activation_threshold=run_config.merge_config.activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_modules=run_config.merge_config.filter_modules, + ) + else: + n_samples = run_config.n_samples or run_config.batch_size + processed = collect_memberships_resid_mlp( + model=model, + dataloader=dataloader, + n_samples=n_samples, + device=device, + activation_threshold=run_config.merge_config.activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_modules=run_config.merge_config.filter_modules, + ) + + save_membership_snapshot( + args.output_dir, + memberships=processed.memberships, + labels=processed.labels, + n_samples=processed.n_samples, + ) + print( + { + "output_dir": str(args.output_dir), + "n_samples": processed.n_samples, + "n_components": len(processed.labels), + } + ) + + +if __name__ == "__main__": + main() diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 844365681..eea1007dc 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -29,7 +29,10 @@ from spd.clustering.activations import ( ProcessedActivations, + ProcessedMemberships, collect_activations, + collect_memberships_lm, + collect_memberships_resid_mlp, component_activations, process_activations, ) @@ -43,7 +46,7 @@ from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.math.semilog import semilog -from spd.clustering.merge import merge_iteration +from spd.clustering.merge import merge_iteration, merge_iteration_memberships from spd.clustering.merge_history import MergeHistory from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration @@ -290,44 +293,99 @@ def main(run_config: ClusteringRunConfig) -> Path: logger.info("Loading model") model = ComponentModel.from_run_info(spd_run).to(device) + processed_activations: ProcessedActivations | None = None + processed_memberships: ProcessedMemberships | None = None + activations: ActivationsTensor | None = None + component_labels: ComponentLabels + # 4. Compute activations logger.info("Computing activations") - if task_name == "lm": - assert run_config.n_tokens is not None, "n_tokens must be set for LM tasks" - assert run_config.n_tokens_per_seq is not None, "n_tokens_per_seq must be set for LM tasks" - activations_dict = collect_activations( - model=model, - dataloader=dataloader, - n_tokens=run_config.n_tokens, - n_tokens_per_seq=run_config.n_tokens_per_seq, - device=device, - seed=run_config.dataset_seed, - ) + use_compressed_merge = run_config.merge_config.activation_threshold is not None + if use_compressed_merge: + activation_threshold = run_config.merge_config.activation_threshold + assert activation_threshold is not None + if task_name == "lm": + assert run_config.n_tokens is not None, "n_tokens must be set for LM tasks" + assert run_config.n_tokens_per_seq is not None, ( + "n_tokens_per_seq must be set for LM tasks" + ) + processed_memberships = collect_memberships_lm( + model=model, + dataloader=dataloader, + n_tokens=run_config.n_tokens, + n_tokens_per_seq=run_config.n_tokens_per_seq, + device=device, + seed=run_config.dataset_seed, + activation_threshold=activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_dead_stat=run_config.merge_config.filter_dead_stat, + filter_modules=run_config.merge_config.filter_modules, + ) + else: + processed_memberships = collect_memberships_resid_mlp( + model=model, + dataloader=dataloader, + n_samples=run_config.n_samples or run_config.batch_size, + device=device, + activation_threshold=activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_dead_stat=run_config.merge_config.filter_dead_stat, + filter_modules=run_config.merge_config.filter_modules, + ) + processed_activations = processed_memberships.preview + component_labels = ComponentLabels(processed_memberships.labels.copy()) else: - # resid_mlp: single batch, no sequence dimension - batch_data = next(iter(dataloader)) - batch, _ = batch_data # DatasetGeneratedDataLoader yields (batch, labels) - activations_dict = component_activations( - model=model, - batch=batch, - device=device, + if task_name == "lm": + assert run_config.n_tokens is not None, "n_tokens must be set for LM tasks" + assert run_config.n_tokens_per_seq is not None, ( + "n_tokens_per_seq must be set for LM tasks" + ) + activations_dict = collect_activations( + model=model, + dataloader=dataloader, + n_tokens=run_config.n_tokens, + n_tokens_per_seq=run_config.n_tokens_per_seq, + device=device, + seed=run_config.dataset_seed, + ) + else: + n_samples_target = run_config.n_samples or run_config.batch_size + collected_batches: dict[str, list[Tensor]] = {} + n_collected = 0 + for batch_data in dataloader: + batch, _ = batch_data + batch_take = min(batch.shape[0], n_samples_target - n_collected) + acts = component_activations(model=model, batch=batch[:batch_take], device=device) + for key, act in acts.items(): + collected_batches.setdefault(key, []).append(act.cpu()) + n_collected += batch_take + if n_collected >= n_samples_target: + break + assert n_collected >= n_samples_target, ( + f"Dataloader exhausted: collected {n_collected} samples but needed {n_samples_target}" + ) + activations_dict = { + key: torch.cat(chunks, dim=0) for key, chunks in collected_batches.items() + } + + logger.info("Processing activations") + processed_activations = process_activations( + activations=activations_dict, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_dead_stat=run_config.merge_config.filter_dead_stat, + seq_mode=None, + filter_modules=run_config.merge_config.filter_modules, ) + activations = processed_activations.activations.to(device) + component_labels = ComponentLabels(processed_activations.labels.copy()) + del activations_dict - # 5. Process activations - logger.info("Processing activations") - processed_activations: ProcessedActivations = process_activations( - activations=activations_dict, - filter_dead_threshold=run_config.merge_config.filter_dead_threshold, - seq_mode=None, - filter_modules=run_config.merge_config.filter_modules, - ) - - # 6. Log activations (if WandB enabled) - if wandb_run is not None: + # 5. Log activations preview (if WandB enabled) + if wandb_run is not None and processed_activations is not None: logger.info("Plotting activations") plot_activations( processed_activations=processed_activations, - save_dir=None, # Don't save to disk, only WandB + save_dir=None, n_samples_max=256, wandb_run=wandb_run, ) @@ -339,11 +397,6 @@ def main(run_config: ClusteringRunConfig) -> Path: single=True, ) - # Extract what we need, then free the model and temporary objects - activations: ActivationsTensor = processed_activations.activations.to(device) - component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) - del processed_activations - del activations_dict del model gc.collect() torch.cuda.empty_cache() @@ -356,12 +409,22 @@ def main(run_config: ClusteringRunConfig) -> Path: else None ) - history: MergeHistory = merge_iteration( - merge_config=run_config.merge_config, - activations=activations, - component_labels=component_labels, - log_callback=log_callback, - ) + if processed_memberships is not None: + history = merge_iteration_memberships( + merge_config=run_config.merge_config, + memberships=processed_memberships.memberships, + n_samples=processed_memberships.n_samples, + component_labels=component_labels, + log_callback=log_callback, + ) + else: + assert activations is not None + history = merge_iteration( + merge_config=run_config.merge_config, + activations=activations, + component_labels=component_labels, + log_callback=log_callback, + ) # 8. Save merge history diff --git a/spd/clustering/util.py b/spd/clustering/util.py index bd11e2fd4..a84e3d78f 100644 --- a/spd/clustering/util.py +++ b/spd/clustering/util.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Literal def format_scientific_latex(value: float) -> str: @@ -16,3 +17,4 @@ def format_scientific_latex(value: float) -> str: ModuleFilterSource = str | Callable[[str], bool] | set[str] | None ModuleFilterFunc = Callable[[str], bool] +DeadComponentFilterStat = Literal["max", "mean"] diff --git a/tests/clustering/test_filter_dead_components.py b/tests/clustering/test_filter_dead_components.py index 654631f37..456db84e3 100644 --- a/tests/clustering/test_filter_dead_components.py +++ b/tests/clustering/test_filter_dead_components.py @@ -129,3 +129,49 @@ def test_linear_gradient_thresholds(threshold: float) -> None: assert result.n_alive == expected_alive assert result.n_dead == n_components - expected_alive + + +def test_filter_dead_components_mean_stat() -> None: + """Mean-based filtering keeps components whose average activation clears the threshold.""" + activations = torch.tensor( + [ + [1e-5, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + ] + ) + labels = ComponentLabels(["spiky", "dead", "steady"]) + + result: FilteredActivations = filter_dead_components( + activations=activations, + labels=labels, + filter_dead_threshold=5e-6, + filter_dead_stat="mean", + ) + + assert result.labels == ["steady"] + assert result.dead_components_labels == ["spiky", "dead"] + + +def test_filter_dead_components_max_stat_preserves_spikes() -> None: + """Max-based filtering preserves components with a single large activation.""" + activations = torch.tensor( + [ + [1e-5, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + ] + ) + labels = ComponentLabels(["spiky", "dead", "steady"]) + + result: FilteredActivations = filter_dead_components( + activations=activations, + labels=labels, + filter_dead_threshold=5e-6, + filter_dead_stat="max", + ) + + assert result.labels == ["spiky", "steady"] + assert result.dead_components_labels == ["dead"] diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py index 63f4e88f7..38ffa0fc1 100644 --- a/tests/clustering/test_merge_config.py +++ b/tests/clustering/test_merge_config.py @@ -15,6 +15,7 @@ def test_default_config(self): assert config.merge_pair_sampling_method == "range" assert config.merge_pair_sampling_kwargs == {"threshold": 0.05} + assert config.filter_dead_stat == "max" def test_range_sampler_config(self): """Test MergeConfig with range sampler.""" @@ -75,6 +76,7 @@ def test_config_with_all_parameters(self): merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.5}, filter_dead_threshold=0.001, + filter_dead_stat="mean", module_name_filter="model.layers", ) @@ -84,6 +86,7 @@ def test_config_with_all_parameters(self): assert config.merge_pair_sampling_method == "mcmc" assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} assert config.filter_dead_threshold == 0.001 + assert config.filter_dead_stat == "mean" assert config.module_name_filter == "model.layers" def test_config_serialization(self): diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 8492300de..47e1801f6 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -2,9 +2,16 @@ import torch -from spd.clustering.consts import ComponentLabels +from spd.clustering.compute_costs import recompute_coacts_merge_pair_memberships +from spd.clustering.consts import ComponentLabels, MergePair +from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig +from spd.clustering.sample_membership import ( + CompressedMembership, + compute_coactivation_matrix, + memberships_to_sample_component_csr, +) class TestMergeIntegration: @@ -149,3 +156,33 @@ def test_merge_with_small_components(self): # Early stopping may occur at 2 groups, so final count could be 2 or 3 assert history.merges.k_groups[-1].item() >= 2 assert history.merges.k_groups[-1].item() <= 3 + + def test_membership_recompute_matches_row_oriented_path(self): + """Row-oriented overlap recompute should match the direct membership path exactly.""" + memberships = [ + CompressedMembership.from_sample_indices(torch.tensor(indices).numpy(), n_samples=8) + for indices in ([0, 2, 5], [1, 2], [0, 3], [4, 5, 6]) + ] + coact = compute_coactivation_matrix(memberships) + merges = GroupMerge.identity(n_components=len(memberships)) + component_activity_csr = memberships_to_sample_component_csr(memberships) + + merge_row, coact_row, memberships_row = recompute_coacts_merge_pair_memberships( + coact=coact, + merges=merges, + merge_pair=MergePair((0, 1)), + memberships=memberships, + component_activity_csr=component_activity_csr, + ) + + expected_group_idxs = torch.tensor([0, 0, 1, 2], dtype=torch.int64) + expected_coact = torch.tensor( + [ + [4.0, 1.0, 1.0], + [1.0, 2.0, 0.0], + [1.0, 0.0, 3.0], + ] + ) + assert torch.equal(merge_row.group_idxs, expected_group_idxs) + assert torch.equal(coact_row, expected_coact) + assert [membership.count() for membership in memberships_row] == [4, 2, 3] diff --git a/uv.lock b/uv.lock index becdf51ec..be4f3d735 100644 --- a/uv.lock +++ b/uv.lock @@ -870,6 +870,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/8f/8f6f491d595a9e5912971f3f863d81baddccc8a4d0c3749d6a0dd9ffc9df/kiwisolver-1.4.9-cp313-cp313t-win_arm64.whl", hash = "sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c", size = 68646, upload-time = "2025-08-10T21:27:00.52Z" }, ] +[[package]] +name = "llvmlite" +version = "0.46.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/74/cd/08ae687ba099c7e3d21fe2ea536500563ef1943c5105bf6ab4ee3829f68e/llvmlite-0.46.0.tar.gz", hash = "sha256:227c9fd6d09dce2783c18b754b7cd9d9b3b3515210c46acc2d3c5badd9870ceb", size = 193456, upload-time = "2025-12-08T18:15:36.295Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/ff/3eba7eb0aed4b6fca37125387cd417e8c458e750621fce56d2c541f67fa8/llvmlite-0.46.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:30b60892d034bc560e0ec6654737aaa74e5ca327bd8114d82136aa071d611172", size = 37232767, upload-time = "2025-12-08T18:15:13.22Z" }, + { url = "https://files.pythonhosted.org/packages/0e/54/737755c0a91558364b9200702c3c9c15d70ed63f9b98a2c32f1c2aa1f3ba/llvmlite-0.46.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6cc19b051753368a9c9f31dc041299059ee91aceec81bd57b0e385e5d5bf1a54", size = 56275176, upload-time = "2025-12-08T18:15:16.339Z" }, + { url = "https://files.pythonhosted.org/packages/e6/91/14f32e1d70905c1c0aa4e6609ab5d705c3183116ca02ac6df2091868413a/llvmlite-0.46.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bca185892908f9ede48c0acd547fe4dc1bafefb8a4967d47db6cf664f9332d12", size = 55128629, upload-time = "2025-12-08T18:15:19.493Z" }, + { url = "https://files.pythonhosted.org/packages/4a/a7/d526ae86708cea531935ae777b6dbcabe7db52718e6401e0fb9c5edea80e/llvmlite-0.46.0-cp313-cp313-win_amd64.whl", hash = "sha256:67438fd30e12349ebb054d86a5a1a57fd5e87d264d2451bcfafbbbaa25b82a35", size = 38138941, upload-time = "2025-12-08T18:15:22.536Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -1068,6 +1080,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/af/cd3290a647df567645353feed451ef4feaf5844496ced69c4dcb84295ff4/nodejs_wheel_binaries-24.12.0-py2.py3-none-win_arm64.whl", hash = "sha256:d0c2273b667dd7e3f55e369c0085957b702144b1b04bfceb7ce2411e58333757", size = 39048104, upload-time = "2025-12-11T21:12:23.495Z" }, ] +[[package]] +name = "numba" +version = "0.64.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/c9/a0fb41787d01d621046138da30f6c2100d80857bf34b3390dd68040f27a3/numba-0.64.0.tar.gz", hash = "sha256:95e7300af648baa3308127b1955b52ce6d11889d16e8cfe637b4f85d2fca52b1", size = 2765679, upload-time = "2026-02-18T18:41:20.974Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/80/2734de90f9300a6e2503b35ee50d9599926b90cbb7ac54f9e40074cd07f1/numba-0.64.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3bab2c872194dcd985f1153b70782ec0fbbe348fffef340264eacd3a76d59fd6", size = 2683392, upload-time = "2026-02-18T18:41:06.563Z" }, + { url = "https://files.pythonhosted.org/packages/42/e8/14b5853ebefd5b37723ef365c5318a30ce0702d39057eaa8d7d76392859d/numba-0.64.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:703a246c60832cad231d2e73c1182f25bf3cc8b699759ec8fe58a2dbc689a70c", size = 3812245, upload-time = "2026-02-18T18:41:07.963Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a2/f60dc6c96d19b7185144265a5fbf01c14993d37ff4cd324b09d0212aa7ce/numba-0.64.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e2e49a7900ee971d32af7609adc0cfe6aa7477c6f6cccdf6d8138538cf7756f", size = 3511328, upload-time = "2026-02-18T18:41:09.504Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2a/fe7003ea7e7237ee7014f8eaeeb7b0d228a2db22572ca85bab2648cf52cb/numba-0.64.0-cp313-cp313-win_amd64.whl", hash = "sha256:396f43c3f77e78d7ec84cdfc6b04969c78f8f169351b3c4db814b97e7acf4245", size = 2752668, upload-time = "2026-02-18T18:41:11.455Z" }, +] + [[package]] name = "numpy" version = "2.4.0" @@ -1965,6 +1993,7 @@ dependencies = [ { name = "jaxtyping" }, { name = "kaleido" }, { name = "matplotlib" }, + { name = "numba" }, { name = "numpy" }, { name = "openrouter" }, { name = "orjson" }, @@ -2007,6 +2036,7 @@ requires-dist = [ { name = "jaxtyping" }, { name = "kaleido", specifier = "==0.2.1" }, { name = "matplotlib" }, + { name = "numba", specifier = ">=0.64.0" }, { name = "numpy" }, { name = "openrouter", specifier = ">=0.1.1" }, { name = "orjson" },