Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
e17f510
clustering (#43) squash
mivanit Oct 10, 2025
26c2957
timing cluster_ss.py
mivanit Oct 10, 2025
55be0d8
Revert "timing cluster_ss.py"
mivanit Oct 10, 2025
00db8dd
[clustering] `cluster_ss.py` speedup ci (#199)
mivanit Oct 10, 2025
13db0be
add num_nonsingleton_groups stat from PR170
mivanit Oct 10, 2025
8e9ddd0
Merge branch 'main' into clustering/main
mivanit Oct 13, 2025
e79ccb8
Merge branch 'main' into clustering/main
mivanit Oct 13, 2025
0c74b5d
switch BaseModel to BaseConfig, get rid of old save/read logic (#209)
mivanit Oct 13, 2025
3f07f42
switch to new run
mivanit Oct 13, 2025
37175cb
Merge branch 'main' into clustering/main
mivanit Oct 14, 2025
ed667c8
Merge branch 'main' into clustering/main
mivanit Oct 14, 2025
0cdd754
Merge branch 'main' into clustering/main
mivanit Oct 16, 2025
c06cffe
wip sigmoid issues
mivanit Oct 16, 2025
fddb323
that worked...?
mivanit Oct 16, 2025
8a56e12
dont assert positive coacts?
mivanit Oct 16, 2025
1f18690
Merge branch 'main' into clustering/main
mivanit Oct 16, 2025
d18470b
get rid of long-running merge pair sampler on GPU test
mivanit Oct 16, 2025
35b7423
Merge branch 'main' into clustering/main
mivanit Oct 20, 2025
83d4288
wip CI decision trees w/ random data
mivanit Oct 20, 2025
e2e8c5c
wip
mivanit Oct 20, 2025
17d3a13
wip
mivanit Oct 20, 2025
8dfea5e
[clustering] Refactor to two-stage process (#203)
danbraunai-goodfire Oct 20, 2025
8f6462c
split up ci_dt
mivanit Oct 20, 2025
e35a7b1
wip
mivanit Oct 20, 2025
869ce94
Merge branch 'clustering/main' into clustering/ci-decision-trees
mivanit Oct 20, 2025
c6cadd4
wip
mivanit Oct 20, 2025
2b74497
wip
mivanit Oct 20, 2025
73a4835
wip
mivanit Oct 20, 2025
d80ba3f
[clustering] distance computation (#213)
mivanit Oct 21, 2025
06ab8cc
Merge branch 'main' into clustering/main
mivanit Oct 21, 2025
3a206ed
allow specifying either config path or mrc cfg in pipeline cfg
mivanit Oct 21, 2025
eb831c0
[wip] reorg configs
mivanit Oct 21, 2025
89e5c36
added default `None` for slurm partition and job name prefix
mivanit Oct 21, 2025
8910bb4
refactor configs, add config tests
mivanit Oct 21, 2025
0b957f5
fix tests
mivanit Oct 21, 2025
7de545b
allow `None` or `-1` idx_in_ensemble
mivanit Oct 21, 2025
3d45ac4
whoops, wrong name on fixture
mivanit Oct 21, 2025
4adde10
fix idx passed in tests when not needed
mivanit Oct 21, 2025
189b64a
rename "mrc" -> "crc" in paths
mivanit Oct 21, 2025
57f445a
rename merge_run_config.py -> clustering_run_config.py
mivanit Oct 21, 2025
91f5348
fix pyright
mivanit Oct 21, 2025
11e5501
fix idx_in_ensemble being passed in tests
mivanit Oct 21, 2025
1d96054
rename cache dir 'merge_run_configs' -> 'clustering_run_configs'
mivanit Oct 21, 2025
a1f1146
remove component popping
mivanit Oct 21, 2025
1e3fbb2
dont pass batch size, change not brought in here
mivanit Oct 21, 2025
866e28c
wip
mivanit Oct 21, 2025
12e54e0
add some js from feature/clustering-dashboard branch
mivanit Oct 21, 2025
b2aa6bc
claude plan
mivanit Oct 21, 2025
2837492
junk??
mivanit Oct 21, 2025
e86adc9
fix history_path extension and storage usage
mivanit Oct 22, 2025
b8bbb08
dev pipeline
mivanit Oct 22, 2025
733d47f
better config validation tests
mivanit Oct 22, 2025
2fa1f21
set default base output dir
mivanit Oct 22, 2025
eec80fb
wandb use run id for clustering, TODO for spd decomp
mivanit Oct 22, 2025
6098536
basedpyright 1.32.0 causes issues, esp w/ wandb
mivanit Oct 22, 2025
a24361a
Merge branch 'main' into clustering/main
mivanit Oct 23, 2025
bd60f2c
Merge branch 'clustering/main' into clustering/add-mrc-inline-in-pipe…
mivanit Oct 23, 2025
af59b36
Merge branch 'clustering/add-mrc-inline-in-pipeline-cfg' into cluster…
mivanit Oct 23, 2025
7bac507
wip
mivanit Oct 23, 2025
7d43ac1
big plan
mivanit Oct 23, 2025
12cac51
wip
mivanit Oct 23, 2025
0661120
wip
mivanit Oct 23, 2025
f75fd9f
wip
mivanit Oct 23, 2025
23ba055
wip
mivanit Oct 23, 2025
dc15ee4
wip
mivanit Oct 23, 2025
9325b1b
format
mivanit Oct 23, 2025
d860bfa
wip
mivanit Oct 24, 2025
55c9206
fix bug where we were avg across seq
mivanit Oct 24, 2025
2db397e
wip
mivanit Oct 24, 2025
79c9f44
wip
mivanit Oct 24, 2025
6c70327
Merge branch 'main' into clustering/main
mivanit Oct 24, 2025
40df505
remove idx_in_ensemble, always auto-assigned now
mivanit Oct 24, 2025
bd8a442
Merge branch 'clustering/main' into clustering/add-mrc-inline-in-pipe…
mivanit Oct 24, 2025
cf64a79
only allow passing clustering run config path, not inline
mivanit Oct 24, 2025
2a9f731
rename run_clustering_config_path -> clustering_run_config_path
mivanit Oct 24, 2025
26c6520
[clustering] config refactor (#227)
mivanit Oct 24, 2025
1f0725c
deps
mivanit Oct 24, 2025
d8cc0e7
Merge branch 'main' into clustering/main
mivanit Oct 27, 2025
60790a4
Merge branch 'clustering/add-mrc-inline-in-pipeline-cfg' into cluster…
mivanit Oct 27, 2025
e721c54
Merge branch 'clustering/main' into clustering/ci-decision-trees
mivanit Oct 27, 2025
39dda42
format
mivanit Oct 27, 2025
5e764da
some type fixes
mivanit Oct 27, 2025
4b52539
delete old script
mivanit Oct 27, 2025
bee64e3
more type fixes
mivanit Oct 27, 2025
fd27dbc
wip
mivanit Oct 27, 2025
010fbaa
wip type stuff
mivanit Oct 27, 2025
2f56924
type fixes
mivanit Oct 27, 2025
964f42e
print device
mivanit Oct 27, 2025
b6620ba
wip serialization of trees
mivanit Oct 27, 2025
cd2df15
wip
mivanit Oct 27, 2025
59b299a
make format
mivanit Oct 27, 2025
315e953
Merge branch 'main' into clustering/main
mivanit Oct 29, 2025
d67b1d1
Merge branch 'clustering/main' into clustering/ci-decision-trees
mivanit Oct 29, 2025
06deb8c
uv lock
mivanit Oct 29, 2025
60f54d2
Merge branch 'clustering/main' into clustering/ci-decision-trees
mivanit Oct 29, 2025
8481695
format and type fixes
mivanit Oct 29, 2025
b7b6f4b
more type fixes
mivanit Oct 29, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
spd/scripts/sweep_params.yaml
spd/scripts/sweep_params.yaml
docs/coverage/**
artifacts/**
docs/dep_graph/**
tests/.temp/**

**/out/
neuronpedia_outputs/
Expand Down
32 changes: 32 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,38 @@
"--model_path",
"wandb:goodfire/spd/runs/ioprgffh"
]
},
{
"name": "run_clustering example",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/spd/clustering/scripts/run_clustering.py",
"args": [
"--config",
"${workspaceFolder}/spd/clustering/configs/example.yaml",
],
"python": "${command:python.interpreterPath}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
},
{
"name": "clustering pipeline",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/spd/clustering/scripts/run_pipeline.py",
"args": [
"--config",
"${workspaceFolder}/spd/clustering/configs/pipeline_config.yaml",
],
"python": "${command:python.interpreterPath}",
"console": "integratedTerminal",
"justMyCode": true,
"env": {
"PYDEVD_DISABLE_FILE_VALIDATION": "1"
}
}
]
}
15 changes: 14 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,23 @@ coverage:
uv run python -m coverage report -m > $(COVERAGE_DIR)/coverage.txt
uv run python -m coverage html --directory=$(COVERAGE_DIR)/html/


.PHONY: clean
clean:
@echo "Cleaning Python cache and build artifacts..."
find . -type d -name "__pycache__" -exec rm -rf {} +
find . -type d -name "*.egg-info" -exec rm -rf {} +
rm -rf build/ dist/ .ruff_cache/ .pytest_cache/ .coverage


.PHONY: clustering-dev
clustering-dev:
uv run spd-cluster --local --config spd/clustering/configs/pipeline-dev-simplestories.yaml

.PHONY: app
app:
@uv run python app/run_app.py

.PHONY: install-app
install-app:
(cd app/frontend && npm install)
(cd app/frontend && npm install)
73 changes: 73 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# TODO: Cluster Coactivation Matrix Implementation

## What Was Changed

### 1. Added `ClusterActivations` dataclass (`spd/clustering/dashboard/compute_max_act.py`)
- New dataclass to hold vectorized cluster activations for all clusters
- Contains `activations` tensor [n_samples, n_clusters] and `cluster_indices` list

### 2. Added `compute_all_cluster_activations()` function
- Vectorized computation of all cluster activations at once
- Replaces the per-cluster loop for better performance
- Returns `ClusterActivations` object

### 3. Added `compute_cluster_coactivations()` function
- Computes coactivation matrix from list of `ClusterActivations` across batches
- Binarizes activations (acts > 0) and computes matrix multiplication: `activation_mask.T @ activation_mask`
- Follows the pattern from `spd/clustering/merge.py:69`
- Returns tuple of (coactivation_matrix, cluster_indices)

### 4. Modified `compute_max_activations()` function
- Now accumulates `ClusterActivations` from each batch in `all_cluster_activations` list
- Calls `compute_cluster_coactivations()` to compute the matrix
- **Changed return type**: now returns `tuple[DashboardData, np.ndarray, list[int]]`
- Added coactivation matrix and cluster_indices to return value

### 5. Modified `spd/clustering/dashboard/run.py`
- Updated to handle new return value from `compute_max_activations()`
- Saves coactivation matrix as `coactivations.npz` in the dashboard output directory
- NPZ file contains:
- `coactivations`: the [n_clusters, n_clusters] matrix
- `cluster_indices`: array mapping matrix positions to cluster IDs

## What Needs to be Checked

### Testing
- [ ] **Run the dashboard pipeline** on a real clustering run to verify:
- Coactivation computation doesn't crash
- Coactivations are saved correctly to NPZ file
- Matrix dimensions are correct
- `cluster_indices` mapping is correct

### Type Checking
- [ ] Run `make type` to ensure no type errors were introduced
- [ ] Verify jaxtyping annotations are correct

### Verification
- [ ] Load a saved `coactivations.npz` file and verify:
```python
data = np.load("coactivations.npz")
coact = data["coactivations"]
cluster_indices = data["cluster_indices"]
# Check: coact should be symmetric
# Check: diagonal should be >= off-diagonal (clusters coactivate with themselves most)
# Check: cluster_indices length should match coact.shape[0]
```

### Performance
- [ ] Check if vectorization actually improved performance
- [ ] Monitor memory usage with large numbers of clusters

### Edge Cases
- [ ] Test with clusters that have zero activations
- [ ] Test with single-batch runs
- [ ] Test with very large number of clusters

### Integration
- [ ] Verify the coactivation matrix can be used in downstream analysis
- [ ] Consider if visualization of coactivations should be added to dashboard

## Notes
- The coactivation matrix is computed over all samples processed (n_batches * batch_size * seq_len samples)
- Binarization threshold is currently hardcoded as `> 0` - may want to make this configurable
- The computation happens in the dashboard pipeline, NOT during the main clustering pipeline
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ dependencies = [
# see: https://github.com/huggingface/datasets/issues/6980 https://github.com/huggingface/datasets/pull/6991 (fixed in https://github.com/huggingface/datasets/releases/tag/2.21.0 )
"datasets>=2.21.0",
"simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev",
"scipy>=1.14.1",
"muutils",
"scikit-learn",
"fastapi",
"uvicorn",
]
Expand All @@ -40,10 +43,12 @@ dev = [
"ruff",
"basedpyright<1.32.0", # pyright and wandb issues, see https://github.com/goodfire-ai/spd/pull/232
"pre-commit",
"nbconvert",
]

[project.scripts]
spd-run = "spd.scripts.run:cli"
spd-cluster = "spd.clustering.scripts.run_pipeline:cli"

[build-system]
requires = ["setuptools", "wheel"]
Expand Down
23 changes: 20 additions & 3 deletions spd/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@
from pydantic import BaseModel, ConfigDict


class FileTypeError(ValueError):
"""Error raised when a file has an unsupported type/extension."""


class ConfigValidationError(ValueError):
"""Error raised when a config file fails pydantic validation."""


class BaseConfig(BaseModel):
"""Pydantic BaseModel suited for configs.

Expand All @@ -15,6 +23,8 @@ class BaseConfig(BaseModel):

model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True)

# TODO: add a "config_type" field, which is set to the class name, so that when loading a config we can check whether the config type matches the expected class

@classmethod
def from_file(cls, path: Path | str) -> Self:
"""Load config from path to a JSON or YAML file."""
Expand All @@ -27,9 +37,16 @@ def from_file(cls, path: Path | str) -> Self:
case Path() if path.suffix in [".yaml", ".yml"]:
data = yaml.safe_load(path.read_text())
case _:
raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}")
raise FileTypeError(f"Only (.json, .yaml, .yml) files are supported, got {path}")

try:
cfg = cls.model_validate(data)
except Exception as e:
raise ConfigValidationError(
f"Error validating config {cls=} from path `{path.as_posix()}`\n{data = }"
) from e

return cls.model_validate(data)
return cfg

def to_file(self, path: Path | str) -> None:
"""Save config to file (format inferred from extension)."""
Expand All @@ -43,4 +60,4 @@ def to_file(self, path: Path | str) -> None:
case ".yaml" | ".yml":
path.write_text(yaml.dump(self.model_dump(mode="json")))
case _:
raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}")
raise FileTypeError(f"Only (.json, .yaml, .yml) files are supported, got {path}")
Empty file added spd/clustering/__init__.py
Empty file.
Loading