From 124316f388ba05c33af807d78e203474c99563ce Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Wed, 18 Mar 2026 13:27:27 +0000 Subject: [PATCH 1/5] Refactor clustering to use compressed memberships --- spd/clustering/activations.py | 286 +++++++++++++++++++++++ spd/clustering/clustering_run_config.py | 4 + spd/clustering/compute_costs.py | 49 ++++ spd/clustering/merge.py | 99 ++++++++ spd/clustering/sample_membership.py | 238 +++++++++++++++++++ spd/clustering/scripts/run_clustering.py | 144 ++++++++---- 6 files changed, 778 insertions(+), 42 deletions(-) create mode 100644 spd/clustering/sample_membership.py diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index 2738daf3e..3b6c52b9d 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -2,6 +2,7 @@ from functools import cached_property from typing import Any, Literal, NamedTuple +import numpy as np import torch from jaxtyping import Bool, Float, Float16 from torch import Tensor @@ -14,6 +15,7 @@ ClusterCoactivationShaped, ComponentLabels, ) +from spd.clustering.sample_membership import BitsetMembership from spd.clustering.util import ModuleFilterFunc from spd.log import logger from spd.models.component_model import ComponentModel, OutputWithCache @@ -266,6 +268,290 @@ def get_module_activations(self) -> dict[str, ActivationsTensor]: return result +@dataclass(frozen=True) +class ProcessedMemberships: + """Processed, compressed sample memberships for exact merge iteration.""" + + module_component_counts: dict[str, int] + module_alive_counts: dict[str, int] + labels: ComponentLabels + dead_components_lst: ComponentLabels | None + memberships: list[BitsetMembership] + n_samples: int + preview: ProcessedActivations | None = None + + @property + def n_components_original(self) -> int: + return sum(self.module_component_counts.values()) + + @property + def n_components_alive(self) -> int: + return len(self.labels) + + @property + def n_components_dead(self) -> int: + return len(self.dead_components_lst) if self.dead_components_lst else 0 + + def validate(self) -> None: + assert self.n_components_alive == len(self.memberships), ( + f"{self.n_components_alive = } != {len(self.memberships) = }" + ) + assert self.n_components_alive + self.n_components_dead == self.n_components_original, ( + f"{self.n_components_alive = } + {self.n_components_dead = } != {self.n_components_original = }" + ) + + +class MembershipBuilder: + """Streaming builder for compressed sample memberships. + + This stores only active sample ids per component plus a small dense preview + for plots/logging. It assumes thresholded boolean merge semantics. + """ + + def __init__( + self, + *, + activation_threshold: float, + filter_dead_threshold: float, + filter_modules: ModuleFilterFunc | None, + preview_n_samples: int = 256, + ) -> None: + self.activation_threshold = activation_threshold + self.filter_dead_threshold = filter_dead_threshold + self.filter_modules = filter_modules + self.preview_n_samples = preview_n_samples + + self.n_samples = 0 + self.module_component_counts: dict[str, int] = {} + self.max_activations: dict[str, Float[Tensor, " c"]] = {} + self.sample_idx_chunks: dict[str, list[list[np.ndarray]]] = {} + self.preview_chunks: dict[str, list[Tensor]] = {} + self.module_order: list[str] = [] + self._preview_rows = 0 + + def _ensure_module(self, key: str, n_components: int) -> None: + if key in self.module_component_counts: + assert self.module_component_counts[key] == n_components, ( + f"Inconsistent component count for module '{key}': " + f"{self.module_component_counts[key]} vs {n_components}" + ) + return + + self.module_component_counts[key] = n_components + self.max_activations[key] = torch.full((n_components,), float("-inf")) + self.sample_idx_chunks[key] = [[] for _ in range(n_components)] + self.preview_chunks[key] = [] + self.module_order.append(key) + + def add_batch( + self, + activations: dict[str, Float[Tensor, "samples C"]], + ) -> None: + """Add a batch of per-module activations shaped [samples, components].""" + filtered = ( + {key: act for key, act in activations.items() if self.filter_modules(key)} + if self.filter_modules is not None + else activations + ) + if not filtered: + return + + batch_n_samples = next(iter(filtered.values())).shape[0] + sample_offset = self.n_samples + + for key, act in filtered.items(): + act_cpu = act.detach().cpu() + assert act_cpu.ndim == 2, f"Expected 2D activations, got shape {tuple(act_cpu.shape)}" + self._ensure_module(key, act_cpu.shape[1]) + + self.max_activations[key] = torch.maximum( + self.max_activations[key], act_cpu.max(dim=0).values + ) + + if self._preview_rows < self.preview_n_samples: + remaining = self.preview_n_samples - self._preview_rows + self.preview_chunks[key].append(act_cpu[:remaining].clone()) + + mask_np = (act_cpu.numpy() > self.activation_threshold).T + comp_indices, row_indices = np.nonzero(mask_np) + if comp_indices.size > 0: + row_indices = row_indices.astype(np.int32, copy=False) + sample_offset + split_points = np.flatnonzero(np.diff(comp_indices)) + 1 + row_groups = np.split(row_indices, split_points) + comp_groups = np.split(comp_indices, split_points) + for comp_group, row_group in zip(comp_groups, row_groups, strict=True): + self.sample_idx_chunks[key][int(comp_group[0])].append(row_group) + + self.n_samples += batch_n_samples + self._preview_rows = min(self.n_samples, self.preview_n_samples) + + def finalize(self) -> ProcessedMemberships: + module_alive_counts: dict[str, int] = {} + alive_labels = ComponentLabels(list()) + dead_labels = ComponentLabels(list()) + memberships: list[BitsetMembership] = [] + + preview_module_component_counts: dict[str, int] = {} + preview_module_alive_counts: dict[str, int] = {} + preview_chunks_alive: list[Tensor] = [] + + for key in self.module_order: + max_act = self.max_activations[key] + n_components = self.module_component_counts[key] + alive = ( + max_act >= self.filter_dead_threshold + if self.filter_dead_threshold > 0 + else torch.ones(n_components, dtype=torch.bool) + ) + n_alive = int(alive.sum().item()) + module_alive_counts[key] = n_alive + preview_module_component_counts[key] = n_components + preview_module_alive_counts[key] = n_alive + + preview_tensor = ( + torch.cat(self.preview_chunks[key], dim=0) + if self.preview_chunks[key] + else torch.empty((0, n_components), dtype=max_act.dtype) + ) + + for comp_idx in range(n_components): + label = f"{key}:{comp_idx}" + if alive[comp_idx]: + alive_labels.append(label) + chunks = self.sample_idx_chunks[key][comp_idx] + sample_ids = ( + np.concatenate(chunks).astype(np.int64, copy=False) + if chunks + else np.empty((0,), dtype=np.int64) + ) + memberships.append( + BitsetMembership.from_sample_indices( + sample_indices=sample_ids, + n_samples=self.n_samples, + ) + ) + else: + dead_labels.append(label) + + if n_alive > 0: + preview_chunks_alive.append(preview_tensor[:, alive]) + + preview: ProcessedActivations | None = None + if preview_chunks_alive: + preview = ProcessedActivations( + module_component_counts=preview_module_component_counts, + module_alive_counts=preview_module_alive_counts, + activations=torch.cat(preview_chunks_alive, dim=1), + labels=ComponentLabels(alive_labels.copy()), + dead_components_lst=ComponentLabels(dead_labels.copy()) if dead_labels else None, + ) + + result = ProcessedMemberships( + module_component_counts=self.module_component_counts, + module_alive_counts=module_alive_counts, + labels=alive_labels, + dead_components_lst=dead_labels if dead_labels else None, + memberships=memberships, + n_samples=self.n_samples, + preview=preview, + ) + result.validate() + return result + + +def collect_memberships_lm( + model: ComponentModel, + dataloader: DataLoader[Any], + n_tokens: int, + n_tokens_per_seq: int, + device: torch.device | str, + seed: int, + activation_threshold: float, + filter_dead_threshold: float, + filter_modules: ModuleFilterFunc | None = None, + preview_n_samples: int = 256, +) -> ProcessedMemberships: + """Collect LM activations across batches into compressed memberships.""" + rng = torch.Generator().manual_seed(seed) + builder = MembershipBuilder( + activation_threshold=activation_threshold, + filter_dead_threshold=filter_dead_threshold, + filter_modules=filter_modules, + preview_n_samples=preview_n_samples, + ) + n_collected = 0 + + pbar = tqdm(dataloader, desc="Collecting activations", unit="batch") + for batch_data in pbar: + input_ids = batch_data["input_ids"] + batch_size, n_ctx = input_ids.shape + + activations = component_activations(model=model, batch=input_ids, device=device) + + positions = torch.randint(0, n_ctx, (batch_size, n_tokens_per_seq), generator=rng) + batch_indices = torch.arange(batch_size).unsqueeze(1).expand_as(positions) + + sampled_activations: dict[str, Float[Tensor, "samples C"]] = {} + n_remaining = n_tokens - n_collected + batch_take = min(batch_size * n_tokens_per_seq, n_remaining) + for key, act in activations.items(): + sampled = act[batch_indices, positions].reshape(batch_size * n_tokens_per_seq, -1) + sampled_activations[key] = sampled[:batch_take] + + builder.add_batch(sampled_activations) + + n_collected += batch_take + pbar.set_postfix(tokens=f"{n_collected}/{n_tokens}") + if n_collected >= n_tokens: + break + + assert n_collected >= n_tokens, ( + f"Dataloader exhausted: collected {n_collected} tokens but needed {n_tokens}" + ) + logger.info(f"Collected {n_collected} token activations (requested {n_tokens})") + return builder.finalize() + + +def collect_memberships_resid_mlp( + model: ComponentModel, + dataloader: DataLoader[Any], + n_samples: int, + device: torch.device | str, + activation_threshold: float, + filter_dead_threshold: float, + filter_modules: ModuleFilterFunc | None = None, + preview_n_samples: int = 256, +) -> ProcessedMemberships: + """Collect ResidMLP activations across batches into compressed memberships.""" + builder = MembershipBuilder( + activation_threshold=activation_threshold, + filter_dead_threshold=filter_dead_threshold, + filter_modules=filter_modules, + preview_n_samples=preview_n_samples, + ) + n_collected = 0 + + pbar = tqdm(dataloader, desc="Collecting activations", unit="batch") + for batch_data in pbar: + batch, _ = batch_data + activations = component_activations(model=model, batch=batch, device=device) + + n_remaining = n_samples - n_collected + batch_take = min(batch.shape[0], n_remaining) + builder.add_batch({key: act[:batch_take] for key, act in activations.items()}) + + n_collected += batch_take + pbar.set_postfix(samples=f"{n_collected}/{n_samples}") + if n_collected >= n_samples: + break + + assert n_collected >= n_samples, ( + f"Dataloader exhausted: collected {n_collected} samples but needed {n_samples}" + ) + logger.info(f"Collected {n_collected} resid_mlp activations (requested {n_samples})") + return builder.finalize() + + def process_activations( activations: dict[ str, # module name to diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py index 4ab830319..cac0dbfa6 100644 --- a/spd/clustering/clustering_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -43,6 +43,10 @@ class ClusteringRunConfig(BaseConfig): ) batch_size: PositiveInt = Field(..., description="Batch size for processing") + n_samples: PositiveInt | None = Field( + default=None, + description="Number of activation samples to collect for non-LM tasks. Defaults to one batch if unset.", + ) n_tokens: PositiveInt | None = Field( default=None, description="Number of token activation samples to collect (LM only)", diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index f1b3425d1..96db03531 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -6,6 +6,7 @@ from spd.clustering.consts import ClusterCoactivationShaped, MergePair from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.sample_membership import BitsetMembership def compute_mdl_cost( @@ -187,3 +188,51 @@ def recompute_coacts_merge_pair( coact_new, activation_mask_new, ) + + +def recompute_coacts_merge_pair_memberships( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + merge_pair: MergePair, + memberships: list[BitsetMembership], +) -> tuple[ + GroupMerge, + Float[Tensor, "k_groups-1 k_groups-1"], + list[BitsetMembership], +]: + """Recompute coactivations after a merge using compressed memberships.""" + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + assert len(memberships) == k_groups, "Memberships must match coactivation matrix shape" + + new_group_idx: int = min(merge_pair) + remove_idx: int = max(merge_pair) + merged_membership = memberships[merge_pair[0]].union(memberships[merge_pair[1]]) + + coact_with_merge = torch.tensor( + [float(merged_membership.intersection_count(membership)) for membership in memberships], + dtype=coact.dtype, + device=coact.device, + ) + + merge_new: GroupMerge = merges.merge_groups( + merge_pair[0], + merge_pair[1], + ) + + coact_temp: ClusterCoactivationShaped = coact.clone() + coact_temp[new_group_idx, :] = coact_with_merge + coact_temp[:, new_group_idx] = coact_with_merge + + mask: Bool[Tensor, " k_groups"] = torch.ones( + coact_temp.shape[0], dtype=torch.bool, device=coact_temp.device + ) + mask[remove_idx] = False + coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact_temp[mask, :][:, mask] + coact_new[new_group_idx, new_group_idx] = float(merged_membership.count()) + + memberships_new = memberships.copy() + memberships_new[new_group_idx] = merged_membership + memberships_new.pop(remove_idx) + + return merge_new, coact_new, memberships_new diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index dba55c878..28acc9d9e 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -16,6 +16,7 @@ compute_mdl_cost, compute_merge_costs, recompute_coacts_merge_pair, + recompute_coacts_merge_pair_memberships, ) from spd.clustering.consts import ( ActivationsTensor, @@ -27,6 +28,7 @@ from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory +from spd.clustering.sample_membership import BitsetMembership, compute_coactivation_matrix class LogCallback(Protocol): @@ -186,3 +188,100 @@ def merge_iteration( # finish up # ================================================== return merge_history + + +def merge_iteration_memberships( + merge_config: MergeConfig, + memberships: list[BitsetMembership], + n_samples: int, + component_labels: ComponentLabels, + log_callback: LogCallback | None = None, +) -> MergeHistory: + """Exact merge iteration using compressed sample memberships.""" + current_coact: ClusterCoactivationShaped = compute_coactivation_matrix(memberships) + + c_components: int = current_coact.shape[0] + assert current_coact.shape[1] == c_components, "Coactivation matrix must be square" + + num_iters: int = merge_config.get_num_iters(c_components) + current_merge: GroupMerge = GroupMerge.identity(n_components=c_components) + current_memberships = memberships.copy() + k_groups: int = c_components + + merge_history: MergeHistory = MergeHistory.from_config( + merge_config=merge_config, + labels=component_labels, + ) + + pbar: tqdm[int] = tqdm( + range(num_iters), + unit="iter", + total=num_iters, + ) + for iter_idx in pbar: + costs: ClusterCoactivationShaped = compute_merge_costs( + coact=current_coact / n_samples, + merges=current_merge, + alpha=merge_config.alpha, + ) + + merge_pair: MergePair = merge_config.merge_pair_sample(costs) + + current_merge, current_coact, current_memberships = recompute_coacts_merge_pair_memberships( + coact=current_coact, + merges=current_merge, + merge_pair=merge_pair, + memberships=current_memberships, + ) + + merge_history.add_iteration( + idx=iter_idx, + selected_pair=merge_pair, + current_merge=current_merge, + ) + + diag_acts: Float[Tensor, " k_groups"] = torch.diag(current_coact) + mdl_loss: float = compute_mdl_cost( + acts=diag_acts, + merges=current_merge, + alpha=merge_config.alpha, + ) + mdl_loss_norm: float = mdl_loss / n_samples + merge_pair_cost: float = float(costs[merge_pair].item()) + + pbar.set_description(f"k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}") + + if log_callback is not None: + log_callback( + iter_idx=iter_idx, + current_coact=current_coact, + component_labels=component_labels, + current_merge=current_merge, + costs=costs, + merge_history=merge_history, + k_groups=k_groups, + merge_pair_cost=merge_pair_cost, + mdl_loss=mdl_loss, + mdl_loss_norm=mdl_loss_norm, + diag_acts=diag_acts, + ) + + k_groups -= 1 + assert current_coact.shape[0] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_coact.shape[1] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert len(current_memberships) == k_groups, ( + "Membership count should match number of groups" + ) + + if k_groups <= 3: + warnings.warn( + f"Stopping early at iteration {iter_idx} as only {k_groups} groups left", + stacklevel=2, + ) + break + + return merge_history diff --git a/spd/clustering/sample_membership.py b/spd/clustering/sample_membership.py new file mode 100644 index 000000000..9e216d135 --- /dev/null +++ b/spd/clustering/sample_membership.py @@ -0,0 +1,238 @@ +from dataclasses import dataclass + +import numpy as np +import torch + +from spd.clustering.consts import ClusterCoactivationShaped + +_POPCOUNT_TABLE = np.unpackbits(np.arange(256, dtype=np.uint8)[:, None], axis=1).sum(axis=1) + + +def _index_dtype_for(n_samples: int) -> np.dtype[np.unsignedinteger]: + return np.dtype(np.uint32) if n_samples <= np.iinfo(np.uint32).max else np.dtype(np.uint64) + + +def _bytes_for_bitset(n_samples: int) -> int: + return (n_samples + 7) // 8 + + +def _prefer_sparse(n_samples: int, active_count: int, dtype: np.dtype[np.generic]) -> bool: + return active_count * np.dtype(dtype).itemsize <= _bytes_for_bitset(n_samples) + + +def _sample_indices_to_bits( + sample_indices: np.ndarray, + n_samples: int, +) -> np.ndarray: + n_bytes = _bytes_for_bitset(n_samples) + bits = np.zeros(n_bytes, dtype=np.uint8) + if sample_indices.size == 0: + return bits + + sample_indices = sample_indices.astype(np.int64, copy=False) + byte_indices = sample_indices // 8 + bit_offsets = sample_indices % 8 + np.bitwise_or.at(bits, byte_indices, (1 << bit_offsets).astype(np.uint8)) + return bits + + +def _count_sparse_sparse_intersection(a: np.ndarray, b: np.ndarray) -> int: + i = 0 + j = 0 + count = 0 + len_a = len(a) + len_b = len(b) + while i < len_a and j < len_b: + a_i = int(a[i]) + b_j = int(b[j]) + if a_i == b_j: + count += 1 + i += 1 + j += 1 + elif a_i < b_j: + i += 1 + else: + j += 1 + return count + + +def _union_sparse_sparse(a: np.ndarray, b: np.ndarray) -> np.ndarray: + i = 0 + j = 0 + out: list[int] = [] + len_a = len(a) + len_b = len(b) + while i < len_a and j < len_b: + a_i = int(a[i]) + b_j = int(b[j]) + if a_i == b_j: + out.append(a_i) + i += 1 + j += 1 + elif a_i < b_j: + out.append(a_i) + i += 1 + else: + out.append(b_j) + j += 1 + while i < len_a: + out.append(int(a[i])) + i += 1 + while j < len_b: + out.append(int(b[j])) + j += 1 + dtype = a.dtype if a.size > 0 else b.dtype + return np.asarray(out, dtype=dtype) + + +def _count_sparse_bitset_intersection( + sample_indices: np.ndarray, + bits: np.ndarray, +) -> int: + if sample_indices.size == 0: + return 0 + sample_indices = sample_indices.astype(np.int64, copy=False) + byte_indices = sample_indices // 8 + bit_offsets = sample_indices % 8 + return int(np.count_nonzero(bits[byte_indices] & ((1 << bit_offsets).astype(np.uint8)))) + + +@dataclass(frozen=True, slots=True) +class CompressedMembership: + """Exact sample memberships stored sparsely when cheaper, otherwise as a bitset.""" + + n_samples: int + active_count: int + sample_indices: np.ndarray | None = None + bits: np.ndarray | None = None + + def __post_init__(self) -> None: + has_sparse = self.sample_indices is not None + has_bits = self.bits is not None + assert has_sparse != has_bits, "Membership must use exactly one representation" + + @classmethod + def empty(cls, n_samples: int) -> "CompressedMembership": + dtype = _index_dtype_for(n_samples) + return cls(n_samples=n_samples, active_count=0, sample_indices=np.empty((0,), dtype=dtype)) + + @classmethod + def from_sample_indices( + cls, + sample_indices: np.ndarray, + n_samples: int, + ) -> "CompressedMembership": + dtype = _index_dtype_for(n_samples) + sample_indices = sample_indices.astype(dtype, copy=False) + active_count = int(sample_indices.size) + if _prefer_sparse(n_samples, active_count, dtype): + return cls( + n_samples=n_samples, active_count=active_count, sample_indices=sample_indices + ) + return cls( + n_samples=n_samples, + active_count=active_count, + bits=_sample_indices_to_bits(sample_indices, n_samples), + ) + + @classmethod + def from_bits( + cls, + bits: np.ndarray, + n_samples: int, + active_count: int, + ) -> "CompressedMembership": + return cls( + n_samples=n_samples, + active_count=active_count, + bits=bits.astype(np.uint8, copy=False), + ) + + @property + def is_sparse(self) -> bool: + return self.sample_indices is not None + + def count(self) -> int: + return self.active_count + + def intersection_count(self, other: "CompressedMembership") -> int: + assert self.n_samples == other.n_samples, "Memberships must share sample space" + if self.sample_indices is not None and other.sample_indices is not None: + return _count_sparse_sparse_intersection(self.sample_indices, other.sample_indices) + if self.sample_indices is not None: + assert other.bits is not None + return _count_sparse_bitset_intersection(self.sample_indices, other.bits) + if other.sample_indices is not None: + assert self.bits is not None + return _count_sparse_bitset_intersection(other.sample_indices, self.bits) + + assert self.bits is not None and other.bits is not None + overlap = np.bitwise_and(self.bits, other.bits) + return int(_POPCOUNT_TABLE[overlap].sum(dtype=np.uint64)) + + def union(self, other: "CompressedMembership") -> "CompressedMembership": + assert self.n_samples == other.n_samples, "Memberships must share sample space" + overlap_count = self.intersection_count(other) + union_count = self.active_count + other.active_count - overlap_count + + if self.sample_indices is not None and other.sample_indices is not None: + union_indices = _union_sparse_sparse(self.sample_indices, other.sample_indices) + return CompressedMembership.from_sample_indices(union_indices, self.n_samples) + + if self.bits is not None and other.bits is not None: + return CompressedMembership.from_bits( + bits=np.bitwise_or(self.bits, other.bits), + n_samples=self.n_samples, + active_count=union_count, + ) + + if self.bits is not None: + base_bits = self.bits.copy() + sparse_indices = other.sample_indices + else: + assert other.bits is not None + base_bits = other.bits.copy() + sparse_indices = self.sample_indices + + assert sparse_indices is not None + if sparse_indices.size > 0: + sparse_indices_i64 = sparse_indices.astype(np.int64, copy=False) + byte_indices = sparse_indices_i64 // 8 + bit_offsets = sparse_indices_i64 % 8 + np.bitwise_or.at(base_bits, byte_indices, (1 << bit_offsets).astype(np.uint8)) + + return CompressedMembership.from_bits( + bits=base_bits, + n_samples=self.n_samples, + active_count=union_count, + ) + + def to_bool_array(self) -> np.ndarray: + if self.sample_indices is not None: + result = np.zeros(self.n_samples, dtype=bool) + result[self.sample_indices.astype(np.int64, copy=False)] = True + return result + + assert self.bits is not None + unpacked = np.unpackbits(self.bits, bitorder="little") + return unpacked[: self.n_samples].astype(bool, copy=False) + + +BitsetMembership = CompressedMembership + + +def compute_coactivation_matrix( + memberships: list[CompressedMembership], +) -> ClusterCoactivationShaped: + """Compute the full coactivation matrix from compressed memberships.""" + n_groups = len(memberships) + coact = np.zeros((n_groups, n_groups), dtype=np.float32) + + for i, membership_i in enumerate(memberships): + coact[i, i] = membership_i.count() + for j in range(i + 1, n_groups): + overlap = membership_i.intersection_count(memberships[j]) + coact[i, j] = overlap + coact[j, i] = overlap + + return torch.from_numpy(coact) diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 844365681..f06e1c88a 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -29,7 +29,10 @@ from spd.clustering.activations import ( ProcessedActivations, + ProcessedMemberships, collect_activations, + collect_memberships_lm, + collect_memberships_resid_mlp, component_activations, process_activations, ) @@ -43,7 +46,7 @@ from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.math.semilog import semilog -from spd.clustering.merge import merge_iteration +from spd.clustering.merge import merge_iteration, merge_iteration_memberships from spd.clustering.merge_history import MergeHistory from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration @@ -290,44 +293,96 @@ def main(run_config: ClusteringRunConfig) -> Path: logger.info("Loading model") model = ComponentModel.from_run_info(spd_run).to(device) + processed_activations: ProcessedActivations | None = None + processed_memberships: ProcessedMemberships | None = None + activations: ActivationsTensor | None = None + component_labels: ComponentLabels + # 4. Compute activations logger.info("Computing activations") - if task_name == "lm": - assert run_config.n_tokens is not None, "n_tokens must be set for LM tasks" - assert run_config.n_tokens_per_seq is not None, "n_tokens_per_seq must be set for LM tasks" - activations_dict = collect_activations( - model=model, - dataloader=dataloader, - n_tokens=run_config.n_tokens, - n_tokens_per_seq=run_config.n_tokens_per_seq, - device=device, - seed=run_config.dataset_seed, - ) + use_compressed_merge = run_config.merge_config.activation_threshold is not None + if use_compressed_merge: + activation_threshold = run_config.merge_config.activation_threshold + assert activation_threshold is not None + if task_name == "lm": + assert run_config.n_tokens is not None, "n_tokens must be set for LM tasks" + assert run_config.n_tokens_per_seq is not None, ( + "n_tokens_per_seq must be set for LM tasks" + ) + processed_memberships = collect_memberships_lm( + model=model, + dataloader=dataloader, + n_tokens=run_config.n_tokens, + n_tokens_per_seq=run_config.n_tokens_per_seq, + device=device, + seed=run_config.dataset_seed, + activation_threshold=activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_modules=run_config.merge_config.filter_modules, + ) + else: + processed_memberships = collect_memberships_resid_mlp( + model=model, + dataloader=dataloader, + n_samples=run_config.n_samples or run_config.batch_size, + device=device, + activation_threshold=activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_modules=run_config.merge_config.filter_modules, + ) + processed_activations = processed_memberships.preview + component_labels = ComponentLabels(processed_memberships.labels.copy()) else: - # resid_mlp: single batch, no sequence dimension - batch_data = next(iter(dataloader)) - batch, _ = batch_data # DatasetGeneratedDataLoader yields (batch, labels) - activations_dict = component_activations( - model=model, - batch=batch, - device=device, + if task_name == "lm": + assert run_config.n_tokens is not None, "n_tokens must be set for LM tasks" + assert run_config.n_tokens_per_seq is not None, ( + "n_tokens_per_seq must be set for LM tasks" + ) + activations_dict = collect_activations( + model=model, + dataloader=dataloader, + n_tokens=run_config.n_tokens, + n_tokens_per_seq=run_config.n_tokens_per_seq, + device=device, + seed=run_config.dataset_seed, + ) + else: + n_samples_target = run_config.n_samples or run_config.batch_size + collected_batches: dict[str, list[Tensor]] = {} + n_collected = 0 + for batch_data in dataloader: + batch, _ = batch_data + batch_take = min(batch.shape[0], n_samples_target - n_collected) + acts = component_activations(model=model, batch=batch[:batch_take], device=device) + for key, act in acts.items(): + collected_batches.setdefault(key, []).append(act.cpu()) + n_collected += batch_take + if n_collected >= n_samples_target: + break + assert n_collected >= n_samples_target, ( + f"Dataloader exhausted: collected {n_collected} samples but needed {n_samples_target}" + ) + activations_dict = { + key: torch.cat(chunks, dim=0) for key, chunks in collected_batches.items() + } + + logger.info("Processing activations") + processed_activations = process_activations( + activations=activations_dict, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + seq_mode=None, + filter_modules=run_config.merge_config.filter_modules, ) + activations = processed_activations.activations.to(device) + component_labels = ComponentLabels(processed_activations.labels.copy()) + del activations_dict - # 5. Process activations - logger.info("Processing activations") - processed_activations: ProcessedActivations = process_activations( - activations=activations_dict, - filter_dead_threshold=run_config.merge_config.filter_dead_threshold, - seq_mode=None, - filter_modules=run_config.merge_config.filter_modules, - ) - - # 6. Log activations (if WandB enabled) - if wandb_run is not None: + # 5. Log activations preview (if WandB enabled) + if wandb_run is not None and processed_activations is not None: logger.info("Plotting activations") plot_activations( processed_activations=processed_activations, - save_dir=None, # Don't save to disk, only WandB + save_dir=None, n_samples_max=256, wandb_run=wandb_run, ) @@ -339,11 +394,6 @@ def main(run_config: ClusteringRunConfig) -> Path: single=True, ) - # Extract what we need, then free the model and temporary objects - activations: ActivationsTensor = processed_activations.activations.to(device) - component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) - del processed_activations - del activations_dict del model gc.collect() torch.cuda.empty_cache() @@ -356,12 +406,22 @@ def main(run_config: ClusteringRunConfig) -> Path: else None ) - history: MergeHistory = merge_iteration( - merge_config=run_config.merge_config, - activations=activations, - component_labels=component_labels, - log_callback=log_callback, - ) + if processed_memberships is not None: + history = merge_iteration_memberships( + merge_config=run_config.merge_config, + memberships=processed_memberships.memberships, + n_samples=processed_memberships.n_samples, + component_labels=component_labels, + log_callback=log_callback, + ) + else: + assert activations is not None + history = merge_iteration( + merge_config=run_config.merge_config, + activations=activations, + component_labels=component_labels, + log_callback=log_callback, + ) # 8. Save merge history From 8483070c6c6268532c5d73a0cb11e1bcbeec7caa Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Wed, 18 Mar 2026 15:45:31 +0000 Subject: [PATCH 2/5] Speed up clustering merge benchmarks --- spd/clustering/membership_snapshot.py | 111 ++++++++++++++++++ spd/clustering/merge.py | 52 ++++++++ spd/clustering/sample_membership.py | 67 +++++++++-- .../scripts/benchmark_merge_snapshot.py | 79 +++++++++++++ .../scripts/cache_membership_snapshot.py | 93 +++++++++++++++ 5 files changed, 391 insertions(+), 11 deletions(-) create mode 100644 spd/clustering/membership_snapshot.py create mode 100644 spd/clustering/scripts/benchmark_merge_snapshot.py create mode 100644 spd/clustering/scripts/cache_membership_snapshot.py diff --git a/spd/clustering/membership_snapshot.py b/spd/clustering/membership_snapshot.py new file mode 100644 index 000000000..5868220cc --- /dev/null +++ b/spd/clustering/membership_snapshot.py @@ -0,0 +1,111 @@ +import json +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +from scipy import sparse + +from spd.clustering.consts import ComponentLabels +from spd.clustering.sample_membership import CompressedMembership + + +@dataclass(frozen=True, slots=True) +class MembershipSnapshot: + """Disk-friendly sparse membership snapshot for repeatable merge benchmarks.""" + + matrix_csc: sparse.csc_matrix + labels: ComponentLabels + + @property + def n_samples(self) -> int: + shape = self.matrix_csc.shape + assert shape is not None + return int(shape[0]) + + @property + def n_components(self) -> int: + shape = self.matrix_csc.shape + assert shape is not None + return int(shape[1]) + + def to_memberships(self) -> list[CompressedMembership]: + memberships: list[CompressedMembership] = [] + for col_idx in range(self.n_components): + sample_indices = self.matrix_csc.indices[ + self.matrix_csc.indptr[col_idx] : self.matrix_csc.indptr[col_idx + 1] + ].astype(np.int64, copy=False) + memberships.append( + CompressedMembership.from_sample_indices( + sample_indices=sample_indices, + n_samples=self.n_samples, + ) + ) + return memberships + + def to_csr(self) -> sparse.sparray | sparse.spmatrix: + return self.matrix_csc.tocsr() + + +def memberships_to_csc( + memberships: list[CompressedMembership], + n_samples: int, +) -> sparse.csc_matrix: + n_components = len(memberships) + if n_components == 0: + return sparse.csc_matrix((n_samples, 0), dtype=np.uint8) + + nnz = sum(membership.count() for membership in memberships) + row_indices = np.empty(nnz, dtype=np.int64) + col_indices = np.empty(nnz, dtype=np.int32) + + offset = 0 + for col_idx, membership in enumerate(memberships): + sample_indices = membership.to_sample_indices().astype(np.int64, copy=False) + col_nnz = sample_indices.size + row_indices[offset : offset + col_nnz] = sample_indices + col_indices[offset : offset + col_nnz] = col_idx + offset += col_nnz + + values = np.ones(nnz, dtype=np.uint8) + return sparse.csc_matrix( + (values, (row_indices, col_indices)), + shape=(n_samples, n_components), + dtype=np.uint8, + ) + + +def save_membership_snapshot( + output_dir: Path, + *, + memberships: list[CompressedMembership], + labels: ComponentLabels, + n_samples: int, +) -> Path: + output_dir.mkdir(parents=True, exist_ok=True) + matrix_path = output_dir / "memberships.npz" + metadata_path = output_dir / "metadata.json" + + matrix_csc = memberships_to_csc(memberships, n_samples=n_samples) + sparse.save_npz(matrix_path, matrix_csc) + metadata_path.write_text( + json.dumps( + { + "n_samples": n_samples, + "n_components": len(labels), + "labels": list(labels), + }, + indent=2, + ) + ) + return output_dir + + +def load_membership_snapshot(path: Path) -> MembershipSnapshot: + matrix_path = path / "memberships.npz" + metadata_path = path / "metadata.json" + matrix_csc = sparse.load_npz(matrix_path).tocsc() + metadata = json.loads(metadata_path.read_text()) + labels = ComponentLabels(metadata["labels"]) + assert matrix_csc.shape[0] == metadata["n_samples"] + assert matrix_csc.shape[1] == metadata["n_components"] + return MembershipSnapshot(matrix_csc=matrix_csc, labels=labels) diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 28acc9d9e..828473baf 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -4,6 +4,7 @@ This wraps the pure merge_iteration_pure() function and adds WandB/plotting callbacks. """ +import time import warnings from typing import Protocol @@ -29,6 +30,25 @@ from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory from spd.clustering.sample_membership import BitsetMembership, compute_coactivation_matrix +from spd.log import logger + + +def _choose_coact_device(coact: ClusterCoactivationShaped) -> torch.device: + """Prefer GPU for dense cost-matrix math when enough memory is available.""" + if not torch.cuda.is_available(): + return coact.device + + if coact.device.type == "cuda": + return coact.device + + free_bytes, _ = torch.cuda.mem_get_info() + coact_bytes = coact.numel() * coact.element_size() + # Current coact, a temporary clone during recompute, and the costs tensor dominate. + required_bytes = coact_bytes * 3 + 512 * 1024**2 + if free_bytes >= required_bytes: + return torch.device("cuda") + + return coact.device class LogCallback(Protocol): @@ -198,7 +218,27 @@ def merge_iteration_memberships( log_callback: LogCallback | None = None, ) -> MergeHistory: """Exact merge iteration using compressed sample memberships.""" + coact_start = time.perf_counter() + logger.info( + "Building coactivation matrix from compressed memberships " + f"(n_groups={len(memberships)}, n_samples={n_samples})" + ) current_coact: ClusterCoactivationShaped = compute_coactivation_matrix(memberships) + logger.info( + "Built coactivation matrix in " + f"{time.perf_counter() - coact_start:.2f}s " + f"(shape={tuple(current_coact.shape)})" + ) + coact_device = _choose_coact_device(current_coact) + if coact_device != current_coact.device: + transfer_start = time.perf_counter() + current_coact = current_coact.to(device=coact_device) + logger.info( + "Moved compressed coactivation matrix to " + f"{coact_device} in {time.perf_counter() - transfer_start:.2f}s" + ) + else: + logger.info(f"Keeping compressed coactivation matrix on {current_coact.device}") c_components: int = current_coact.shape[0] assert current_coact.shape[1] == c_components, "Coactivation matrix must be square" @@ -218,6 +258,8 @@ def merge_iteration_memberships( unit="iter", total=num_iters, ) + merge_start = time.perf_counter() + log_every = min(10, num_iters) for iter_idx in pbar: costs: ClusterCoactivationShaped = compute_merge_costs( coact=current_coact / n_samples, @@ -266,6 +308,16 @@ def merge_iteration_memberships( diag_acts=diag_acts, ) + if (iter_idx + 1) % log_every == 0 or iter_idx == 0 or iter_idx + 1 == num_iters: + elapsed = time.perf_counter() - merge_start + logger.info( + "Compressed merge progress: " + f"iter={iter_idx + 1}/{num_iters}, " + f"elapsed={elapsed:.2f}s, " + f"sec_per_iter={elapsed / (iter_idx + 1):.4f}, " + f"k_groups={k_groups - 1}" + ) + k_groups -= 1 assert current_coact.shape[0] == k_groups, ( "Coactivation matrix shape should match number of groups" diff --git a/spd/clustering/sample_membership.py b/spd/clustering/sample_membership.py index 9e216d135..f01543911 100644 --- a/spd/clustering/sample_membership.py +++ b/spd/clustering/sample_membership.py @@ -2,6 +2,7 @@ import numpy as np import torch +from scipy import sparse from spd.clustering.consts import ClusterCoactivationShaped @@ -97,6 +98,20 @@ def _count_sparse_bitset_intersection( return int(np.count_nonzero(bits[byte_indices] & ((1 << bit_offsets).astype(np.uint8)))) +def _bitset_to_sample_indices(bits: np.ndarray, n_samples: int) -> np.ndarray: + nonzero_byte_idxs = np.flatnonzero(bits) + if nonzero_byte_idxs.size == 0: + return np.empty((0,), dtype=_index_dtype_for(n_samples)) + + unpacked = np.unpackbits(bits[nonzero_byte_idxs], bitorder="little") + active_bit_positions = np.flatnonzero(unpacked) + sample_indices = nonzero_byte_idxs[active_bit_positions // 8].astype( + np.int64, copy=False + ) * 8 + (active_bit_positions % 8) + sample_indices = sample_indices[sample_indices < n_samples] + return sample_indices.astype(_index_dtype_for(n_samples), copy=False) + + @dataclass(frozen=True, slots=True) class CompressedMembership: """Exact sample memberships stored sparsely when cheaper, otherwise as a bitset.""" @@ -217,6 +232,13 @@ def to_bool_array(self) -> np.ndarray: unpacked = np.unpackbits(self.bits, bitorder="little") return unpacked[: self.n_samples].astype(bool, copy=False) + def to_sample_indices(self) -> np.ndarray: + if self.sample_indices is not None: + return self.sample_indices + + assert self.bits is not None + return _bitset_to_sample_indices(self.bits, self.n_samples) + BitsetMembership = CompressedMembership @@ -224,15 +246,38 @@ def to_bool_array(self) -> np.ndarray: def compute_coactivation_matrix( memberships: list[CompressedMembership], ) -> ClusterCoactivationShaped: - """Compute the full coactivation matrix from compressed memberships.""" - n_groups = len(memberships) - coact = np.zeros((n_groups, n_groups), dtype=np.float32) + """Compute the full coactivation matrix from compressed memberships. - for i, membership_i in enumerate(memberships): - coact[i, i] = membership_i.count() - for j in range(i + 1, n_groups): - overlap = membership_i.intersection_count(memberships[j]) - coact[i, j] = overlap - coact[j, i] = overlap - - return torch.from_numpy(coact) + This builds a sparse sample-by-component matrix and computes X.T @ X, + which is much faster than Python-level pairwise intersections in the + typical highly sparse regime. + """ + n_groups = len(memberships) + if n_groups == 0: + return torch.empty((0, 0), dtype=torch.float32) + + n_samples = memberships[0].n_samples + assert all(membership.n_samples == n_samples for membership in memberships), ( + "Memberships must share sample space" + ) + + nnz = sum(membership.count() for membership in memberships) + row_indices = np.empty(nnz, dtype=np.int64) + col_indices = np.empty(nnz, dtype=np.int32) + + offset = 0 + for group_idx, membership in enumerate(memberships): + sample_indices = membership.to_sample_indices().astype(np.int64, copy=False) + group_nnz = sample_indices.size + row_indices[offset : offset + group_nnz] = sample_indices + col_indices[offset : offset + group_nnz] = group_idx + offset += group_nnz + + values = np.ones(nnz, dtype=np.int32) + activation_matrix = sparse.csr_matrix( + (values, (row_indices, col_indices)), + shape=(n_samples, n_groups), + dtype=np.int32, + ) + coact = (activation_matrix.T @ activation_matrix).toarray() + return torch.from_numpy(coact.astype(np.float32, copy=False)) diff --git a/spd/clustering/scripts/benchmark_merge_snapshot.py b/spd/clustering/scripts/benchmark_merge_snapshot.py new file mode 100644 index 000000000..247572483 --- /dev/null +++ b/spd/clustering/scripts/benchmark_merge_snapshot.py @@ -0,0 +1,79 @@ +"""Benchmark exact merge performance from a cached membership snapshot.""" + +import argparse +import time +from pathlib import Path +from typing import Any + +import numpy as np + +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.membership_snapshot import load_membership_snapshot +from spd.clustering.merge import merge_iteration_memberships +from spd.utils.general_utils import replace_pydantic_model + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--snapshot-dir", type=Path, required=True) + parser.add_argument("--config", type=Path, required=True) + parser.add_argument("--iters", type=int, default=100) + parser.add_argument("--profile-overlap", action="store_true") + args = parser.parse_args() + + snapshot = load_membership_snapshot(args.snapshot_dir) + memberships = snapshot.to_memberships() + + if args.profile_overlap: + row_csr: Any = snapshot.to_csr() + merged = memberships[0].union(memberships[1]) + merged_indices = merged.to_sample_indices().astype(np.int64, copy=False) + + st = time.time() + current = np.array([merged.intersection_count(m) for m in memberships], dtype=np.int64) + current_s = time.time() - st + + st = time.time() + row_sum = np.asarray(row_csr[merged_indices].sum(axis=0)).ravel().astype(np.int64) + row_sum_s = time.time() - st + + assert np.array_equal(current, row_sum) + print( + { + "phase": "overlap_profile", + "merged_size": int(merged_indices.size), + "current_s": round(current_s, 4), + "row_sum_s": round(row_sum_s, 4), + "speedup": round(current_s / row_sum_s, 2), + } + ) + + run_config = ClusteringRunConfig.from_file(args.config) + merge_config = replace_pydantic_model(run_config.merge_config, {"iters": args.iters}) + + st = time.time() + history = merge_iteration_memberships( + merge_config=merge_config, + memberships=memberships, + n_samples=snapshot.n_samples, + component_labels=snapshot.labels, + log_callback=None, + ) + elapsed = time.time() - st + print( + { + "phase": "merge_benchmark", + "snapshot_dir": str(args.snapshot_dir), + "n_samples": snapshot.n_samples, + "n_components": snapshot.n_components, + "iters": history.n_iters_current, + "elapsed_s": round(elapsed, 2), + "sec_per_iter": round(elapsed / history.n_iters_current, 4) + if history.n_iters_current + else None, + } + ) + + +if __name__ == "__main__": + main() diff --git a/spd/clustering/scripts/cache_membership_snapshot.py b/spd/clustering/scripts/cache_membership_snapshot.py new file mode 100644 index 000000000..2d65caa08 --- /dev/null +++ b/spd/clustering/scripts/cache_membership_snapshot.py @@ -0,0 +1,93 @@ +"""Collect compressed memberships once and save them for repeated merge benchmarking.""" + +import argparse +from pathlib import Path + +from spd.clustering.activations import collect_memberships_lm, collect_memberships_resid_mlp +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.dataset import create_clustering_dataloader +from spd.clustering.membership_snapshot import save_membership_snapshot +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import replace_pydantic_model + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--config", type=Path, required=True) + parser.add_argument("--output-dir", type=Path, required=True) + parser.add_argument("--n-tokens", type=int, default=None) + parser.add_argument("--n-samples", type=int, default=None) + parser.add_argument("--batch-size", type=int, default=None) + args = parser.parse_args() + + run_config = ClusteringRunConfig.from_file(args.config) + overrides: dict[str, int] = {} + if args.n_tokens is not None: + overrides["n_tokens"] = args.n_tokens + if args.n_samples is not None: + overrides["n_samples"] = args.n_samples + if args.batch_size is not None: + overrides["batch_size"] = args.batch_size + if overrides: + run_config = replace_pydantic_model(run_config, overrides) + + assert run_config.merge_config.activation_threshold is not None, ( + "Snapshotting only supports thresholded compressed memberships" + ) + + spd_run = SPDRunInfo.from_path(run_config.model_path) + task_name: TaskName = spd_run.config.task_config.task_name + device = get_device() + model = ComponentModel.from_run_info(spd_run).to(device) + dataloader = create_clustering_dataloader( + model_path=run_config.model_path, + task_name=task_name, + batch_size=run_config.batch_size, + seed=run_config.dataset_seed, + ) + + if task_name == "lm": + assert run_config.n_tokens is not None + assert run_config.n_tokens_per_seq is not None + processed = collect_memberships_lm( + model=model, + dataloader=dataloader, + n_tokens=run_config.n_tokens, + n_tokens_per_seq=run_config.n_tokens_per_seq, + device=device, + seed=run_config.dataset_seed, + activation_threshold=run_config.merge_config.activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_modules=run_config.merge_config.filter_modules, + ) + else: + n_samples = run_config.n_samples or run_config.batch_size + processed = collect_memberships_resid_mlp( + model=model, + dataloader=dataloader, + n_samples=n_samples, + device=device, + activation_threshold=run_config.merge_config.activation_threshold, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_modules=run_config.merge_config.filter_modules, + ) + + save_membership_snapshot( + args.output_dir, + memberships=processed.memberships, + labels=processed.labels, + n_samples=processed.n_samples, + ) + print( + { + "output_dir": str(args.output_dir), + "n_samples": processed.n_samples, + "n_components": len(processed.labels), + } + ) + + +if __name__ == "__main__": + main() From ac99c0ee247fe1878806ecc3e366b73c9916be9e Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Wed, 18 Mar 2026 16:00:55 +0000 Subject: [PATCH 3/5] Speed up exact clustering overlap updates --- pyproject.toml | 3 +- spd/clustering/compute_costs.py | 45 +++++++--- spd/clustering/merge.py | 16 +++- spd/clustering/sample_membership.py | 89 ++++++++++++++++--- .../scripts/benchmark_merge_snapshot.py | 17 ++-- tests/clustering/test_merge_integration.py | 38 ++++++++ uv.lock | 30 +++++++ 7 files changed, 206 insertions(+), 32 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index cef1fc201..8d48394eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,9 @@ dependencies = [ "aiolimiter>=1.2", "openrouter>=0.1.1", "httpx>=0.28.0", - "zstandard", # For streaming datasets + "zstandard", # For streaming datasets "kaleido==0.2.1", + "numba>=0.64.0", ] [dependency-groups] diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index 96db03531..291db2e77 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -2,11 +2,15 @@ import torch from jaxtyping import Bool, Float +from scipy import sparse from torch import Tensor from spd.clustering.consts import ClusterCoactivationShaped, MergePair from spd.clustering.math.merge_matrix import GroupMerge -from spd.clustering.sample_membership import BitsetMembership +from spd.clustering.sample_membership import ( + BitsetMembership, + count_group_overlaps_from_component_rows, +) def compute_mdl_cost( @@ -195,6 +199,7 @@ def recompute_coacts_merge_pair_memberships( merges: GroupMerge, merge_pair: MergePair, memberships: list[BitsetMembership], + component_activity_csr: sparse.csr_matrix | None = None, ) -> tuple[ GroupMerge, Float[Tensor, "k_groups-1 k_groups-1"], @@ -209,26 +214,40 @@ def recompute_coacts_merge_pair_memberships( remove_idx: int = max(merge_pair) merged_membership = memberships[merge_pair[0]].union(memberships[merge_pair[1]]) - coact_with_merge = torch.tensor( - [float(merged_membership.intersection_count(membership)) for membership in memberships], - dtype=coact.dtype, - device=coact.device, - ) - merge_new: GroupMerge = merges.merge_groups( merge_pair[0], merge_pair[1], ) - coact_temp: ClusterCoactivationShaped = coact.clone() - coact_temp[new_group_idx, :] = coact_with_merge - coact_temp[:, new_group_idx] = coact_with_merge - mask: Bool[Tensor, " k_groups"] = torch.ones( - coact_temp.shape[0], dtype=torch.bool, device=coact_temp.device + coact.shape[0], dtype=torch.bool, device=coact.device ) mask[remove_idx] = False - coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact_temp[mask, :][:, mask] + coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact[mask, :][:, mask].clone() + + if component_activity_csr is not None: + merged_rows = merged_membership.to_sample_indices() + coact_with_merge_np = count_group_overlaps_from_component_rows( + merged_rows=merged_rows, + component_activity_csr=component_activity_csr, + group_idxs=merge_new.group_idxs.cpu().numpy(), + n_groups=merge_new.k_groups, + ) + coact_with_merge = torch.tensor( + coact_with_merge_np, + dtype=coact.dtype, + device=coact.device, + ) + else: + coact_with_merge = torch.tensor( + [float(merged_membership.intersection_count(membership)) for membership in memberships], + dtype=coact.dtype, + device=coact.device, + ) + coact_with_merge = coact_with_merge[mask] + + coact_new[new_group_idx, :] = coact_with_merge + coact_new[:, new_group_idx] = coact_with_merge coact_new[new_group_idx, new_group_idx] = float(merged_membership.count()) memberships_new = memberships.copy() diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 828473baf..474d95630 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -10,6 +10,7 @@ import torch from jaxtyping import Bool, Float +from scipy import sparse from torch import Tensor from tqdm import tqdm @@ -29,7 +30,11 @@ from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory -from spd.clustering.sample_membership import BitsetMembership, compute_coactivation_matrix +from spd.clustering.sample_membership import ( + BitsetMembership, + compute_coactivation_matrix, + memberships_to_sample_component_csr, +) from spd.log import logger @@ -218,6 +223,14 @@ def merge_iteration_memberships( log_callback: LogCallback | None = None, ) -> MergeHistory: """Exact merge iteration using compressed sample memberships.""" + csr_start = time.perf_counter() + component_activity_csr: sparse.csr_matrix = memberships_to_sample_component_csr(memberships) + logger.info( + "Built component activity CSR in " + f"{time.perf_counter() - csr_start:.2f}s " + f"(shape={component_activity_csr.shape}, nnz={component_activity_csr.nnz})" + ) + coact_start = time.perf_counter() logger.info( "Building coactivation matrix from compressed memberships " @@ -274,6 +287,7 @@ def merge_iteration_memberships( merges=current_merge, merge_pair=merge_pair, memberships=current_memberships, + component_activity_csr=component_activity_csr, ) merge_history.add_iteration( diff --git a/spd/clustering/sample_membership.py b/spd/clustering/sample_membership.py index f01543911..ea8ab0a19 100644 --- a/spd/clustering/sample_membership.py +++ b/spd/clustering/sample_membership.py @@ -2,6 +2,7 @@ import numpy as np import torch +from numba import njit from scipy import sparse from spd.clustering.consts import ClusterCoactivationShaped @@ -112,6 +113,30 @@ def _bitset_to_sample_indices(bits: np.ndarray, n_samples: int) -> np.ndarray: return sample_indices.astype(_index_dtype_for(n_samples), copy=False) +@njit(cache=True) +def _count_group_overlaps_rows_numba( + merged_rows: np.ndarray, + indptr: np.ndarray, + indices: np.ndarray, + group_idxs: np.ndarray, + n_groups: int, +) -> np.ndarray: + counts = np.zeros(n_groups, dtype=np.int64) + seen = np.full(n_groups, -1, dtype=np.int64) + stamp = 0 + for row in merged_rows: + stamp += 1 + start = indptr[row] + end = indptr[row + 1] + for pos in range(start, end): + group_idx = group_idxs[indices[pos]] + if seen[group_idx] == stamp: + continue + seen[group_idx] = stamp + counts[group_idx] += 1 + return counts + + @dataclass(frozen=True, slots=True) class CompressedMembership: """Exact sample memberships stored sparsely when cheaper, otherwise as a bitset.""" @@ -243,18 +268,13 @@ def to_sample_indices(self) -> np.ndarray: BitsetMembership = CompressedMembership -def compute_coactivation_matrix( +def memberships_to_sample_component_csr( memberships: list[CompressedMembership], -) -> ClusterCoactivationShaped: - """Compute the full coactivation matrix from compressed memberships. - - This builds a sparse sample-by-component matrix and computes X.T @ X, - which is much faster than Python-level pairwise intersections in the - typical highly sparse regime. - """ +) -> sparse.csr_matrix: + """Build a binary sample-by-component CSR matrix from memberships.""" n_groups = len(memberships) if n_groups == 0: - return torch.empty((0, 0), dtype=torch.float32) + return sparse.csr_matrix((0, 0), dtype=np.uint8) n_samples = memberships[0].n_samples assert all(membership.n_samples == n_samples for membership in memberships), ( @@ -273,11 +293,56 @@ def compute_coactivation_matrix( col_indices[offset : offset + group_nnz] = group_idx offset += group_nnz - values = np.ones(nnz, dtype=np.int32) - activation_matrix = sparse.csr_matrix( + values = np.ones(nnz, dtype=np.uint8) + return sparse.csr_matrix( (values, (row_indices, col_indices)), shape=(n_samples, n_groups), - dtype=np.int32, + dtype=np.uint8, + ) + + +def count_group_overlaps_from_component_rows( + merged_rows: np.ndarray, + component_activity_csr: sparse.csr_matrix, + group_idxs: np.ndarray, + n_groups: int, +) -> np.ndarray: + """Count exact merged-group overlaps by scanning original component rows. + + `group_idxs` maps original component indices to the current group index after + the candidate merge has been applied. + """ + merged_rows_i64 = merged_rows.astype(np.int64, copy=False) + indptr = component_activity_csr.indptr.astype(np.int64, copy=False) + indices = component_activity_csr.indices.astype(np.int32, copy=False) + group_idxs_i64 = group_idxs.astype(np.int64, copy=False) + return _count_group_overlaps_rows_numba( + merged_rows=merged_rows_i64, + indptr=indptr, + indices=indices, + group_idxs=group_idxs_i64, + n_groups=n_groups, + ) + + +def compute_coactivation_matrix( + memberships: list[CompressedMembership], +) -> ClusterCoactivationShaped: + """Compute the full coactivation matrix from compressed memberships. + + This builds a sparse sample-by-component matrix and computes X.T @ X, + which is much faster than Python-level pairwise intersections in the + typical highly sparse regime. + """ + n_groups = len(memberships) + if n_groups == 0: + return torch.empty((0, 0), dtype=torch.float32) + + n_samples = memberships[0].n_samples + assert all(membership.n_samples == n_samples for membership in memberships), ( + "Memberships must share sample space" ) + + activation_matrix = memberships_to_sample_component_csr(memberships).astype(np.int32, copy=False) coact = (activation_matrix.T @ activation_matrix).toarray() return torch.from_numpy(coact.astype(np.float32, copy=False)) diff --git a/spd/clustering/scripts/benchmark_merge_snapshot.py b/spd/clustering/scripts/benchmark_merge_snapshot.py index 247572483..bd253d807 100644 --- a/spd/clustering/scripts/benchmark_merge_snapshot.py +++ b/spd/clustering/scripts/benchmark_merge_snapshot.py @@ -10,6 +10,7 @@ from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.membership_snapshot import load_membership_snapshot from spd.clustering.merge import merge_iteration_memberships +from spd.clustering.sample_membership import count_group_overlaps_from_component_rows from spd.utils.general_utils import replace_pydantic_model @@ -28,23 +29,29 @@ def main() -> None: row_csr: Any = snapshot.to_csr() merged = memberships[0].union(memberships[1]) merged_indices = merged.to_sample_indices().astype(np.int64, copy=False) + group_idxs = np.arange(snapshot.n_components, dtype=np.int64) st = time.time() current = np.array([merged.intersection_count(m) for m in memberships], dtype=np.int64) current_s = time.time() - st st = time.time() - row_sum = np.asarray(row_csr[merged_indices].sum(axis=0)).ravel().astype(np.int64) - row_sum_s = time.time() - st + row_counts = count_group_overlaps_from_component_rows( + merged_rows=merged_indices, + component_activity_csr=row_csr, + group_idxs=group_idxs, + n_groups=snapshot.n_components, + ) + row_counts_s = time.time() - st - assert np.array_equal(current, row_sum) + assert np.array_equal(current, row_counts) print( { "phase": "overlap_profile", "merged_size": int(merged_indices.size), "current_s": round(current_s, 4), - "row_sum_s": round(row_sum_s, 4), - "speedup": round(current_s / row_sum_s, 2), + "row_counts_s": round(row_counts_s, 4), + "speedup": round(current_s / row_counts_s, 2), } ) diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 8492300de..5edeb5450 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -2,9 +2,16 @@ import torch +from spd.clustering.compute_costs import recompute_coacts_merge_pair_memberships from spd.clustering.consts import ComponentLabels from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.sample_membership import ( + CompressedMembership, + compute_coactivation_matrix, + memberships_to_sample_component_csr, +) class TestMergeIntegration: @@ -149,3 +156,34 @@ def test_merge_with_small_components(self): # Early stopping may occur at 2 groups, so final count could be 2 or 3 assert history.merges.k_groups[-1].item() >= 2 assert history.merges.k_groups[-1].item() <= 3 + + def test_membership_recompute_matches_row_oriented_path(self): + """Row-oriented overlap recompute should match the direct membership path exactly.""" + memberships = [ + CompressedMembership.from_sample_indices(torch.tensor(indices).numpy(), n_samples=8) + for indices in ([0, 2, 5], [1, 2], [0, 3], [4, 5, 6]) + ] + coact = compute_coactivation_matrix(memberships) + merges = GroupMerge.identity(n_components=len(memberships)) + component_activity_csr = memberships_to_sample_component_csr(memberships) + + merge_old, coact_old, memberships_old = recompute_coacts_merge_pair_memberships( + coact=coact, + merges=merges, + merge_pair=(0, 1), + memberships=memberships, + component_activity_csr=None, + ) + merge_row, coact_row, memberships_row = recompute_coacts_merge_pair_memberships( + coact=coact, + merges=merges, + merge_pair=(0, 1), + memberships=memberships, + component_activity_csr=component_activity_csr, + ) + + assert torch.equal(merge_old.group_idxs, merge_row.group_idxs) + assert torch.equal(coact_old, coact_row) + assert [membership.count() for membership in memberships_old] == [ + membership.count() for membership in memberships_row + ] diff --git a/uv.lock b/uv.lock index becdf51ec..be4f3d735 100644 --- a/uv.lock +++ b/uv.lock @@ -870,6 +870,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/8f/8f6f491d595a9e5912971f3f863d81baddccc8a4d0c3749d6a0dd9ffc9df/kiwisolver-1.4.9-cp313-cp313t-win_arm64.whl", hash = "sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c", size = 68646, upload-time = "2025-08-10T21:27:00.52Z" }, ] +[[package]] +name = "llvmlite" +version = "0.46.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/74/cd/08ae687ba099c7e3d21fe2ea536500563ef1943c5105bf6ab4ee3829f68e/llvmlite-0.46.0.tar.gz", hash = "sha256:227c9fd6d09dce2783c18b754b7cd9d9b3b3515210c46acc2d3c5badd9870ceb", size = 193456, upload-time = "2025-12-08T18:15:36.295Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/ff/3eba7eb0aed4b6fca37125387cd417e8c458e750621fce56d2c541f67fa8/llvmlite-0.46.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:30b60892d034bc560e0ec6654737aaa74e5ca327bd8114d82136aa071d611172", size = 37232767, upload-time = "2025-12-08T18:15:13.22Z" }, + { url = "https://files.pythonhosted.org/packages/0e/54/737755c0a91558364b9200702c3c9c15d70ed63f9b98a2c32f1c2aa1f3ba/llvmlite-0.46.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6cc19b051753368a9c9f31dc041299059ee91aceec81bd57b0e385e5d5bf1a54", size = 56275176, upload-time = "2025-12-08T18:15:16.339Z" }, + { url = "https://files.pythonhosted.org/packages/e6/91/14f32e1d70905c1c0aa4e6609ab5d705c3183116ca02ac6df2091868413a/llvmlite-0.46.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bca185892908f9ede48c0acd547fe4dc1bafefb8a4967d47db6cf664f9332d12", size = 55128629, upload-time = "2025-12-08T18:15:19.493Z" }, + { url = "https://files.pythonhosted.org/packages/4a/a7/d526ae86708cea531935ae777b6dbcabe7db52718e6401e0fb9c5edea80e/llvmlite-0.46.0-cp313-cp313-win_amd64.whl", hash = "sha256:67438fd30e12349ebb054d86a5a1a57fd5e87d264d2451bcfafbbbaa25b82a35", size = 38138941, upload-time = "2025-12-08T18:15:22.536Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -1068,6 +1080,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/af/cd3290a647df567645353feed451ef4feaf5844496ced69c4dcb84295ff4/nodejs_wheel_binaries-24.12.0-py2.py3-none-win_arm64.whl", hash = "sha256:d0c2273b667dd7e3f55e369c0085957b702144b1b04bfceb7ce2411e58333757", size = 39048104, upload-time = "2025-12-11T21:12:23.495Z" }, ] +[[package]] +name = "numba" +version = "0.64.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "llvmlite" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/c9/a0fb41787d01d621046138da30f6c2100d80857bf34b3390dd68040f27a3/numba-0.64.0.tar.gz", hash = "sha256:95e7300af648baa3308127b1955b52ce6d11889d16e8cfe637b4f85d2fca52b1", size = 2765679, upload-time = "2026-02-18T18:41:20.974Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/80/2734de90f9300a6e2503b35ee50d9599926b90cbb7ac54f9e40074cd07f1/numba-0.64.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3bab2c872194dcd985f1153b70782ec0fbbe348fffef340264eacd3a76d59fd6", size = 2683392, upload-time = "2026-02-18T18:41:06.563Z" }, + { url = "https://files.pythonhosted.org/packages/42/e8/14b5853ebefd5b37723ef365c5318a30ce0702d39057eaa8d7d76392859d/numba-0.64.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:703a246c60832cad231d2e73c1182f25bf3cc8b699759ec8fe58a2dbc689a70c", size = 3812245, upload-time = "2026-02-18T18:41:07.963Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a2/f60dc6c96d19b7185144265a5fbf01c14993d37ff4cd324b09d0212aa7ce/numba-0.64.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e2e49a7900ee971d32af7609adc0cfe6aa7477c6f6cccdf6d8138538cf7756f", size = 3511328, upload-time = "2026-02-18T18:41:09.504Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2a/fe7003ea7e7237ee7014f8eaeeb7b0d228a2db22572ca85bab2648cf52cb/numba-0.64.0-cp313-cp313-win_amd64.whl", hash = "sha256:396f43c3f77e78d7ec84cdfc6b04969c78f8f169351b3c4db814b97e7acf4245", size = 2752668, upload-time = "2026-02-18T18:41:11.455Z" }, +] + [[package]] name = "numpy" version = "2.4.0" @@ -1965,6 +1993,7 @@ dependencies = [ { name = "jaxtyping" }, { name = "kaleido" }, { name = "matplotlib" }, + { name = "numba" }, { name = "numpy" }, { name = "openrouter" }, { name = "orjson" }, @@ -2007,6 +2036,7 @@ requires-dist = [ { name = "jaxtyping" }, { name = "kaleido", specifier = "==0.2.1" }, { name = "matplotlib" }, + { name = "numba", specifier = ">=0.64.0" }, { name = "numpy" }, { name = "openrouter", specifier = ">=0.1.1" }, { name = "orjson" }, From dd72fab0b903b950b0100e844eb8cea21037f1c6 Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Wed, 18 Mar 2026 16:27:16 +0000 Subject: [PATCH 4/5] Simplify exact clustering merge path --- spd/clustering/activations.py | 8 ++--- spd/clustering/compute_costs.py | 40 +++++++++------------- spd/clustering/membership_snapshot.py | 32 +++++------------ spd/clustering/merge.py | 13 +++---- spd/clustering/sample_membership.py | 40 ++++++++++++++++------ tests/clustering/test_merge_integration.py | 23 ++++++------- 6 files changed, 75 insertions(+), 81 deletions(-) diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index 3b6c52b9d..99435b56f 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -15,7 +15,7 @@ ClusterCoactivationShaped, ComponentLabels, ) -from spd.clustering.sample_membership import BitsetMembership +from spd.clustering.sample_membership import CompressedMembership from spd.clustering.util import ModuleFilterFunc from spd.log import logger from spd.models.component_model import ComponentModel, OutputWithCache @@ -276,7 +276,7 @@ class ProcessedMemberships: module_alive_counts: dict[str, int] labels: ComponentLabels dead_components_lst: ComponentLabels | None - memberships: list[BitsetMembership] + memberships: list[CompressedMembership] n_samples: int preview: ProcessedActivations | None = None @@ -389,7 +389,7 @@ def finalize(self) -> ProcessedMemberships: module_alive_counts: dict[str, int] = {} alive_labels = ComponentLabels(list()) dead_labels = ComponentLabels(list()) - memberships: list[BitsetMembership] = [] + memberships: list[CompressedMembership] = [] preview_module_component_counts: dict[str, int] = {} preview_module_alive_counts: dict[str, int] = {} @@ -425,7 +425,7 @@ def finalize(self) -> ProcessedMemberships: else np.empty((0,), dtype=np.int64) ) memberships.append( - BitsetMembership.from_sample_indices( + CompressedMembership.from_sample_indices( sample_indices=sample_ids, n_samples=self.n_samples, ) diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index 291db2e77..b715fa681 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -8,7 +8,7 @@ from spd.clustering.consts import ClusterCoactivationShaped, MergePair from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.sample_membership import ( - BitsetMembership, + CompressedMembership, count_group_overlaps_from_component_rows, ) @@ -198,12 +198,12 @@ def recompute_coacts_merge_pair_memberships( coact: ClusterCoactivationShaped, merges: GroupMerge, merge_pair: MergePair, - memberships: list[BitsetMembership], - component_activity_csr: sparse.csr_matrix | None = None, + memberships: list[CompressedMembership], + component_activity_csr: sparse.csr_matrix, ) -> tuple[ GroupMerge, Float[Tensor, "k_groups-1 k_groups-1"], - list[BitsetMembership], + list[CompressedMembership], ]: """Recompute coactivations after a merge using compressed memberships.""" k_groups: int = coact.shape[0] @@ -225,26 +225,18 @@ def recompute_coacts_merge_pair_memberships( mask[remove_idx] = False coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact[mask, :][:, mask].clone() - if component_activity_csr is not None: - merged_rows = merged_membership.to_sample_indices() - coact_with_merge_np = count_group_overlaps_from_component_rows( - merged_rows=merged_rows, - component_activity_csr=component_activity_csr, - group_idxs=merge_new.group_idxs.cpu().numpy(), - n_groups=merge_new.k_groups, - ) - coact_with_merge = torch.tensor( - coact_with_merge_np, - dtype=coact.dtype, - device=coact.device, - ) - else: - coact_with_merge = torch.tensor( - [float(merged_membership.intersection_count(membership)) for membership in memberships], - dtype=coact.dtype, - device=coact.device, - ) - coact_with_merge = coact_with_merge[mask] + merged_rows = merged_membership.to_sample_indices() + coact_with_merge_np = count_group_overlaps_from_component_rows( + merged_rows=merged_rows, + component_activity_csr=component_activity_csr, + group_idxs=merge_new.group_idxs.cpu().numpy(), + n_groups=merge_new.k_groups, + ) + coact_with_merge = torch.tensor( + coact_with_merge_np, + dtype=coact.dtype, + device=coact.device, + ) coact_new[new_group_idx, :] = coact_with_merge coact_new[:, new_group_idx] = coact_with_merge diff --git a/spd/clustering/membership_snapshot.py b/spd/clustering/membership_snapshot.py index 5868220cc..6b401b705 100644 --- a/spd/clustering/membership_snapshot.py +++ b/spd/clustering/membership_snapshot.py @@ -2,11 +2,13 @@ from dataclasses import dataclass from pathlib import Path -import numpy as np from scipy import sparse from spd.clustering.consts import ComponentLabels -from spd.clustering.sample_membership import CompressedMembership +from spd.clustering.sample_membership import ( + CompressedMembership, + memberships_to_sample_component_matrix, +) @dataclass(frozen=True, slots=True) @@ -50,28 +52,10 @@ def memberships_to_csc( memberships: list[CompressedMembership], n_samples: int, ) -> sparse.csc_matrix: - n_components = len(memberships) - if n_components == 0: - return sparse.csc_matrix((n_samples, 0), dtype=np.uint8) - - nnz = sum(membership.count() for membership in memberships) - row_indices = np.empty(nnz, dtype=np.int64) - col_indices = np.empty(nnz, dtype=np.int32) - - offset = 0 - for col_idx, membership in enumerate(memberships): - sample_indices = membership.to_sample_indices().astype(np.int64, copy=False) - col_nnz = sample_indices.size - row_indices[offset : offset + col_nnz] = sample_indices - col_indices[offset : offset + col_nnz] = col_idx - offset += col_nnz - - values = np.ones(nnz, dtype=np.uint8) - return sparse.csc_matrix( - (values, (row_indices, col_indices)), - shape=(n_samples, n_components), - dtype=np.uint8, - ) + matrix = memberships_to_sample_component_matrix(memberships, fmt="csc") + assert isinstance(matrix, sparse.csc_matrix) + assert matrix.shape[0] == n_samples + return matrix def save_membership_snapshot( diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 474d95630..d95017aa1 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -10,7 +10,6 @@ import torch from jaxtyping import Bool, Float -from scipy import sparse from torch import Tensor from tqdm import tqdm @@ -31,8 +30,8 @@ from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory from spd.clustering.sample_membership import ( - BitsetMembership, - compute_coactivation_matrix, + CompressedMembership, + compute_coactivation_matrix_from_csr, memberships_to_sample_component_csr, ) from spd.log import logger @@ -217,14 +216,14 @@ def merge_iteration( def merge_iteration_memberships( merge_config: MergeConfig, - memberships: list[BitsetMembership], + memberships: list[CompressedMembership], n_samples: int, component_labels: ComponentLabels, log_callback: LogCallback | None = None, ) -> MergeHistory: """Exact merge iteration using compressed sample memberships.""" csr_start = time.perf_counter() - component_activity_csr: sparse.csr_matrix = memberships_to_sample_component_csr(memberships) + component_activity_csr = memberships_to_sample_component_csr(memberships) logger.info( "Built component activity CSR in " f"{time.perf_counter() - csr_start:.2f}s " @@ -236,7 +235,9 @@ def merge_iteration_memberships( "Building coactivation matrix from compressed memberships " f"(n_groups={len(memberships)}, n_samples={n_samples})" ) - current_coact: ClusterCoactivationShaped = compute_coactivation_matrix(memberships) + current_coact: ClusterCoactivationShaped = compute_coactivation_matrix_from_csr( + component_activity_csr + ) logger.info( "Built coactivation matrix in " f"{time.perf_counter() - coact_start:.2f}s " diff --git a/spd/clustering/sample_membership.py b/spd/clustering/sample_membership.py index ea8ab0a19..01698745d 100644 --- a/spd/clustering/sample_membership.py +++ b/spd/clustering/sample_membership.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Literal import numpy as np import torch @@ -265,16 +266,16 @@ def to_sample_indices(self) -> np.ndarray: return _bitset_to_sample_indices(self.bits, self.n_samples) -BitsetMembership = CompressedMembership - - -def memberships_to_sample_component_csr( +def memberships_to_sample_component_matrix( memberships: list[CompressedMembership], -) -> sparse.csr_matrix: - """Build a binary sample-by-component CSR matrix from memberships.""" + *, + fmt: Literal["csr", "csc"] = "csr", +) -> sparse.csr_matrix | sparse.csc_matrix: + """Build a binary sample-by-component sparse matrix from memberships.""" n_groups = len(memberships) if n_groups == 0: - return sparse.csr_matrix((0, 0), dtype=np.uint8) + empty = sparse.csr_matrix((0, 0), dtype=np.uint8) + return empty if fmt == "csr" else empty.tocsc() n_samples = memberships[0].n_samples assert all(membership.n_samples == n_samples for membership in memberships), ( @@ -294,11 +295,21 @@ def memberships_to_sample_component_csr( offset += group_nnz values = np.ones(nnz, dtype=np.uint8) - return sparse.csr_matrix( + matrix = sparse.csr_matrix( (values, (row_indices, col_indices)), shape=(n_samples, n_groups), dtype=np.uint8, ) + return matrix if fmt == "csr" else matrix.tocsc() + + +def memberships_to_sample_component_csr( + memberships: list[CompressedMembership], +) -> sparse.csr_matrix: + """Build a binary sample-by-component CSR matrix from memberships.""" + matrix = memberships_to_sample_component_matrix(memberships, fmt="csr") + assert isinstance(matrix, sparse.csr_matrix) + return matrix def count_group_overlaps_from_component_rows( @@ -325,6 +336,15 @@ def count_group_overlaps_from_component_rows( ) +def compute_coactivation_matrix_from_csr( + component_activity_csr: sparse.csr_matrix, +) -> ClusterCoactivationShaped: + """Compute the full coactivation matrix from a sample-by-component CSR matrix.""" + activation_matrix = component_activity_csr.astype(np.int32, copy=False) + coact = (activation_matrix.T @ activation_matrix).toarray() + return torch.from_numpy(coact.astype(np.float32, copy=False)) + + def compute_coactivation_matrix( memberships: list[CompressedMembership], ) -> ClusterCoactivationShaped: @@ -343,6 +363,4 @@ def compute_coactivation_matrix( "Memberships must share sample space" ) - activation_matrix = memberships_to_sample_component_csr(memberships).astype(np.int32, copy=False) - coact = (activation_matrix.T @ activation_matrix).toarray() - return torch.from_numpy(coact.astype(np.float32, copy=False)) + return compute_coactivation_matrix_from_csr(memberships_to_sample_component_csr(memberships)) diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 5edeb5450..4d3723723 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -167,13 +167,6 @@ def test_membership_recompute_matches_row_oriented_path(self): merges = GroupMerge.identity(n_components=len(memberships)) component_activity_csr = memberships_to_sample_component_csr(memberships) - merge_old, coact_old, memberships_old = recompute_coacts_merge_pair_memberships( - coact=coact, - merges=merges, - merge_pair=(0, 1), - memberships=memberships, - component_activity_csr=None, - ) merge_row, coact_row, memberships_row = recompute_coacts_merge_pair_memberships( coact=coact, merges=merges, @@ -182,8 +175,14 @@ def test_membership_recompute_matches_row_oriented_path(self): component_activity_csr=component_activity_csr, ) - assert torch.equal(merge_old.group_idxs, merge_row.group_idxs) - assert torch.equal(coact_old, coact_row) - assert [membership.count() for membership in memberships_old] == [ - membership.count() for membership in memberships_row - ] + expected_group_idxs = torch.tensor([0, 0, 1, 2], dtype=torch.int64) + expected_coact = torch.tensor( + [ + [4.0, 1.0, 1.0], + [1.0, 2.0, 0.0], + [1.0, 0.0, 3.0], + ] + ) + assert torch.equal(merge_row.group_idxs, expected_group_idxs) + assert torch.equal(coact_row, expected_coact) + assert [membership.count() for membership in memberships_row] == [4, 2, 3] From a4db390259d89ae756dd9b803ef060537f76b89f Mon Sep 17 00:00:00 2001 From: Oliver Clive-Griffin Date: Thu, 19 Mar 2026 11:09:26 +0000 Subject: [PATCH 5/5] Add configurable clustering dead-filter stat --- spd/clustering/CLAUDE.md | 2 +- spd/clustering/activations.py | 55 +++++++++++++++---- spd/clustering/configs/crc/example.yaml | 3 +- spd/clustering/membership_snapshot.py | 5 +- spd/clustering/merge_config.py | 9 ++- spd/clustering/sample_membership.py | 6 +- spd/clustering/scripts/run_clustering.py | 3 + spd/clustering/util.py | 2 + .../clustering/test_filter_dead_components.py | 46 ++++++++++++++++ tests/clustering/test_merge_config.py | 3 + tests/clustering/test_merge_integration.py | 6 +- 11 files changed, 119 insertions(+), 21 deletions(-) diff --git a/spd/clustering/CLAUDE.md b/spd/clustering/CLAUDE.md index f502f8785..7bdbd64bf 100644 --- a/spd/clustering/CLAUDE.md +++ b/spd/clustering/CLAUDE.md @@ -80,7 +80,7 @@ Computes pairwise distances between clustering runs in an ensemble: ```python ClusteringPipelineConfig # Pipeline settings (n_runs, distances_methods, SLURM config) ClusteringRunConfig # Single run settings (model_path, batch_size, n_tokens, merge_config) -MergeConfig # Merge algorithm params (alpha, iters, activation_threshold) +MergeConfig # Merge algorithm params (alpha, iters, activation_threshold, filter_dead_stat) ``` ### Data Structures diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index 99435b56f..521fe71b0 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -16,7 +16,7 @@ ComponentLabels, ) from spd.clustering.sample_membership import CompressedMembership -from spd.clustering.util import ModuleFilterFunc +from spd.clustering.util import DeadComponentFilterStat, ModuleFilterFunc from spd.log import logger from spd.models.component_model import ComponentModel, OutputWithCache @@ -117,6 +117,17 @@ def compute_coactivatons( return activations_f16.T @ activations_f16 +def _get_component_filter_values( + activations: ActivationsTensor, + filter_stat: DeadComponentFilterStat, +) -> Float[Tensor, " c"]: + if filter_stat == "max": + return activations.max(dim=0).values + + assert filter_stat == "mean", f"Unsupported dead component filter stat: {filter_stat}" + return activations.mean(dim=0) + + class FilteredActivations(NamedTuple): activations: ActivationsTensor "activations after filtering dead components" @@ -146,21 +157,26 @@ def filter_dead_components( activations: ActivationsTensor, labels: ComponentLabels, filter_dead_threshold: float = 0.01, + filter_dead_stat: DeadComponentFilterStat = "max", ) -> FilteredActivations: """Filter out dead components based on a threshold if `filter_dead_threshold` is 0, no filtering is applied. activations and labels are returned as is, `dead_components_labels` is `None`. - otherwise, components whose **maximum** activations across all samples is below the threshold - are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`. + otherwise, components whose aggregate activation statistic across all samples is below the + threshold are considered dead and filtered out. The statistic is selected by + `filter_dead_stat` and the labels of dead components are returned in `dead_components_labels`. `dead_components_labels` will also be `None` if no components were below the threshold. """ dead_components_lst: ComponentLabels | None = None if filter_dead_threshold > 0: dead_components_lst = ComponentLabels(list()) - max_act: Float[Tensor, " c"] = activations.max(dim=0).values - dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold + filter_values: Float[Tensor, " c"] = _get_component_filter_values( + activations=activations, + filter_stat=filter_dead_stat, + ) + dead_components: Bool[Tensor, " c"] = filter_values < filter_dead_threshold if dead_components.any(): activations = activations[:, ~dead_components] @@ -313,17 +329,20 @@ def __init__( *, activation_threshold: float, filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat, filter_modules: ModuleFilterFunc | None, preview_n_samples: int = 256, ) -> None: self.activation_threshold = activation_threshold self.filter_dead_threshold = filter_dead_threshold + self.filter_dead_stat = filter_dead_stat self.filter_modules = filter_modules self.preview_n_samples = preview_n_samples self.n_samples = 0 self.module_component_counts: dict[str, int] = {} self.max_activations: dict[str, Float[Tensor, " c"]] = {} + self.sum_activations: dict[str, Float[Tensor, " c"]] = {} self.sample_idx_chunks: dict[str, list[list[np.ndarray]]] = {} self.preview_chunks: dict[str, list[Tensor]] = {} self.module_order: list[str] = [] @@ -339,6 +358,7 @@ def _ensure_module(self, key: str, n_components: int) -> None: self.module_component_counts[key] = n_components self.max_activations[key] = torch.full((n_components,), float("-inf")) + self.sum_activations[key] = torch.zeros((n_components,), dtype=torch.float64) self.sample_idx_chunks[key] = [[] for _ in range(n_components)] self.preview_chunks[key] = [] self.module_order.append(key) @@ -367,6 +387,7 @@ def add_batch( self.max_activations[key] = torch.maximum( self.max_activations[key], act_cpu.max(dim=0).values ) + self.sum_activations[key] += act_cpu.sum(dim=0, dtype=torch.float64) if self._preview_rows < self.preview_n_samples: remaining = self.preview_n_samples - self._preview_rows @@ -396,10 +417,16 @@ def finalize(self) -> ProcessedMemberships: preview_chunks_alive: list[Tensor] = [] for key in self.module_order: - max_act = self.max_activations[key] + filter_values = ( + self.max_activations[key] + if self.filter_dead_stat == "max" + else (self.sum_activations[key] / self.n_samples).to( + self.max_activations[key].dtype + ) + ) n_components = self.module_component_counts[key] alive = ( - max_act >= self.filter_dead_threshold + filter_values >= self.filter_dead_threshold if self.filter_dead_threshold > 0 else torch.ones(n_components, dtype=torch.bool) ) @@ -411,7 +438,7 @@ def finalize(self) -> ProcessedMemberships: preview_tensor = ( torch.cat(self.preview_chunks[key], dim=0) if self.preview_chunks[key] - else torch.empty((0, n_components), dtype=max_act.dtype) + else torch.empty((0, n_components), dtype=filter_values.dtype) ) for comp_idx in range(n_components): @@ -468,6 +495,7 @@ def collect_memberships_lm( seed: int, activation_threshold: float, filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat = "max", filter_modules: ModuleFilterFunc | None = None, preview_n_samples: int = 256, ) -> ProcessedMemberships: @@ -476,6 +504,7 @@ def collect_memberships_lm( builder = MembershipBuilder( activation_threshold=activation_threshold, filter_dead_threshold=filter_dead_threshold, + filter_dead_stat=filter_dead_stat, filter_modules=filter_modules, preview_n_samples=preview_n_samples, ) @@ -519,6 +548,7 @@ def collect_memberships_resid_mlp( device: torch.device | str, activation_threshold: float, filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat = "max", filter_modules: ModuleFilterFunc | None = None, preview_n_samples: int = 256, ) -> ProcessedMemberships: @@ -526,6 +556,7 @@ def collect_memberships_resid_mlp( builder = MembershipBuilder( activation_threshold=activation_threshold, filter_dead_threshold=filter_dead_threshold, + filter_dead_stat=filter_dead_stat, filter_modules=filter_modules, preview_n_samples=preview_n_samples, ) @@ -559,6 +590,7 @@ def process_activations( | Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations) ], filter_dead_threshold: float, + filter_dead_stat: DeadComponentFilterStat = "max", seq_mode: Literal["concat", "seq_mean", None] = None, filter_modules: ModuleFilterFunc | None = None, ) -> ProcessedActivations: @@ -595,8 +627,11 @@ def process_activations( c = act.shape[-1] module_component_counts[key] = c if filter_dead_threshold > 0: - max_act: Float[Tensor, " c"] = act.max(dim=0).values - alive = max_act >= filter_dead_threshold + filter_values: Float[Tensor, " c"] = _get_component_filter_values( + activations=act, + filter_stat=filter_dead_stat, + ) + alive = filter_values >= filter_dead_threshold alive_masks[key] = alive total_alive += int(alive.sum().item()) else: diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml index 9af7106d8..e7031fca5 100644 --- a/spd/clustering/configs/crc/example.yaml +++ b/spd/clustering/configs/crc/example.yaml @@ -12,6 +12,7 @@ merge_config: merge_pair_sampling_kwargs: threshold: 0.05 # For range sampler: fraction of the range of costs to sample from filter_dead_threshold: 0.001 # Threshold for filtering dead components + filter_dead_stat: "max" # Dead-component statistic: "max" or "mean" module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules wandb_project: spd @@ -20,4 +21,4 @@ logging_intervals: stat: 1 # for k_groups, merge_pair_cost, mdl_loss tensor: 100 # for wandb_log_tensor and fraction_* calculations plot: 100 # for calling the plotting callback - artifact: 100 # for calling the artifact callback \ No newline at end of file + artifact: 100 # for calling the artifact callback diff --git a/spd/clustering/membership_snapshot.py b/spd/clustering/membership_snapshot.py index 6b401b705..f8cd1445f 100644 --- a/spd/clustering/membership_snapshot.py +++ b/spd/clustering/membership_snapshot.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from pathlib import Path +import numpy as np from scipy import sparse from spd.clustering.consts import ComponentLabels @@ -54,7 +55,9 @@ def memberships_to_csc( ) -> sparse.csc_matrix: matrix = memberships_to_sample_component_matrix(memberships, fmt="csc") assert isinstance(matrix, sparse.csc_matrix) - assert matrix.shape[0] == n_samples + shape = matrix.shape + assert shape is not None + assert shape[0] == n_samples return matrix diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index f471879b2..c4e246946 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -14,7 +14,7 @@ MergePairSampler, MergePairSamplerKey, ) -from spd.clustering.util import ModuleFilterFunc, ModuleFilterSource +from spd.clustering.util import DeadComponentFilterStat, ModuleFilterFunc, ModuleFilterSource from spd.spd_types import Probability MergeConfigKey = Literal[ @@ -24,6 +24,7 @@ "merge_pair_sampling_method", "merge_pair_sampling_kwargs", "filter_dead_threshold", + "filter_dead_stat", ] @@ -66,7 +67,11 @@ class MergeConfig(BaseConfig): ) filter_dead_threshold: float = Field( default=0.001, - description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", + description="Threshold for filtering out dead components using the statistic selected by filter_dead_stat.", + ) + filter_dead_stat: DeadComponentFilterStat = Field( + default="max", + description="Statistic used to determine whether a component is dead before clustering.", ) module_name_filter: ModuleFilterSource = Field( default=None, diff --git a/spd/clustering/sample_membership.py b/spd/clustering/sample_membership.py index 01698745d..efdfd089e 100644 --- a/spd/clustering/sample_membership.py +++ b/spd/clustering/sample_membership.py @@ -114,7 +114,7 @@ def _bitset_to_sample_indices(bits: np.ndarray, n_samples: int) -> np.ndarray: return sample_indices.astype(_index_dtype_for(n_samples), copy=False) -@njit(cache=True) +@njit(cache=True) # pyright: ignore[reportUntypedFunctionDecorator] def _count_group_overlaps_rows_numba( merged_rows: np.ndarray, indptr: np.ndarray, @@ -275,7 +275,7 @@ def memberships_to_sample_component_matrix( n_groups = len(memberships) if n_groups == 0: empty = sparse.csr_matrix((0, 0), dtype=np.uint8) - return empty if fmt == "csr" else empty.tocsc() + return empty if fmt == "csr" else sparse.csc_matrix(empty) n_samples = memberships[0].n_samples assert all(membership.n_samples == n_samples for membership in memberships), ( @@ -300,7 +300,7 @@ def memberships_to_sample_component_matrix( shape=(n_samples, n_groups), dtype=np.uint8, ) - return matrix if fmt == "csr" else matrix.tocsc() + return matrix if fmt == "csr" else sparse.csc_matrix(matrix) def memberships_to_sample_component_csr( diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index f06e1c88a..eea1007dc 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -318,6 +318,7 @@ def main(run_config: ClusteringRunConfig) -> Path: seed=run_config.dataset_seed, activation_threshold=activation_threshold, filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_dead_stat=run_config.merge_config.filter_dead_stat, filter_modules=run_config.merge_config.filter_modules, ) else: @@ -328,6 +329,7 @@ def main(run_config: ClusteringRunConfig) -> Path: device=device, activation_threshold=activation_threshold, filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_dead_stat=run_config.merge_config.filter_dead_stat, filter_modules=run_config.merge_config.filter_modules, ) processed_activations = processed_memberships.preview @@ -370,6 +372,7 @@ def main(run_config: ClusteringRunConfig) -> Path: processed_activations = process_activations( activations=activations_dict, filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + filter_dead_stat=run_config.merge_config.filter_dead_stat, seq_mode=None, filter_modules=run_config.merge_config.filter_modules, ) diff --git a/spd/clustering/util.py b/spd/clustering/util.py index bd11e2fd4..a84e3d78f 100644 --- a/spd/clustering/util.py +++ b/spd/clustering/util.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from typing import Literal def format_scientific_latex(value: float) -> str: @@ -16,3 +17,4 @@ def format_scientific_latex(value: float) -> str: ModuleFilterSource = str | Callable[[str], bool] | set[str] | None ModuleFilterFunc = Callable[[str], bool] +DeadComponentFilterStat = Literal["max", "mean"] diff --git a/tests/clustering/test_filter_dead_components.py b/tests/clustering/test_filter_dead_components.py index 654631f37..456db84e3 100644 --- a/tests/clustering/test_filter_dead_components.py +++ b/tests/clustering/test_filter_dead_components.py @@ -129,3 +129,49 @@ def test_linear_gradient_thresholds(threshold: float) -> None: assert result.n_alive == expected_alive assert result.n_dead == n_components - expected_alive + + +def test_filter_dead_components_mean_stat() -> None: + """Mean-based filtering keeps components whose average activation clears the threshold.""" + activations = torch.tensor( + [ + [1e-5, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + ] + ) + labels = ComponentLabels(["spiky", "dead", "steady"]) + + result: FilteredActivations = filter_dead_components( + activations=activations, + labels=labels, + filter_dead_threshold=5e-6, + filter_dead_stat="mean", + ) + + assert result.labels == ["steady"] + assert result.dead_components_labels == ["spiky", "dead"] + + +def test_filter_dead_components_max_stat_preserves_spikes() -> None: + """Max-based filtering preserves components with a single large activation.""" + activations = torch.tensor( + [ + [1e-5, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + [0.0, 0.0, 1e-5], + ] + ) + labels = ComponentLabels(["spiky", "dead", "steady"]) + + result: FilteredActivations = filter_dead_components( + activations=activations, + labels=labels, + filter_dead_threshold=5e-6, + filter_dead_stat="max", + ) + + assert result.labels == ["spiky", "steady"] + assert result.dead_components_labels == ["dead"] diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py index 63f4e88f7..38ffa0fc1 100644 --- a/tests/clustering/test_merge_config.py +++ b/tests/clustering/test_merge_config.py @@ -15,6 +15,7 @@ def test_default_config(self): assert config.merge_pair_sampling_method == "range" assert config.merge_pair_sampling_kwargs == {"threshold": 0.05} + assert config.filter_dead_stat == "max" def test_range_sampler_config(self): """Test MergeConfig with range sampler.""" @@ -75,6 +76,7 @@ def test_config_with_all_parameters(self): merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.5}, filter_dead_threshold=0.001, + filter_dead_stat="mean", module_name_filter="model.layers", ) @@ -84,6 +86,7 @@ def test_config_with_all_parameters(self): assert config.merge_pair_sampling_method == "mcmc" assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} assert config.filter_dead_threshold == 0.001 + assert config.filter_dead_stat == "mean" assert config.module_name_filter == "model.layers" def test_config_serialization(self): diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 4d3723723..47e1801f6 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -3,10 +3,10 @@ import torch from spd.clustering.compute_costs import recompute_coacts_merge_pair_memberships -from spd.clustering.consts import ComponentLabels +from spd.clustering.consts import ComponentLabels, MergePair +from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig -from spd.clustering.math.merge_matrix import GroupMerge from spd.clustering.sample_membership import ( CompressedMembership, compute_coactivation_matrix, @@ -170,7 +170,7 @@ def test_membership_recompute_matches_row_oriented_path(self): merge_row, coact_row, memberships_row = recompute_coacts_merge_pair_memberships( coact=coact, merges=merges, - merge_pair=(0, 1), + merge_pair=MergePair((0, 1)), memberships=memberships, component_activity_csr=component_activity_csr, )