Skip to content

Refactor clustering to use compressed memberships#456

Open
ocg-goodfire wants to merge 4 commits intodevfrom
feature/hybrid-clustering-memberships
Open

Refactor clustering to use compressed memberships#456
ocg-goodfire wants to merge 4 commits intodevfrom
feature/hybrid-clustering-memberships

Conversation

@ocg-goodfire
Copy link
Collaborator

@ocg-goodfire ocg-goodfire commented Mar 18, 2026

[entirely by codex 5.4]

Summary

  • replace the dense sample-by-group thresholded clustering path with exact compressed memberships
  • keep memberships hybrid sparse/bitset so low-density groups stay sparse and denser merged groups upgrade naturally
  • route thresholded clustering runs through the compressed path, add multi-batch non-LM collection via n_samples, and keep the old dense path as a fallback when activation_threshold is None
  • speed up exact merge compute by moving dense MDL cost evaluation to GPU and replacing Python membership-intersection overlap updates with a row-oriented numba kernel over a cached sample-by-component CSR matrix
  • add a snapshot/benchmark scaffold so activations can be harvested once and merge kernels can be benchmarked repeatedly without recollecting data

What Changed

The original exact thresholded clustering path materialized dense sample-by-group state and then recomputed merged-group overlaps with Python-level set operations. That made memory scale badly with token count, and once memory was fixed, the per-merge overlap update became the main runtime bottleneck.

This PR changes the exact thresholded path in two stages:

  1. Representation changes
  • introduce exact compressed memberships for thresholded clustering
  • store each group as sparse sample indices when that is cheaper than a bitset, otherwise as a packed bitset
  • build exact coactivation matrices from the compressed memberships with sparse X.T @ X
  • add multi-batch collection for non-LM clustering with n_samples
  1. Exact runtime changes
  • keep the dense coactivation matrix on GPU when possible so compute_merge_costs() runs there
  • build the original sample-by-component CSR matrix once at merge start
  • after each merge, recompute merged-group overlaps by scanning only the merged rows through a numba-compiled row-oriented kernel instead of doing Python-level membership intersections against every current group

Semantics are still exact for the thresholded path. The new row-oriented overlap update computes the same counts as the old direct membership intersection path; it just does it with compiled array code over the original component activity matrix and current group assignments.

Performance Notes

Jose-like synthetic benchmark matching the old real run scale after dead filtering:

  • 500,000 sampled tokens
  • 9,981 alive components
  • token-wise l0 ~= 200

Measured progression:

  • original compressed merge loop: about 1.90s/iter
  • after GPU cost-matrix move: about 0.66s/iter
  • after row-oriented numba overlap update: about 0.095s/iter

Current Jose-like exact runtime at that scale:

  • CSR build: about 2.5s
  • coactivation build: about 75s
  • merge loop: about 0.095s/iter
  • projected ~9100 iterations: about 16 minutes merge-side total at 500k sampled tokens

That means the exact merge kernel is now roughly:

  • ~20x faster than the original compressed path
  • ~6.9x faster than the previous branch state before the row-oriented update

On the memory side, the compressed membership representation already reduced token-scaling memory dramatically versus the old dense thresholded path, so the current branch materially improves both:

  • memory headroom for larger token counts
  • merge throughput at Jose-like clustering scale

Validation

  • python3 -m compileall spd/clustering
  • uv run pytest tests/clustering/test_merge_integration.py tests/clustering/test_filter_dead_components.py tests/clustering/test_merge_pair_samplers.py
  • uv run pytest tests/clustering/test_merge_integration.py -q
  • targeted dense-vs-compressed merge sanity check with synchronized sampler state
  • targeted sanity check that low-density memberships stay sparse
  • targeted exactness check that the row-oriented overlap recompute matches the previous direct membership recompute
  • Jose-like synthetic benchmarks for coactivation build, overlap-update subproblem, and full merge-loop throughput

@ocg-goodfire
Copy link
Collaborator Author

@claude Please review this. mostly laser-focus on correctness - this should be an interface-preserving efficiency improvement.

@claude
Copy link
Contributor

claude bot commented Mar 18, 2026

Claude finished @ocg-goodfire's task —— View job


Code Review Completed

Todo List:

  • Set up environment and read CLAUDE.md guidelines
  • Review new sample_membership.py - core compressed membership logic
  • Review changes to activations.py - activation collection updates
  • Review changes to merge.py - compressed merge implementation
  • Review changes to compute_costs.py - cost computation updates
  • Review changes to run_clustering.py - main orchestration logic
  • Review config changes in clustering_run_config.py
  • Verify interface preservation across all changes
  • Check for potential correctness issues and edge cases
  • Provide final review summary

Code Review Summary

This is a solid efficiency improvement that successfully introduces hybrid sparse/bitset compressed memberships while preserving the existing dense clustering interface. The refactor follows good engineering practices with clean separation of concerns.

Strengths

Well-designed core data structure (sample_membership.py):

  • Clean hybrid representation that automatically chooses sparse vs bitset based on memory efficiency
  • Mathematically sound intersection/union operations with proper edge case handling
  • Good use of numpy optimizations (popcount lookup table, vectorized operations)

Excellent interface preservation:

  • Original merge_iteration() function completely unchanged
  • New merge_iteration_memberships() mirrors original structure exactly
  • Existing functions like collect_activations() and process_activations() remain intact
  • Fallback path when activation_threshold is None maintains full backward compatibility

Smart orchestration (run_clustering.py lines 303-334):

  • Routing logic properly chooses compressed vs dense path based on activation_threshold
  • Preserves all WandB logging, progress tracking, and error handling
  • Clean preview generation for plotting while using compressed storage for computation

Efficient memory management:

  • MembershipBuilder provides streaming construction without holding full dense matrices
  • Preview is limited to 256 samples while full membership tracking scales to larger datasets
  • Proper cleanup and garbage collection maintained

🔍 Areas to Verify

Bitwise operations correctness:

  • In _sample_indices_to_bits() line 35: np.bitwise_or.at(bits, byte_indices, (1 << bit_offsets).astype(np.uint8))
  • This appears correct - bitwise_or.at() performs in-place OR operations to set the appropriate bits
  • Similar pattern in union() method lines 199-202 looks sound

Intersection counting across representations:

  • Mixed sparse/bitset intersection logic in intersection_count() (lines 158-171) handles all combinations
  • The _count_sparse_bitset_intersection() function correctly maps sparse indices to bit positions

Union operations:

  • _union_sparse_sparse() implements proper sorted merge (lines 59-86)
  • Mixed sparse/bitset union correctly converts to bitset when needed (lines 189-208)

📊 Performance Characteristics

Memory efficiency gains:

  • Sparse representation: O(active_samples) vs dense O(total_samples)
  • Bitset representation: O(total_samples/8) vs dense O(total_samples * sizeof(bool))
  • Automatic hybrid selection optimizes for both sparse and dense cases

Computational complexity:

  • Sparse-sparse operations: O(min(|A|, |B|)) instead of O(total_samples)
  • Bitset operations: Vectorized with popcount optimization
  • Coactivation matrix computation: O(n_groups² * avg_membership_size) instead of O(n_groups² * total_samples)

🎯 Correctness Assessment

Mathematics: All set operations (intersection, union) are implemented correctly with proper sorted array merging and bit manipulation.

Edge cases: Code handles empty sets, single elements, and memory-optimized representation switching properly.

Interface contracts: All function signatures and return types preserved. Error conditions and assertions maintained.

Multi-batch handling: The MembershipBuilder.add_batch() correctly handles sample offset tracking for streaming construction.

🔬 Minor Suggestions

  1. Line 97 clarity: The bitwise intersection count could benefit from a comment explaining the bit manipulation
  2. Memory estimates: The _prefer_sparse() heuristic could document the 8x compression assumption for bitsets
  3. Validation: Consider adding runtime assertions to verify intersection/union results in debug mode

Final Assessment

This PR successfully delivers what it promises - an interface-preserving efficiency improvement. The compressed membership approach is mathematically sound, well-implemented, and properly integrated. The hybrid sparse/bitset design is clever and should provide significant memory savings for typical clustering workloads.

Recommendation: ✅ Approve - This is high-quality code that maintains correctness while delivering meaningful performance improvements.


@ocg-goodfire
Copy link
Collaborator Author

@claude thanks for you review. Can you please review the 2 new commits since? again - mostly concerned with correctness.

@claude
Copy link
Contributor

claude bot commented Mar 18, 2026

Claude finished @ocg-goodfire's task —— View job


Review of 2 New Commits Since Previous Review

Todo List:

  • Set up environment and identify the 2 new commits
  • Review changes in the new commits for correctness
  • Analyze mathematical operations and edge cases
  • Check interface preservation and logic correctness
  • Provide focused feedback on any correctness issues found

Summary

Both commits look mathematically sound and correctly implemented. The key optimization introduces a numba-compiled row-oriented overlap counting kernel that maintains exactness while providing significant performance improvements. Good correctness validation is included.

Commit Analysis

1. 8483070 - "Speed up clustering merge benchmarks" ✅

Adds snapshot infrastructure:

  • MembershipSnapshot class for disk serialization/deserialization
  • Benchmark scaffolding for repeatable performance testing
  • Clean separation between data collection and benchmarking phases

Correctness Assessment:

  • Proper sparse matrix format conversions (CSC ↔ CSR)
  • Correct membership reconstruction from sparse matrices
  • Safe dtype handling throughout

2. ac99c0e - "Speed up exact clustering overlap updates" ✅

Core optimization - numba kernel:

The _count_group_overlaps_rows_numba() kernel implements a stamp-based deduplication algorithm that's mathematically equivalent to set intersections:

@njit(cache=True)
def _count_group_overlaps_rows_numba(merged_rows, indptr, indices, group_idxs, n_groups):
    counts = np.zeros(n_groups, dtype=np.int64)
    seen = np.full(n_groups, -1, dtype=np.int64)
    stamp = 0
    for row in merged_rows:
        stamp += 1  # Unique timestamp per sample
        start, end = indptr[row], indptr[row + 1]
        for pos in range(start, end):
            group_idx = group_idxs[indices[pos]]
            if seen[group_idx] == stamp:
                continue  # Already counted for this sample
            seen[group_idx] = stamp
            counts[group_idx] += 1
    return counts

Correctness Verification

Algorithm Logic:

  • Stamp-based deduplication: Each sample gets a unique timestamp (stamp); groups are only counted once per sample
  • Equivalence: This correctly computes |merged_group ∩ other_group| for each group
  • CSR traversal: Proper indexing through indptr[row:row+1] to get components active on each sample

Data Type Safety:

  • Consistent int64 casting for indices prevents overflow
  • uint8int32float32 progression in coactivation computation is safe
  • Proper handling of numpy/torch tensor conversions

Interface Preservation:

  • recompute_coacts_merge_pair_memberships() maintains exact signature
  • Fallback path when component_activity_csr=None preserves backward compatibility
  • Return types and semantics unchanged

Validation Evidence:

  • Direct equivalence test: test_membership_recompute_matches_row_oriented_path() verifies numba path produces identical results to direct membership intersections
  • Built-in consistency check: Benchmark script includes assert np.array_equal(current, row_counts) to verify correctness
  • Performance validation: Maintains exactness while achieving ~20x speedup

🔍 Edge Cases Handled Correctly

  1. Empty groups: np.zeros(n_groups) initialization handles zero counts correctly
  2. Single-component samples: Stamp mechanism works with any sparsity level
  3. Duplicate component indices: Deduplication via seen[group_idx] == stamp prevents double-counting
  4. Large indices: int64 casting prevents overflow for large sample counts
  5. Memory consistency: CSR format guarantees contiguous indices array access

📊 Performance Characteristics

The optimization maintains O(nnz) complexity where nnz is non-zero count in merged groups, but with much better constants:

  • Before: Python set intersections - interpreted loops with hash table overhead
  • After: Compiled array scanning with simple integer comparisons
  • Result: ~20x speedup while preserving exactness

🎯 Final Assessment

Recommendation: ✅ Approve - High Confidence

Both commits are mathematically correct and well-engineered:

  • Numba kernel implements sound deduplication algorithm
  • Proper dtype management prevents numerical issues
  • Comprehensive testing validates equivalence with original path
  • Clean interface preservation maintains backward compatibility
  • Performance improvements are substantial while maintaining correctness

The stamp-based deduplication approach is particularly elegant - it's essentially a compiled version of the membership intersection logic that avoids Python object overhead.


@ocg-goodfire
Copy link
Collaborator Author

@claude sorry one more how's that last commit look

@claude
Copy link
Contributor

claude bot commented Mar 18, 2026

Claude finished @ocg-goodfire's task —— View job


Review of Latest Commit

Todo List:

  • Set up environment and identify the latest commit
  • Review the changes in the latest commit for correctness
  • Analyze mathematical operations and edge cases
  • Check for interface preservation and logic issues
  • Provide focused feedback on any correctness concerns found

Analyzing commit dd72fab - "Simplify exact clustering merge path"


Summary

This commit performs a clean refactoring that simplifies the clustering merge path while maintaining correctness. The changes eliminate code duplication, remove unnecessary conditional paths, and improve API consistency.

Key Changes & Correctness Assessment

1. Type Alias Simplification (sample_membership.py:269)

  • Change: BitsetMembership = CompressedMembership → Removed alias, use CompressedMembership directly
  • Correctness: ✅ Pure renaming, no functional change

2. Matrix Construction Refactor (sample_membership.py:269-312)

  • New: memberships_to_sample_component_matrix() with format parameter (fmt: Literal["csr", "csc"])
  • Wrapper: memberships_to_sample_component_csr() calls the general function with fmt="csr"
  • Correctness: ✅ Clean abstraction that properly handles both CSR/CSC formats with correct shape validation

3. Coactivation Matrix Computation Split (sample_membership.py:339-366)

  • New: compute_coactivation_matrix_from_csr() - operates directly on CSR matrix
  • Updated: compute_coactivation_matrix() - delegates to CSR version after building matrix
  • Mathematical operations:
    activation_matrix = component_activity_csr.astype(np.int32, copy=False)
    coact = (activation_matrix.T @ activation_matrix).toarray() 
    return torch.from_numpy(coact.astype(np.float32, copy=False))
  • Correctness: ✅ Matrix multiplication X.T @ X correctly computes pairwise intersections, dtype progression is safe

4. Mandatory CSR Path (compute_costs.py:202, 228-239)

  • Removed: Optional component_activity_csr: sparse.csr_matrix | None = None parameter
  • Now required: component_activity_csr: sparse.csr_matrix
  • Eliminated fallback: Removed the slower direct membership intersection path
  • Correctness: ✅ Forces use of the faster, validated numba path exclusively

5. CSR Reuse in Merge Loop (merge.py:238-240)

  • Before: compute_coactivation_matrix(memberships) - rebuilt CSR internally
  • After: compute_coactivation_matrix_from_csr(component_activity_csr) - reuses pre-built CSR
  • Correctness: ✅ Eliminates redundant CSR construction, more efficient

6. Test Simplification (test_merge_integration.py:170-188)

  • Removed: Test comparing old fallback path vs new numba path
  • Simplified: Direct validation against expected results
  • Correctness: ✅ Still validates the core merge operation, just more focused

🔍 Mathematical Verification

CSR → Coactivation Matrix:

  • activation_matrix.T @ activation_matrix computes (n_components, n_samples) @ (n_samples, n_components) = (n_components, n_components)
  • Entry [i,j] = sum(activation_matrix[:,i] * activation_matrix[:,j]) = intersection count of groups i,j
  • Result: Identical to pairwise membership.intersection_count() calls but much faster

Data Type Safety:

  • uint8 (CSR) → int32 (computation) → float32 (torch): Safe progression prevents overflow
  • CSR matrix maintains (n_samples, n_groups) shape consistency

🎯 Interface & Behavioral Changes

Breaking change (intentional):

  • recompute_coacts_merge_pair_memberships() now requires component_activity_csr parameter
  • Impact: Eliminates the slower fallback path, enforces use of optimized code

Preserved interfaces:

  • compute_coactivation_matrix() signature unchanged, still works for existing callers
  • Core merge semantics remain identical
  • All mathematical results unchanged

📊 Performance Impact

Eliminates redundancy:

  • Before: CSR built in merge_iteration_memberships(), then rebuilt in compute_coactivation_matrix()
  • After: CSR built once, reused for coactivation computation
  • Benefit: Reduces memory allocation and computation overhead

🔬 Edge Cases & Validation

Empty groups:

  • memberships_to_sample_component_matrix() correctly handles n_groups == 0 case
  • Returns proper empty matrices for both CSR and CSC formats

Shape consistency:

  • Assertions maintain sample space consistency: all(membership.n_samples == n_samples for membership in memberships)
  • Matrix shape validation throughout the pipeline

Type safety:

  • Added isinstance(matrix, sparse.csr_matrix) assertion in wrapper function
  • Proper literal types for format parameter

Final Assessment

Recommendation: ✅ Approve - High Confidence

This commit represents excellent refactoring that:

  • ✅ Maintains mathematical correctness while eliminating code duplication
  • ✅ Simplifies the API by removing unnecessary conditional branches
  • ✅ Improves performance by eliminating redundant CSR construction
  • ✅ Enhances code clarity with better separation of concerns
  • ✅ Properly handles all edge cases and maintains type safety

The mandatory CSR requirement is a good design decision - it forces use of the optimized path and eliminates potential performance pitfalls from the slower fallback.


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant