From e17f51079c27a9d680d48c90971dbae1017eff7d Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 10 Oct 2025 05:09:46 -0700 Subject: [PATCH 01/77] clustering (#43) squash MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Squashing #43 to a clean branch for clustering work https://github.com/goodfire-ai/spd/pull/43#issuecomment-3389173731 # Commits * refactor to use MergeRunConfig everywhere * Wip * wip * wip * some info to yaml * todos * LFGgsgsgs * wip * refactor configs. new system with ability to choose sampling mechanism. code will be pushed later * refactor makefile to use new configs * add tests. code still not comitted yet * format * fix tests * test configs * TESTS PASSIN LFGGGG * format * pyright fixes * fix some pyright issues * parallelizing tests * distributed tests in CI * fix action * try to make tests faster * remove experiment with no canonical run * try to debug issue with normalizing ensemble works on my machine :( * [important] remove old files * move the merge pair samplers code to math folder * fix import in tests * default to cpu if no cuda in spd-cluster * default to cpu if no cuda in spd-cluster * wip wandb logging for spd-cluster refactor * more wip wandb logging for spd-cluster refactor * format? * wandb log tensor info * wandb log tensor info wip * wandb log tensor info wip * wandb log tensor info wip * some figs on wandb * format * wip * [temp] ignore config rename/deprecated warns * wip * wip * wip * wip * wip * wip * wip * wip * wip * wip * [hack] lack of canonical runs causing error * wip * wip, cleaning up s2 * wip * nathan approved MDL cost computation * swap to log2 in mdl * log MDL cost to wandb, other minor fixes * wip * wip * wip * some logging stuff fixed. now tryna figure out why way too many merge hists are loaded * [important] got run loading to work right!!! * wip * format * [important] FEATURE COMPLETE?? * format * minor changes to wandb stuff * fix tests? * format * fix pyright errors, some warnings remain * pyright passing!!! * removed old ignores in pyproject.toml, fixed the resulting errors * ? * fix * fix registry * numprocess to auto for pytest with xdist * move the TaskName definition * remove pyright ignore * move clustering stuff into clustering folders * format * fix type error * no more dep graph * remove outdated TODO.md * remove old ruff exclude * remove unused function * remove `from __future__ import annotations` * factor out component filtering, add tests for it * tensor stats randomized for big matrix * wip * minor fix related to passing filtered labels * histograms as wandb Images instead of plotly * wip * factor out some logging stuff, lots of minor changes * wandb log semilog of merge pair cost * oops, fix semilog * log hists for stuff * close figures to avoid memory leaks, add some logging to help profile * wip, added some logging * logging stuff * pyright fixes * make the merge history more lightweight - get rid of stats dicts, we store everything in wandb anyway - get rid of old sweep stuff, we keep it in config anyway - some code used only in experiment notebooks is broken * wip * format * intervals dict * fix issue with zip file closed * wip * remove merge profiling code * get rid of some old comments * more explanation around popping https://github.com/goodfire-ai/spd/pull/43#discussion_r2280438405 * components_in_pop_grp -> n_components_in_pop_grp https://github.com/goodfire-ai/spd/pull/43#discussion_r2280439360 * remove ignore deprecated config warnings todo see https://github.com/goodfire-ai/spd/pull/43#discussion_r2278716045 this will probably cause a lot of warnings to be logged unless canonical runs updated? * some typing fixes * fix pyright issues by pinning `transformers` package see https://github.com/goodfire-ai/spd/issues/139 * fix annoying mock test issue * fix pyright issue * remove call to deprecated plotting function * remove dead test * give sigmoid type * update default ss decomp run * wip * give sigmoid type * wip * fix interface, store ComponentModel.module_paths maybe it can be a property that gets the keys? and asserts that keys in components and gates match? * patched model -> target model when using config * fix: no more patched_model, use target_model for config * comments * fix cuda_memory_used.py * removed `spd/clustering/math/dev.py` https://github.com/goodfire-ai/spd/pull/43#discussion_r2382460598 * `StatsKeys` -> `StatsKey` https://github.com/goodfire-ai/spd/pull/43#discussion_r2382467106 * remove `plot_merge_history_costs` https://github.com/goodfire-ai/spd/pull/43#discussion_r2382470818 * remove old js-embedding-vis dep * Oli clustering refactor (#172) # PR #172 Commit List https://github.com/goodfire-ai/spd/pull/172 ## Initial Work by oli-clive-griffin - [7c488fc](https://github.com/goodfire-ai/spd/commit/7c488fc) - wip - [b506574](https://github.com/goodfire-ai/spd/commit/b506574) - wip - [56f6b8a](https://github.com/goodfire-ai/spd/commit/56f6b8a) - remove srt - [8682bec](https://github.com/goodfire-ai/spd/commit/8682bec) - wip - [e3645c5](https://github.com/goodfire-ai/spd/commit/e3645c5) - wip - [b246b31](https://github.com/goodfire-ai/spd/commit/b246b31) - wip - [b44abe8](https://github.com/goodfire-ai/spd/commit/b44abe8) - it runs! ## Merge and Formatting - [f83ed30](https://github.com/goodfire-ai/spd/commit/f83ed30) - Merge branch 'feature/clustering' into feature/clustering-refactor-v1 - [7700e6e](https://github.com/goodfire-ai/spd/commit/7700e6e) - format ## Continued Refactoring by oli-clive-griffin - [3ea89ae](https://github.com/goodfire-ai/spd/commit/3ea89ae) - wip - [bc37ae3](https://github.com/goodfire-ai/spd/commit/bc37ae3) - wip - [e30d680](https://github.com/goodfire-ai/spd/commit/e30d680) - wip - [d16cb96](https://github.com/goodfire-ai/spd/commit/d16cb96) - wip - [231af5a](https://github.com/goodfire-ai/spd/commit/231af5a) - wip - [988a9c8](https://github.com/goodfire-ai/spd/commit/988a9c8) - wip - [7665879](https://github.com/goodfire-ai/spd/commit/7665879) - wip - [b01130e](https://github.com/goodfire-ai/spd/commit/b01130e) - wip - [067573a](https://github.com/goodfire-ai/spd/commit/067573a) - wip ## Integration and Fixes by mivanit - [1311a54](https://github.com/goodfire-ai/spd/commit/1311a54) - Merge branch 'feature/oli-cluster' into feature/clustering-refactor-v1 - [28b6e01](https://github.com/goodfire-ai/spd/commit/28b6e01) - format fixes - [b0304d9](https://github.com/goodfire-ai/spd/commit/b0304d9) - remove old s2 script step - [ac10010](https://github.com/goodfire-ai/spd/commit/ac10010) - wip (oli-clive-griffin) - [fce4007](https://github.com/goodfire-ai/spd/commit/fce4007) - pyright passing - [238fa31](https://github.com/goodfire-ai/spd/commit/238fa31) - fix pih tests - [94f3f92](https://github.com/goodfire-ai/spd/commit/94f3f92) - better path handling - [2b4feee](https://github.com/goodfire-ai/spd/commit/2b4feee) - reorg of pipeline - [901cebd](https://github.com/goodfire-ai/spd/commit/901cebd) - fixing tests - [60ed9d9](https://github.com/goodfire-ai/spd/commit/60ed9d9) - wip (oli-clive-griffin) - [ce01a15](https://github.com/goodfire-ai/spd/commit/ce01a15) - wip ## Storage Implementation - [a04abba](https://github.com/goodfire-ai/spd/commit/a04abba) - Merge branch 'demo/cluster-storage' into feature/clustering-refactor-v1 - [b829f3e](https://github.com/goodfire-ai/spd/commit/b829f3e) - only need storage stuff - [cc202ce](https://github.com/goodfire-ai/spd/commit/cc202ce) - wip - [f2ecda1](https://github.com/goodfire-ai/spd/commit/f2ecda1) - wip storage - [eb7a30d](https://github.com/goodfire-ai/spd/commit/eb7a30d) - move storage.py to pipeline - [b4ae5de](https://github.com/goodfire-ai/spd/commit/b4ae5de) - docstring with tree - [8970478](https://github.com/goodfire-ai/spd/commit/8970478) - make some stuff private in storage - [55ba126](https://github.com/goodfire-ai/spd/commit/55ba126) - use ClusteringStorage everywhere - [c85e357](https://github.com/goodfire-ai/spd/commit/c85e357) - Remove unused old path logic - [7e66a3a](https://github.com/goodfire-ai/spd/commit/7e66a3a) - add back model_dump_with_properties - [62dfcdf](https://github.com/goodfire-ai/spd/commit/62dfcdf) - fix tests - [a0a3b2e](https://github.com/goodfire-ai/spd/commit/a0a3b2e) - Revert config change - [1f95caf](https://github.com/goodfire-ai/spd/commit/1f95caf) - add storage tests - [d860f01](https://github.com/goodfire-ai/spd/commit/d860f01) - simplify imports ## Final Cleanup and Testing - [0b94ef1](https://github.com/goodfire-ai/spd/commit/0b94ef1) - format - [e850e41](https://github.com/goodfire-ai/spd/commit/e850e41) - format and type check fixes - [7e83333](https://github.com/goodfire-ai/spd/commit/7e83333) - re-add notebooks as tests - [1a1e521](https://github.com/goodfire-ai/spd/commit/1a1e521) - fix configs - [338b761](https://github.com/goodfire-ai/spd/commit/338b761) - allow toml config files - [e809810](https://github.com/goodfire-ai/spd/commit/e809810) - wip - [6633e69](https://github.com/goodfire-ai/spd/commit/6633e69) - wip - [987e1f0](https://github.com/goodfire-ai/spd/commit/987e1f0) - better logging - [0c9e52d](https://github.com/goodfire-ai/spd/commit/0c9e52d) - wip - [18b424d](https://github.com/goodfire-ai/spd/commit/18b424d) - no default for n samples - [9f7321d](https://github.com/goodfire-ai/spd/commit/9f7321d) - add a todo ## Final Refinements - [dc86bae](https://github.com/goodfire-ai/spd/commit/dc86bae) - add back BatchedGroupMerge util methods - [079b3f2](https://github.com/goodfire-ai/spd/commit/079b3f2) - add link in comment - [9f75588](https://github.com/goodfire-ai/spd/commit/9f75588) - inline alive_mask as ~dead_components - [73a3378](https://github.com/goodfire-ai/spd/commit/73a3378) - wip - [46d4498](https://github.com/goodfire-ai/spd/commit/46d4498) - from_file() -> read() on RunConfig - [227f6ba](https://github.com/goodfire-ai/spd/commit/227f6ba) - rename: RunConfig -> ClusteringRunConfig - [b6869c3](https://github.com/goodfire-ai/spd/commit/b6869c3) - add back comments - [22acf57](https://github.com/goodfire-ai/spd/commit/22acf57) - more comments in merge.py - [e83bc6a](https://github.com/goodfire-ai/spd/commit/e83bc6a) - comment edit - [67af4a6](https://github.com/goodfire-ai/spd/commit/67af4a6) - wip - [f9c702e](https://github.com/goodfire-ai/spd/commit/f9c702e) - fix clustering failing by reverting to shelling out inst. of pool - [7f1516a](https://github.com/goodfire-ai/spd/commit/7f1516a) - Revert "fix clustering failing by reverting to shelling out inst. of … - [0db0379](https://github.com/goodfire-ai/spd/commit/0db0379) - revert to shelling out, fix logging - [e935a78](https://github.com/goodfire-ai/spd/commit/e935a78) - delete old code - [69f0bc8](https://github.com/goodfire-ai/spd/commit/69f0bc8) - [!!!] tests passing # commit list * wip * wip * remove srt * wip * wip * wip * it runs! * format * wip * wip * wip * wip * wip * wip * wip * wip * wip * format fixes * remove old s2 script step * wip * pyright passing * fix pih tests * better path handling * reorg of pipeline * fixing tests * wip * wip * only need storage stuff * wip * wip storage * move storage.py to pipeline * docstring with tree * make some stuff private in storage * use ClusteringStorage everywhere * Remove unused old path logic * add back model_dump_with_properties * fix tests * Revert config change * add storage tests * simplify imports * format * format and type check fixes * re-add notebooks as tests * fix configs * allow toml config files * wip * wip * better logging * wip * no default for n samples * add a todo * add back BatchedGroupMerge util methods * add link in comment * inline `alive_mask` as `~dead_components` pretty sure claude changed this for no reason at some point * wip * from_file() -> read() on RunConfig * rename: RunConfig -> ClusteringRunConfig * add back comments * more comments in merge.py * comment edit * wip * fix clustering failing by reverting to shelling out inst. of pool - originally we would shell out for the different clustering processes, and use a separate fd passing json - then we switched to multiprocessing.pool - this would cause model loading to fail -- not sure why - this reverts it to the old style there are still some old files laying around, theyll be removed * Revert "fix clustering failing by reverting to shelling out inst. of pool" This reverts commit f9c702eca2494ddbb13eb5f197f8cae5cec09926. reverting temporarily * revert to shelling out, fix logging * delete old code * [!!!] tests passing --------- Co-authored-by: Oliver Clive-Griffin * fixes/improvements to dist_utils * type fixes? * wip * Sync everything from feature/clustering-dashboard except spd/clustering/dashboard/ * minimizing diff * minimize pyprojec.toml diff * minimizing diff, removed deps * uv sync * fixing state dict mappings * test parallelization * parallelize tests * device getting utils * add TaskName type * uv sync (pytest-xdist dep) * remove old junk from Makefile * globally unique ports in tests to allow parallelization * comments explaining port allocation in tests * add distributed marker, rull all distributed tests on same worker * Revert "add distributed marker, rull all distributed tests on same worker" This reverts commit 3f55ffa932dfa2c4d0e0c2bcb69ea6615528c049. * add distributed marker, rull all distributed tests on same worker * refactor: use general_utils methods for getting device everywhere * wip jaccard * wip jaccard * wip jaccard * wip jaccard (plotting) * found where to increase timeout * wip jaccard * make format * fixes * typing fixes * claude doing a bunch of type hinting * trying to get pyright passing? this doesnt cause any issues locally :/ * pyright works both locally and in CI * allow installing cpu-only torch in CI * figure out CI disk usage by tests on main * alternate strategy for install basically: - torch and torchvision in a `pytorch` dep group - `pytorch` dep group, along with `dev`, is installed by default (uv sync behavior unchanged) - `make install-ci` recipe which first manually installs torch stuff, then uv sync but NO `pytorch` dep group * fixes to the last commit * cleanup temp changes * make in CI * wip * wip * wip * uv sync * try to fix markupsafe? * pin markup safe with explanation * update lockfile?? * nope i think we need the index strategy * ? * markupsafe issue * remove disk usage printing * fix pyright issue * dependency hell * fix deps??? * oops, missing index strategy. moved to makefile * re-lock * make from /usr/bin/ ? * dependency hell * type checking hell * Update spd/utils/general_utils.py Co-authored-by: Dan Braun * wrap and fix Conv1D imports * minimize diff cleanup * try compile-bytecode for ci install * dont compile bytecode actually compare: - no compile: https://github.com/goodfire-ai/spd/actions/runs/18316145235/job/52156789078 - yes compile: https://github.com/goodfire-ai/spd/actions/runs/18316319784/job/52157420005 approximately no speedup in tests or basedpyright run, but extra 10s added to install step (where the compilation happens) * remove markupsafe constraint? * switched to use get_obj_device * remove device: torch.device type hints see https://github.com/goodfire-ai/spd/pull/186#discussion_r2410092522 * remove "distributed" test marker * fix another timeout * replace get_module_device -> get_obj_device * better comments on port uniqueness * remove old markers port uniqueness should resolve the issue, without causing slowdowns * remove timeout TODO comments * removed checks.yaml timeout todo, clustering tests pass in ~12min * [diff-min] transformers version issue from #139 resolved * fix comment * wip jaccard * pyright fixes to jaccard, wip * Update docs about grad syncing with DDP * Mention feature/memorization-experiments in README * Fix train and eval metrics and hidden_act_recon (#189) * Update canonical runs and change target model path (#197) * Avoid using too many processes in tests * fix wandb model paths to older runs `goodfire/spd/runs/{id}` -> `goodfire/spd-pre-Sep-2025/runs/{id}` see https://github.com/goodfire-ai/spd/commit/5ba1d24b554ca776d94f55e4631af3ffa07e841e --------- Co-authored-by: Oliver Clive-Griffin Co-authored-by: Dan Braun --- .gitignore | 4 +- TODO.md | 73 +++ pyproject.toml | 3 + spd/clustering/__init__.py | 0 spd/clustering/activations.py | 269 ++++++++++ spd/clustering/compute_costs.py | 297 +++++++++++ spd/clustering/configs/example.toml | 37 ++ spd/clustering/configs/example.yaml | 35 ++ spd/clustering/configs/resid_mlp1.json | 24 + spd/clustering/configs/resid_mlp2.json | 23 + spd/clustering/configs/resid_mlp3.json | 23 + spd/clustering/configs/simplestories_dev.json | 25 + spd/clustering/configs/test-resid_mlp1.json | 24 + .../configs/test-simplestories.json | 24 + spd/clustering/consts.py | 48 ++ spd/clustering/math/__init__.py | 0 spd/clustering/math/jaccard.py | 71 +++ spd/clustering/math/jaccard_test.py | 194 ++++++++ spd/clustering/math/merge_distances.py | 53 ++ spd/clustering/math/merge_matrix.py | 283 +++++++++++ spd/clustering/math/merge_pair_samplers.py | 121 +++++ spd/clustering/math/perm_invariant_hamming.py | 70 +++ spd/clustering/math/semilog.py | 13 + spd/clustering/math/tensor_stats.py | 160 ++++++ spd/clustering/merge.py | 246 ++++++++++ spd/clustering/merge_config.py | 119 +++++ spd/clustering/merge_history.py | 463 ++++++++++++++++++ spd/clustering/merge_run_config.py | 281 +++++++++++ spd/clustering/pipeline/__init__.py | 0 .../pipeline/clustering_pipeline.py | 106 ++++ spd/clustering/pipeline/dist_utils.py | 313 ++++++++++++ spd/clustering/pipeline/s1_split_dataset.py | 151 ++++++ spd/clustering/pipeline/s2_clustering.py | 409 ++++++++++++++++ .../pipeline/s3_normalize_histories.py | 32 ++ .../pipeline/s4_compute_distances.py | 92 ++++ spd/clustering/pipeline/storage.py | 300 ++++++++++++ spd/clustering/plotting/__init__.py | 1 + spd/clustering/plotting/activations.py | 379 ++++++++++++++ spd/clustering/plotting/merge.py | 327 +++++++++++++ spd/clustering/scripts/main.py | 81 +++ spd/clustering/util.py | 18 + spd/clustering/wandb_tensor_info.py | 169 +++++++ spd/models/component_model.py | 1 + .../math/test_perm_invariant_hamming.py | 123 +++++ tests/clustering/scripts/cluster_resid_mlp.py | 196 ++++++++ tests/clustering/scripts/cluster_ss.py | 129 +++++ .../clustering/test_clustering_experiments.py | 99 ++++ .../clustering/test_filter_dead_components.py | 131 +++++ tests/clustering/test_merge_config.py | 181 +++++++ tests/clustering/test_merge_integration.py | 201 ++++++++ tests/clustering/test_merge_pair_samplers.py | 274 +++++++++++ tests/clustering/test_storage.py | 351 +++++++++++++ tests/clustering/test_wandb_integration.py | 153 ++++++ uv.lock | 106 +++- 54 files changed, 7279 insertions(+), 27 deletions(-) create mode 100644 TODO.md create mode 100644 spd/clustering/__init__.py create mode 100644 spd/clustering/activations.py create mode 100644 spd/clustering/compute_costs.py create mode 100644 spd/clustering/configs/example.toml create mode 100644 spd/clustering/configs/example.yaml create mode 100644 spd/clustering/configs/resid_mlp1.json create mode 100644 spd/clustering/configs/resid_mlp2.json create mode 100644 spd/clustering/configs/resid_mlp3.json create mode 100644 spd/clustering/configs/simplestories_dev.json create mode 100644 spd/clustering/configs/test-resid_mlp1.json create mode 100644 spd/clustering/configs/test-simplestories.json create mode 100644 spd/clustering/consts.py create mode 100644 spd/clustering/math/__init__.py create mode 100644 spd/clustering/math/jaccard.py create mode 100644 spd/clustering/math/jaccard_test.py create mode 100644 spd/clustering/math/merge_distances.py create mode 100644 spd/clustering/math/merge_matrix.py create mode 100644 spd/clustering/math/merge_pair_samplers.py create mode 100644 spd/clustering/math/perm_invariant_hamming.py create mode 100644 spd/clustering/math/semilog.py create mode 100644 spd/clustering/math/tensor_stats.py create mode 100644 spd/clustering/merge.py create mode 100644 spd/clustering/merge_config.py create mode 100644 spd/clustering/merge_history.py create mode 100644 spd/clustering/merge_run_config.py create mode 100644 spd/clustering/pipeline/__init__.py create mode 100644 spd/clustering/pipeline/clustering_pipeline.py create mode 100644 spd/clustering/pipeline/dist_utils.py create mode 100644 spd/clustering/pipeline/s1_split_dataset.py create mode 100644 spd/clustering/pipeline/s2_clustering.py create mode 100644 spd/clustering/pipeline/s3_normalize_histories.py create mode 100644 spd/clustering/pipeline/s4_compute_distances.py create mode 100644 spd/clustering/pipeline/storage.py create mode 100644 spd/clustering/plotting/__init__.py create mode 100644 spd/clustering/plotting/activations.py create mode 100644 spd/clustering/plotting/merge.py create mode 100644 spd/clustering/scripts/main.py create mode 100644 spd/clustering/util.py create mode 100644 spd/clustering/wandb_tensor_info.py create mode 100644 tests/clustering/math/test_perm_invariant_hamming.py create mode 100644 tests/clustering/scripts/cluster_resid_mlp.py create mode 100644 tests/clustering/scripts/cluster_ss.py create mode 100644 tests/clustering/test_clustering_experiments.py create mode 100644 tests/clustering/test_filter_dead_components.py create mode 100644 tests/clustering/test_merge_config.py create mode 100644 tests/clustering/test_merge_integration.py create mode 100644 tests/clustering/test_merge_pair_samplers.py create mode 100644 tests/clustering/test_storage.py create mode 100644 tests/clustering/test_wandb_integration.py diff --git a/.gitignore b/.gitignore index c8e17fa45..bda91d8fa 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ spd/scripts/sweep_params.yaml -spd/scripts/sweep_params.yaml docs/coverage/** +artifacts/** +docs/dep_graph/** +tests/.temp/** **/out/ neuronpedia_outputs/ diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..9e6f14815 --- /dev/null +++ b/TODO.md @@ -0,0 +1,73 @@ +# TODO: Cluster Coactivation Matrix Implementation + +## What Was Changed + +### 1. Added `ClusterActivations` dataclass (`spd/clustering/dashboard/compute_max_act.py`) +- New dataclass to hold vectorized cluster activations for all clusters +- Contains `activations` tensor [n_samples, n_clusters] and `cluster_indices` list + +### 2. Added `compute_all_cluster_activations()` function +- Vectorized computation of all cluster activations at once +- Replaces the per-cluster loop for better performance +- Returns `ClusterActivations` object + +### 3. Added `compute_cluster_coactivations()` function +- Computes coactivation matrix from list of `ClusterActivations` across batches +- Binarizes activations (acts > 0) and computes matrix multiplication: `activation_mask.T @ activation_mask` +- Follows the pattern from `spd/clustering/merge.py:69` +- Returns tuple of (coactivation_matrix, cluster_indices) + +### 4. Modified `compute_max_activations()` function +- Now accumulates `ClusterActivations` from each batch in `all_cluster_activations` list +- Calls `compute_cluster_coactivations()` to compute the matrix +- **Changed return type**: now returns `tuple[DashboardData, np.ndarray, list[int]]` + - Added coactivation matrix and cluster_indices to return value + +### 5. Modified `spd/clustering/dashboard/run.py` +- Updated to handle new return value from `compute_max_activations()` +- Saves coactivation matrix as `coactivations.npz` in the dashboard output directory +- NPZ file contains: + - `coactivations`: the [n_clusters, n_clusters] matrix + - `cluster_indices`: array mapping matrix positions to cluster IDs + +## What Needs to be Checked + +### Testing +- [ ] **Run the dashboard pipeline** on a real clustering run to verify: + - Coactivation computation doesn't crash + - Coactivations are saved correctly to NPZ file + - Matrix dimensions are correct + - `cluster_indices` mapping is correct + +### Type Checking +- [ ] Run `make type` to ensure no type errors were introduced +- [ ] Verify jaxtyping annotations are correct + +### Verification +- [ ] Load a saved `coactivations.npz` file and verify: + ```python + data = np.load("coactivations.npz") + coact = data["coactivations"] + cluster_indices = data["cluster_indices"] + # Check: coact should be symmetric + # Check: diagonal should be >= off-diagonal (clusters coactivate with themselves most) + # Check: cluster_indices length should match coact.shape[0] + ``` + +### Performance +- [ ] Check if vectorization actually improved performance +- [ ] Monitor memory usage with large numbers of clusters + +### Edge Cases +- [ ] Test with clusters that have zero activations +- [ ] Test with single-batch runs +- [ ] Test with very large number of clusters + +### Integration +- [ ] Verify the coactivation matrix can be used in downstream analysis +- [ ] Consider if visualization of coactivations should be added to dashboard + +## Notes +- The coactivation matrix is computed over all samples processed (n_batches * batch_size * seq_len samples) +- Binarization threshold is currently hardcoded as `> 0` - may want to make this configurable +- The computation happens in the dashboard pipeline, NOT during the main clustering pipeline diff --git a/pyproject.toml b/pyproject.toml index c91d1cad7..901e9de39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,8 @@ dependencies = [ # see: https://github.com/huggingface/datasets/issues/6980 https://github.com/huggingface/datasets/pull/6991 (fixed in https://github.com/huggingface/datasets/releases/tag/2.21.0 ) "datasets>=2.21.0", "simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev", + "scipy>=1.14.1", + "muutils", ] [dependency-groups] @@ -42,6 +44,7 @@ dev = [ [project.scripts] spd-run = "spd.scripts.run:cli" +spd-cluster = "spd.clustering.scripts.main:cli" [build-system] requires = ["setuptools", "wheel"] diff --git a/spd/clustering/__init__.py b/spd/clustering/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py new file mode 100644 index 000000000..6b7c51abf --- /dev/null +++ b/spd/clustering/activations.py @@ -0,0 +1,269 @@ +from dataclasses import dataclass +from functools import cached_property +from typing import Literal, NamedTuple + +import torch +from jaxtyping import Bool, Float, Float16, Int +from torch import Tensor + +from spd.clustering.consts import ( + ActivationsTensor, + BoolActivationsTensor, + ClusterCoactivationShaped, + ComponentLabels, +) +from spd.clustering.util import ModuleFilterFunc +from spd.models.component_model import ComponentModel, OutputWithCache +from spd.models.sigmoids import SigmoidTypes + + +def component_activations( + model: ComponentModel, + device: torch.device | str, + batch: Int[Tensor, "batch_size n_ctx"], + sigmoid_type: SigmoidTypes, +) -> dict[str, ActivationsTensor]: + """Get the component activations over a **single** batch.""" + causal_importances: dict[str, ActivationsTensor] + with torch.no_grad(): + model_output: OutputWithCache = model( + batch.to(device), + cache_type="input", + ) + + causal_importances, _ = model.calc_causal_importances( + pre_weight_acts=model_output.cache, + sigmoid_type=sigmoid_type, + sampling="continuous", + detach_inputs=False, + ) + + return causal_importances + + +def compute_coactivatons( + activations: ActivationsTensor | BoolActivationsTensor, +) -> ClusterCoactivationShaped: + """Compute the coactivations matrix from the activations.""" + # TODO: this works for both boolean and continuous activations, + # but we could do better by just using OR for boolean activations + # and maybe even some bitshift hacks. but for now, we convert to float16 + activations_f16: Float16[Tensor, "samples C"] = activations.to(torch.float16) + return activations_f16.T @ activations_f16 + + +class FilteredActivations(NamedTuple): + activations: ActivationsTensor + "activations after filtering dead components" + + labels: ComponentLabels + "list of length c with labels for each preserved component" + + dead_components_labels: ComponentLabels | None + "list of labels for dead components, or None if no filtering was applied" + + @property + def n_alive(self) -> int: + """Number of alive components after filtering.""" + n_alive: int = len(self.labels) + assert n_alive == self.activations.shape[1], ( + f"{n_alive = } != {self.activations.shape[1] = }" + ) + return n_alive + + @property + def n_dead(self) -> int: + """Number of dead components after filtering.""" + return len(self.dead_components_labels) if self.dead_components_labels else 0 + + +def filter_dead_components( + activations: ActivationsTensor, + labels: ComponentLabels, + filter_dead_threshold: float = 0.01, +) -> FilteredActivations: + """Filter out dead components based on a threshold + + if `filter_dead_threshold` is 0, no filtering is applied. + activations and labels are returned as is, `dead_components_labels` is `None`. + + otherwise, components whose **maximum** activations across all samples is below the threshold + are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`. + `dead_components_labels` will also be `None` if no components were below the threshold. + """ + dead_components_lst: ComponentLabels | None = None + if filter_dead_threshold > 0: + dead_components_lst = ComponentLabels(list()) + max_act: Float[Tensor, " c"] = activations.max(dim=0).values + dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold + + if dead_components.any(): + activations = activations[:, ~dead_components] + alive_labels: list[tuple[str, bool]] = [ + (lbl, bool(keep.item())) + for lbl, keep in zip(labels, ~dead_components, strict=False) + ] + # re-assign labels only if we are filtering + labels = ComponentLabels([label for label, keep in alive_labels if keep]) + dead_components_lst = ComponentLabels( + [label for label, keep in alive_labels if not keep] + ) + + return FilteredActivations( + activations=activations, + labels=labels, + dead_components_labels=dead_components_lst if dead_components_lst else None, + ) + + +@dataclass(frozen=True) +class ProcessedActivations: + """Processed activations after filtering and concatenation""" + + activations_raw: dict[str, ActivationsTensor] + "activations after filtering, but prior to concatenation" + + activations: ActivationsTensor + "activations after filtering and concatenation" + + labels: ComponentLabels + "list of length c with labels for each preserved component, format `{module_name}:{component_index}`" + + dead_components_lst: ComponentLabels | None + "list of labels for dead components, or None if no filtering was applied" + + def validate(self) -> None: + """Validate the processed activations""" + # getting this property will also perform a variety of other checks + assert self.n_components_alive > 0 + + @property + def n_components_original(self) -> int: + """Total number of components before filtering. equal to the sum of all components in `activations_raw`, or to `n_components_alive + n_components_dead`""" + return sum(act.shape[1] for act in self.activations_raw.values()) + + @property + def n_components_alive(self) -> int: + """Number of alive components after filtering. equal to the length of `labels`""" + n_alive: int = len(self.labels) + assert n_alive + self.n_components_dead == self.n_components_original, ( + f"({n_alive = }) + ({self.n_components_dead = }) != ({self.n_components_original = })" + ) + assert n_alive == self.activations.shape[1], ( + f"{n_alive = } != {self.activations.shape[1] = }" + ) + + return n_alive + + @property + def n_components_dead(self) -> int: + """Number of dead components after filtering. equal to the length of `dead_components_lst` if it is not None, or 0 otherwise""" + return len(self.dead_components_lst) if self.dead_components_lst else 0 + + @cached_property + def label_index(self) -> dict[str, int | None]: + """Create a mapping from label to alive index (`None` if dead)""" + return { + **{label: i for i, label in enumerate(self.labels)}, + **( + {label: None for label in self.dead_components_lst} + if self.dead_components_lst + else {} + ), + } + + def get_label_index(self, label: str) -> int | None: + """Get the index of a label in the activations, or None if it is dead""" + return self.label_index[label] + + def get_label_index_alive(self, label: str) -> int: + """Get the index of a label in the activations, or raise if it is dead""" + idx: int | None = self.get_label_index(label) + if idx is None: + raise ValueError(f"Label '{label}' is dead and has no index in the activations.") + return idx + + @property + def module_keys(self) -> list[str]: + """Get the module keys from the activations_raw""" + return list(self.activations_raw.keys()) + + def get_module_indices(self, module_key: str) -> list[int | None]: + """given a module key, return a list len "num components in that moduel", with int index in alive components, or None if dead""" + num_components: int = self.activations_raw[module_key].shape[1] + return [self.label_index[f"{module_key}:{i}"] for i in range(num_components)] + + +def process_activations( + activations: dict[ + str, # module name to + Float[Tensor, "samples C"] # (sample x component gate activations) + | Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations) + ], + filter_dead_threshold: float = 0.01, + seq_mode: Literal["concat", "seq_mean", None] = None, + filter_modules: ModuleFilterFunc | None = None, +) -> ProcessedActivations: + """get back a dict of coactivations, slices, and concated activations + + Args: + activations: Dictionary of activations by module + filter_dead_threshold: Threshold for filtering dead components + seq_mode: How to handle sequence dimension + filter_modules: Function to filter modules + sort_components: Whether to sort components by similarity within each module + """ + + # reshape -- special cases for llms + # ============================================================ + activations_: dict[str, ActivationsTensor] + if seq_mode == "concat": + # Concatenate the sequence dimension into the sample dimension + activations_ = { + key: act.reshape(act.shape[0] * act.shape[1], act.shape[2]) + for key, act in activations.items() + } + elif seq_mode == "seq_mean": + # Take the mean over the sequence dimension + activations_ = { + key: act.mean(dim=1) if act.ndim == 3 else act for key, act in activations.items() + } + else: + # Use the activations as they are + activations_ = activations + + # put the labelled activations into one big matrix and filter them + # ============================================================ + + # filter activations for only the modules we want + if filter_modules is not None: + activations_ = {key: act for key, act in activations_.items() if filter_modules(key)} + + # compute the labels and total component count + total_c: int = 0 + labels: ComponentLabels = ComponentLabels(list()) + for key, act in activations_.items(): + c: int = act.shape[-1] + labels.extend([f"{key}:{i}" for i in range(c)]) + total_c += c + + # concat the activations + act_concat: ActivationsTensor = torch.cat([activations_[key] for key in activations_], dim=-1) + + # filter dead components + filtered_components: FilteredActivations = filter_dead_components( + activations=act_concat, + labels=labels, + filter_dead_threshold=filter_dead_threshold, + ) + + assert filtered_components.n_alive + filtered_components.n_dead == total_c, ( + f"({filtered_components.n_alive = }) + ({filtered_components.n_dead = }) != ({total_c = })" + ) + + return ProcessedActivations( + activations_raw=activations_, + activations=filtered_components.activations, + labels=filtered_components.labels, + dead_components_lst=filtered_components.dead_components_labels, + ) diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py new file mode 100644 index 000000000..ba1ff274c --- /dev/null +++ b/spd/clustering/compute_costs.py @@ -0,0 +1,297 @@ +import math + +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, MergePair +from spd.clustering.math.merge_matrix import GroupMerge + + +def compute_mdl_cost( + acts: Float[Tensor, " k_groups"], + merges: GroupMerge, + alpha: float = 1.0, +) -> float: + r"""Compute MDL costs for merge matrices + + $$ + MDL = \sum_{i \in \N_k} s_i ( \log(k) + \alpha r(P_i) ) + $$ + + where: + - $s_i$ activation of component $i$, $s_j$ activation of component $j$ + - $r(P_i)$ rank of component $i$, $r(P_j)$ rank of component $j$ + - $k$ is the total number of components + """ + + k_groups: int = acts.shape[0] + assert k_groups == merges.k_groups, "Merges must match activation vector shape" + + return ( + (acts * (math.log2(k_groups) + alpha * merges.components_per_group.to(device=acts.device))) + .sum() + .item() + ) + + +def compute_merge_costs( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + alpha: float = 1.0, +) -> ClusterCoactivationShaped: + r"""Compute MDL costs for merge matrices + + $$ + F(P_i, P_j) + = \alpha |s_i| r(P_i) + \alpha |s_j| r(P_j) + - s_i s_j ( \alpha r(P_i) + \alpha r(P_j) + c ) + = \alpha ( + |s_i| r(P_i) + + |s_j| r(P_j) + - s_i s_j ( r(P_i) + r(P_j) + c/\alpha ) + ) + $$ + + new version from nathu 2025-08-11 16:48 + + $$ + (s_\Sigma - s_i - s_j) log((c-1)/c) + + s_{i,j} log(c-1) - s_i log(c) - s_j log(c) + + alpha ( s_{i,j} r(P_{i,j}) - s_i r(P_i) - s_j r(P_j) ) + $$ + where: + - $s_\Sigma$ average activation of all components + - $s_i$ activation of component $i$, $s_j$ activation of component $j$ + - $s_{i,j}$ activation of the merged component $i,j$ + - $r(P_i)$ rank of component $i$, $r(P_j)$ rank of component $j$ + - $r(P_{i,j})$ rank of the merged component $i,j$ + + """ + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + assert merges.k_groups == k_groups, "Merges must match coactivation matrix shape" + + device: torch.device = coact.device + ranks: Float[Tensor, " k_groups"] = merges.components_per_group.to(device=device).float() + s_diag: Float[Tensor, " k_groups"] = torch.diag(coact).to(device=device) + # term_si_rpj: Float[Tensor, "k_groups k_groups"] = s_diag.view(-1, 1) * ranks.view(1, -1) + # term_si_rpj: Float[Tensor, "k_groups k_groups"] = s_diag.view(-1, 1) * (ranks.view(1, -1) + 1/alpha) + term_si_rpi: Float[Tensor, " k_groups"] = s_diag * ranks + # dbg_auto(term_si_rpi) + rank_sum: ClusterCoactivationShaped = ranks.view(-1, 1) + ranks.view(1, -1) + # TODO: use dynamic rank computation + # return alpha * ( + # term_si_rpj # |s_i| r(P_j) + # + term_si_rpj.T # |s_j| r(P_i) + # - coact * ( # s_i s_j + # rank_sum # r(P_i) + r(P_j) + # + (rank_cost(merges.k_groups) / alpha) # c / alpha + # ) + # ) + + coact_OR: ClusterCoactivationShaped = s_diag.view(-1, 1) + s_diag.view(1, -1) - coact + + # reduce penalty for sending dictionary by 1 + # (s_\Sigma - s_i - s_j) log((c-1)/c) + # delta of cost for sending index, in expectation + # + s_{i,j} log(c-1) - s_i log(c) - s_j log(c) + # delta of cost for sending ranks, in expectation + # + alpha ( s_{i,j} r(P_{i,j}) - s_i r(P_i) - s_j r(P_j) + + s_other: ClusterCoactivationShaped = ( + s_diag.sum() - s_diag.view(-1, 1) - s_diag.view(1, -1) + ) * math.log2((k_groups - 1) / k_groups) + + bits_local: ClusterCoactivationShaped = ( + coact_OR * math.log2(k_groups - 1) + - s_diag.view(-1, 1) * math.log2(k_groups) + - s_diag.view(1, -1) * math.log2(k_groups) + ) + + penalty: ClusterCoactivationShaped = ( + coact_OR * rank_sum # s_{i,j} r(P_{i,j}) + - term_si_rpi.view(-1, 1) # s_i r(P_i) + - term_si_rpi.view(1, -1) # s_j r(P_j) + ) + + output: ClusterCoactivationShaped = s_other + bits_local + alpha * penalty + return output + + +def recompute_coacts_merge_pair( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + merge_pair: MergePair, + activation_mask: Bool[Tensor, "samples k_groups"], +) -> tuple[ + GroupMerge, + Float[Tensor, "k_groups-1 k_groups-1"], + Bool[Tensor, "samples k_groups"], +]: + # check shape + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + + # activations of the new merged group + activation_mask_grp: Bool[Tensor, " samples"] = ( + activation_mask[:, merge_pair[0]] + activation_mask[:, merge_pair[1]] + ) + + # coactivations with the new merged group + coact_with_merge: Float[Tensor, " k_groups"] = ( + activation_mask_grp.float() @ activation_mask.float() + ) + new_group_idx: int = min(merge_pair) + remove_idx: int = max(merge_pair) + new_group_self_coact: float = activation_mask_grp.float().sum().item() + + # assemble the merge pair + merge_new: GroupMerge = merges.merge_groups( + merge_pair[0], + merge_pair[1], + ) + # TODO: we don't use this index for anything, and could reconstruct it from the merge pair if needed. get rid of it + # `merge_groups` will set `old_to_new_idx` to be an actual dict for `merge_new` + old_to_new_idx: dict[int | None, int | None] = merge_new.old_to_new_idx # pyright: ignore[reportAssignmentType] + assert old_to_new_idx[None] == new_group_idx, ( + "New group index should be the minimum of the merge pair" + ) + assert old_to_new_idx[new_group_idx] is None + assert old_to_new_idx[remove_idx] is None + # TODO: check that the rest are in order? probably not necessary + + # reindex coactivations + coact_temp: ClusterCoactivationShaped = coact.clone() + # add in the similarities with the new group + coact_temp[new_group_idx, :] = coact_with_merge + coact_temp[:, new_group_idx] = coact_with_merge + # delete the old group + mask: Bool[Tensor, " k_groups"] = torch.ones( + coact_temp.shape[0], dtype=torch.bool, device=coact_temp.device + ) + mask[remove_idx] = False + coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact_temp[mask, :][:, mask] + # add in the self-coactivation of the new group + coact_new[new_group_idx, new_group_idx] = new_group_self_coact + + # reindex mask + activation_mask_new: Float[Tensor, "samples ..."] = activation_mask.clone() + # add in the new group + activation_mask_new[:, new_group_idx] = activation_mask_grp + # remove the old group + activation_mask_new = activation_mask_new[:, mask] + + return ( + merge_new, + coact_new, + activation_mask_new, + ) + + +def recompute_coacts_pop_group( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + component_idx: int, + activation_mask: Bool[Tensor, "n_samples k_groups"], + activation_mask_orig: Bool[Tensor, "n_samples n_components"], +) -> tuple[ + GroupMerge, + Float[Tensor, "k_groups+1 k_groups+1"], + Bool[Tensor, "n_samples k_groups+1"], +]: + # sanity check dims + # ================================================== + + k_groups: int = coact.shape[0] + n_samples: int = activation_mask.shape[0] + k_groups_new: int = k_groups + 1 + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + assert activation_mask.shape[1] == k_groups, ( + "Activation mask must match coactivation matrix shape" + ) + assert n_samples == activation_mask_orig.shape[0], ( + "Activation mask original must match number of samples" + ) + + # get the activations we need + # ================================================== + # which group does the component belong to? + group_idx: int = int(merges.group_idxs[component_idx].item()) + group_size_old: int = int(merges.components_per_group[group_idx].item()) + group_size_new: int = group_size_old - 1 + + # activations of component we are popping out + acts_pop: Bool[Tensor, " samples"] = activation_mask_orig[:, component_idx] + + # activations of the "remainder" -- everything other than the component we are popping out, + # in the group we're popping it out of + acts_remainder: Bool[Tensor, " samples"] = ( + activation_mask_orig[ + :, [i for i in merges.components_in_group(group_idx) if i != component_idx] + ] + .max(dim=-1) + .values + ) + + # assemble the new activation mask + # ================================================== + # first concat the popped-out component onto the end + activation_mask_new: Bool[Tensor, " samples k_groups+1"] = torch.cat( + [activation_mask, acts_pop.unsqueeze(1)], + dim=1, + ) + # then replace the group we are popping out of with the remainder + activation_mask_new[:, group_idx] = acts_remainder + + # assemble the new coactivation matrix + # ================================================== + coact_new: Float[Tensor, "k_groups+1 k_groups+1"] = torch.full( + (k_groups_new, k_groups_new), + fill_value=float("nan"), + dtype=coact.dtype, + device=coact.device, + ) + # copy in the old coactivation matrix + coact_new[:k_groups, :k_groups] = coact.clone() + # compute new coactivations we need + coact_pop: Float[Tensor, " k_groups"] = acts_pop.float() @ activation_mask_new.float() + coact_remainder: Float[Tensor, " k_groups"] = ( + acts_remainder.float() @ activation_mask_new.float() + ) + + # replace the relevant rows and columns + coact_new[group_idx, :] = coact_remainder + coact_new[:, group_idx] = coact_remainder + coact_new[-1, :] = coact_pop + coact_new[:, -1] = coact_pop + + # assemble the new group merge + # ================================================== + group_idxs_new: Int[Tensor, " k_groups+1"] = merges.group_idxs.clone() + # the popped-out component is now its own group + new_group_idx: int = k_groups_new - 1 + group_idxs_new[component_idx] = new_group_idx + merge_new: GroupMerge = GroupMerge( + group_idxs=group_idxs_new, + k_groups=k_groups_new, + ) + + # sanity check + assert merge_new.components_per_group.shape == (k_groups_new,), ( + "New merge must have k_groups+1 components" + ) + assert merge_new.components_per_group[new_group_idx] == 1, ( + "New group must have exactly one component" + ) + assert merge_new.components_per_group[group_idx] == group_size_new, ( + "Old group must have one less component" + ) + + # return + # ================================================== + return ( + merge_new, + coact_new, + activation_mask_new, + ) diff --git a/spd/clustering/configs/example.toml b/spd/clustering/configs/example.toml new file mode 100644 index 000000000..d5cfe46d6 --- /dev/null +++ b/spd/clustering/configs/example.toml @@ -0,0 +1,37 @@ +# Example MergeRunConfig in TOML format + +# Run configuration +model_path = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" # WandB path to the decomposed model +task_name = "lm" # Task name (must be explicit: tms, resid_mlp, lm, ih) +# experiment_key = "tms_5-2" # Alternative: use experiment key from EXPERIMENT_REGISTRY +n_batches = 10 # Ensemble size +batch_size = 64 # Batch size for processing -- number of samples for each run in the ensemble + +# WandB configuration +wandb_enabled = false # Enable WandB logging +wandb_project = "spd-cluster" # WandB project name + +[intervals] +stat = 1 # for k_groups, merge_pair_cost, mdl_loss +tensor = 100 # for wandb_log_tensor and fraction_* calculations +plot = 100 # for calling the plotting callback +artifact = 100 # for calling the artifact callback + +# Optional: Override defaults (typically set via CLI args) +# base_path = ".data/clustering/" # defaults to .data/clustering/ +# workers_per_device = 1 # defaults to 1 +# devices = ["cpu"] # defaults to ["cpu"], CLI will override with ["cuda"] if available + +# Merge algorithm parameters (wrapped in merge_config) +[merge_config] +activation_threshold = 0.01 # set to null to use scalar activations for cost calculation +alpha = 1.0 # rank penalty term +iters = 100 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? +pop_component_prob = 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway +filter_dead_threshold = 0.001 # Threshold for filtering dead components +module_name_filter = "__NULL__" # Can be a string prefix like "model.layers.0." if you want to do only some modules +rank_cost_fn_name = "const_1" # Options: const_1, const_2, log, linear +merge_pair_sampling_method = "range" # Method for sampling merge pairs: 'range' or 'mcmc' + +[merge_config.merge_pair_sampling_kwargs] +threshold = 0.05 # For range sampler: fraction of the range of costs to sample from diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/example.yaml new file mode 100644 index 000000000..5f3cd5fa5 --- /dev/null +++ b/spd/clustering/configs/example.yaml @@ -0,0 +1,35 @@ +# Example MergeRunConfig in YAML format + +# Merge algorithm parameters (wrapped in merge_config) +merge_config: + activation_threshold: 0.01 # set to null to use scalar activations for cost calculation + alpha: 1.0 # rank penalty term + iters: 100 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? + merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' + merge_pair_sampling_kwargs: + threshold: 0.05 # For range sampler: fraction of the range of costs to sample from + pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway + filter_dead_threshold: 0.001 # Threshold for filtering dead components + module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules + rank_cost_fn_name: const_1 # Options: const_1, const_2, log, linear + +# Run configuration +model_path: wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh # WandB path to the decomposed model +task_name: lm # Task name (must be explicit: tms, resid_mlp, lm, ih) +# experiment_key: tms_5-2 # Alternative: use experiment key from EXPERIMENT_REGISTRY +n_batches: 10 # Ensemble size +batch_size: 64 # Batch size for processing -- number of samples for each run in the ensemble + +# WandB configuration +wandb_enabled: false # Enable WandB logging +wandb_project: spd-cluster # WandB project name +intervals: + stat: 1 # for k_groups, merge_pair_cost, mdl_loss + tensor: 100 # for wandb_log_tensor and fraction_* calculations + plot: 100 # for calling the plotting callback + artifact: 100 # for calling the artifact callback + +# Optional: Override defaults (typically set via CLI args) +# base_path: .data/clustering/ # defaults to .data/clustering/ +# workers_per_device: 1 # defaults to 1 +# devices: ["cpu"] # defaults to ["cpu"], CLI will override with ["cuda"] if available \ No newline at end of file diff --git a/spd/clustering/configs/resid_mlp1.json b/spd/clustering/configs/resid_mlp1.json new file mode 100644 index 000000000..e825215ee --- /dev/null +++ b/spd/clustering/configs/resid_mlp1.json @@ -0,0 +1,24 @@ +{ + "merge_config": { + "activation_threshold": 0.01, + "alpha": 1, + "iters": null, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "pop_component_prob": 0, + "filter_dead_threshold": 0, + "module_name_filter": null + }, + "experiment_key": "resid_mlp1", + "distances_method": "perm_invariant_hamming", + "n_batches": 8, + "batch_size": 1024, + "wandb_enabled": true, + "wandb_project": "spd-cluster", + "intervals": { + "stat": 1, + "tensor": 5, + "plot": 5, + "artifact": 5 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/resid_mlp2.json b/spd/clustering/configs/resid_mlp2.json new file mode 100644 index 000000000..2be350979 --- /dev/null +++ b/spd/clustering/configs/resid_mlp2.json @@ -0,0 +1,23 @@ +{ + "merge_config": { + "activation_threshold": 0.01, + "alpha": 1, + "iters": 100, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "pop_component_prob": 0, + "filter_dead_threshold": 0.01, + "module_name_filter": null + }, + "experiment_key": "resid_mlp2", + "n_batches": 16, + "batch_size": 1024, + "wandb_enabled": true, + "wandb_project": "spd-cluster", + "intervals": { + "stat": 1, + "tensor": 5, + "plot": 5, + "artifact": 50 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/resid_mlp3.json b/spd/clustering/configs/resid_mlp3.json new file mode 100644 index 000000000..5d87e08d5 --- /dev/null +++ b/spd/clustering/configs/resid_mlp3.json @@ -0,0 +1,23 @@ +{ + "merge_config": { + "activation_threshold": 0.01, + "alpha": 1, + "iters": 350, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "pop_component_prob": 0, + "filter_dead_threshold": 0.01, + "module_name_filter": null + }, + "experiment_key": "resid_mlp3", + "n_batches": 4, + "batch_size": 1024, + "wandb_enabled": true, + "wandb_project": "spd-cluster", + "intervals": { + "stat": 1, + "tensor": 32, + "plot": 32, + "artifact": 32 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/simplestories_dev.json new file mode 100644 index 000000000..c82b11710 --- /dev/null +++ b/spd/clustering/configs/simplestories_dev.json @@ -0,0 +1,25 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1.0, + "iters": null, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "pop_component_prob": 0, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd-pre-Sep-2025/runs/rn9klzfs", + "task_name": "lm", + "distances_method": "jaccard", + "n_batches": 1, + "batch_size": 32, + "wandb_enabled": true, + "wandb_project": "spd-cluster", + "intervals": { + "stat": 1, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/test-resid_mlp1.json new file mode 100644 index 000000000..75877dd25 --- /dev/null +++ b/spd/clustering/configs/test-resid_mlp1.json @@ -0,0 +1,24 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1, + "iters": 140, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "pop_component_prob": 0, + "filter_dead_threshold": 0.1, + "module_name_filter": null, + "rank_cost_fn_name": "const_1" + }, + "experiment_key": "resid_mlp1", + "n_batches": 2, + "batch_size": 100, + "wandb_enabled": true, + "wandb_project": "spd-cluster", + "intervals": { + "stat": 1, + "tensor": 5, + "plot": 10, + "artifact": 10 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/test-simplestories.json b/spd/clustering/configs/test-simplestories.json new file mode 100644 index 000000000..377eb6af1 --- /dev/null +++ b/spd/clustering/configs/test-simplestories.json @@ -0,0 +1,24 @@ +{ + "merge_config": { + "activation_threshold": 0.5, + "alpha": 1.0, + "iters": 5, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "pop_component_prob": 0, + "filter_dead_threshold": 0.9, + "module_name_filter": "model.layers.0" + }, + "model_path": "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh", + "task_name": "lm", + "n_batches": 1, + "batch_size": 1, + "wandb_enabled": true, + "wandb_project": "spd-cluster", + "intervals": { + "stat": 1, + "tensor": 2, + "plot": 3, + "artifact": 4 + } +} \ No newline at end of file diff --git a/spd/clustering/consts.py b/spd/clustering/consts.py new file mode 100644 index 000000000..ab824b8d7 --- /dev/null +++ b/spd/clustering/consts.py @@ -0,0 +1,48 @@ +"""Constants and shared abstractions for clustering pipeline.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Literal, NewType + +import numpy as np +from jaxtyping import Bool, Float, Int +from torch import Tensor + +# Merge arrays and distances (numpy-based for storage/analysis) +MergesAtIterArray = Int[np.ndarray, "n_ens n_components"] +MergesArray = Int[np.ndarray, "n_ens n_iters n_components"] +DistancesMethod = Literal["perm_invariant_hamming", "jaccard"] +DistancesArray = Float[np.ndarray, "n_iters n_ens n_ens"] + +# Component and label types (NewType for stronger type safety) +ComponentLabel = NewType("ComponentLabel", str) # Format: "module_name:component_index" +ComponentLabels = NewType("ComponentLabels", list[str]) +BatchId = NewType("BatchId", str) + +# Path types +WandBPath = NewType("WandBPath", str) # Format: "wandb:entity/project/run_id" + +# Merge types +MergePair = NewType("MergePair", tuple[int, int]) + +# Tensor type aliases (torch-based for computation - TypeAlias for jaxtyping compatibility) +ActivationsTensor = Float[Tensor, "samples n_components"] +BoolActivationsTensor = Bool[Tensor, "samples n_components"] +ClusterCoactivationShaped = Float[Tensor, "k_groups k_groups"] +GroupIdxsTensor = Int[Tensor, " n_components"] +BatchTensor = Int[Tensor, "batch_size seq_len"] + + +class SaveableObject(ABC): + """Abstract base class for objects that can be saved to and loaded from disk.""" + + @abstractmethod + def save(self, path: Path) -> None: + """Save the object to disk at the given path.""" + ... + + @classmethod + @abstractmethod + def read(cls, path: Path) -> "SaveableObject": + """Load the object from disk at the given path.""" + ... diff --git a/spd/clustering/math/__init__.py b/spd/clustering/math/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/clustering/math/jaccard.py b/spd/clustering/math/jaccard.py new file mode 100644 index 000000000..0c0b1a284 --- /dev/null +++ b/spd/clustering/math/jaccard.py @@ -0,0 +1,71 @@ +"""jaccard index between clusterings + + +we start with a matrix X: Int[np.ndarray, "s n"] where each of the s rows is a label vector of length n +we want to compute a Float["s s"] matrix $J$ of pairwise jaccard indices between the rows of X + +jaccard index between two partitions A and B is defined as: +J(A, B) = M11 / (M11 + M10 + M01) + +where: +- M11 = number of pairs clustered together in both partitions +- M10 = number of pairs clustered together in A but not in B +- M01 = number of pairs clustered together in B but not in A + +""" + +# %% +import matplotlib.pyplot as plt +import torch +from jaxtyping import Bool, Float, Int +from muutils.dbg import dbg_auto +from torch import Tensor + + +def jaccard_index( + X: Int[Tensor, "s n"], +) -> Float[Tensor, "s s"]: + """Compute the pairwise jaccard index between rows of X""" + + s_ensemble, _n_components = X.shape + dbg_auto(X) + matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] + dbg_auto(matches) + + _jaccard: Float[Tensor, "s s"] = torch.full((s_ensemble, s_ensemble), torch.nan) + + for i in range(s_ensemble): + plt.matshow(matches[i].cpu().numpy()) + plt.title(f"matches for row {i}") + plt.show() + + # for i in range(s_ensemble): + # for j in range(i, s_ensemble): + # M11: int = int((matches[i] & matches[j]).sum() - n_components) // 2 + # M10: int = int((matches[i] & ~matches[j]).sum()) // 2 + # M01: int = int((~matches[i] & matches[j]).sum()) // 2 + # if M11 + M10 + M01 == 0: + # jaccard[i, j] = float("nan") + # else: + # jaccard[i, j] = M11 / (M11 + M10 + M01) + # jaccard[j, i] = jaccard[i, j] + # dbg_auto(i, j, M11, M10, M01, jaccard[i, j]) + + return _jaccard + + +jaccard_index( + torch.tensor( + [ + # [1, 2, 3, 3], + [0, 1, 1, 2, 3, 3], + [3, 0, 0, 1, 2, 2], + [0, 3, 1, 1, 2, 2], + [0, 3, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 0], + # [0, 1, 2, 3], + ] + ) +) + +# dbg(X - z[0]) diff --git a/spd/clustering/math/jaccard_test.py b/spd/clustering/math/jaccard_test.py new file mode 100644 index 000000000..9322abe3c --- /dev/null +++ b/spd/clustering/math/jaccard_test.py @@ -0,0 +1,194 @@ +"""jaccard index between clusterings + + +we start with a matrix X: Int[np.ndarray, "k n"] where each of the k rows is a label vector of length n +we want to compute a Float["k k"] matrix $J$ of pairwise jaccard indices between the rows of X + +jaccard index between two partitions A and B is defined as: +J(A, B) = M11 / (M11 + M10 + M01) + +where: +- M11 = number of pairs clustered together in both partitions +- M10 = number of pairs clustered together in A but not in B +- M01 = number of pairs clustered together in B but not in A + +""" + +# %% +import matplotlib.pyplot as plt +import torch +from jaxtyping import Bool, Float, Int +from muutils.dbg import dbg +from torch import Tensor + +# def per_row_label_counts(X: Int[Tensor, "k n"]) -> list[Tensor]: +# """Return a list of 1D count arrays, one per row.""" +# return [ +# torch.bincount(x) +# for x in X +# ] + + +def process_singletons( + x: Int[Tensor, " n"], +) -> tuple[Int[Tensor, " n"], int]: + """relabel anything in a singleton cluster to -1, relabel other clusters to minimize labels""" + assert (x >= 0).all(), "input labels must be non-negative" + # figure out where the singletons are + counts: Int[Tensor, " k"] = torch.bincount(x) + singleton_mask: Bool[Tensor, " k"] = counts == 1 + + x_relabel: Int[Tensor, " n"] = x.clone() + dbg(x) + dbg(singleton_mask) + dbg(singleton_mask[x]) + dbg(x_relabel) + dbg(x_relabel[singleton_mask[x]]) + + # map singletons to -1 + x_relabel[singleton_mask[x]] = -1 + dbg(x_relabel) + + # map every non `-1` label to a new label + non_singleton_labels: Int[Tensor, " m"] = x_relabel[~singleton_mask[x]].unique() + dbg(non_singleton_labels) + n_unique_nonsingleton_labels: int = non_singleton_labels.shape[0] + dbg(n_unique_nonsingleton_labels) + old_to_new: dict[int, int] = { + old: new for new, old in enumerate(sorted(non_singleton_labels.tolist())) + } + dbg(old_to_new) + + for old, new in old_to_new.items(): + x_relabel[x == old] = new + dbg(x_relabel) + + return x_relabel, n_unique_nonsingleton_labels + + +# X_1 = torch.tensor([0, 3, 3, 2, 4, 0, 5, 6, 7, 7, 7]) +# X_2 = torch.tensor([1, 1, 2, 3, 3, 1, 4, 5, 6, 6, 6]) +# dbg(X_1) +# process_singletons(X_1) + + +# def to_matrix( +# self, device: torch.device | None = None +# ) -> Bool[Tensor, "k_groups n_components"]: +# if device is None: +# device = self.group_idxs.device +# mat: Bool[Tensor, "k_groups n_components"] = torch.zeros( +# (self.k_groups, self._n_components), dtype=torch.bool, device=device +# ) +# idxs: Int[Tensor, " n_components"] = torch.arange( +# self._n_components, device=device, dtype=torch.int +# ) +# mat[self.group_idxs.to(dtype=torch.int), idxs] = True +# return mat + + +def expand_to_onehot( + x: Int[Tensor, " n"], + k_groups: int, +) -> Bool[Tensor, " k_groups+1 n_components"]: + """expand a label (possibly having -1s) vector to a one-hot matrix""" + n_components: int = x.shape[0] + + # add 1 as -1 will map to last index and be ignored + mat: Bool[Tensor, " k_groups n_components"] = torch.zeros( + (k_groups + 1, n_components), dtype=torch.bool + ) + idxs: Int[Tensor, " n_components"] = torch.arange(n_components, dtype=torch.int) + mat[x.to(dtype=torch.int), idxs] = True + return mat + + +def show_matrix(mat: Tensor, title: str = "", cmap: str = "viridis") -> None: + """Display a matrix with values annotated on each cell.""" + mat_np = mat.cpu().numpy() + _fig, ax = plt.subplots() + im = ax.matshow(mat_np, cmap=cmap) + + # Add text annotations + for i in range(mat_np.shape[0]): + for j in range(mat_np.shape[1]): + ax.text( + j, + i, + f"{mat_np[i, j]:.2f}", + ha="center", + va="center", + color="white" if mat_np[i, j] < mat_np.max() / 2 else "black", + ) + + if title: + plt.title(title) + plt.colorbar(im, ax=ax) + plt.show() + + +# plt.imshow(expand_to_onehot(*process_singletons(X_1))) +# plt.show() +# plt.imshow(expand_to_onehot(*process_singletons(X_2))) +# plt.show() + + +def jaccard_index( + X: Int[Tensor, " s n"], +) -> Float[Tensor, " s s"]: + """compute pairwise jaccard indices between rows of X""" + s: int + _n: int + s, _n = X.shape + + X_expanded_list: list[Int[Tensor, " k n"]] = [ + expand_to_onehot(*process_singletons(X[i])) for i in range(s) + ] + + # compute jaccard for each pair of rows + # jaccard: dict[ + # tuple[int, int], # key is (i, j) from the rows of X + # fl + # # Int[Tensor, " k_i k_j"], # value at (p, q) is jaccard index between two clusters + # ] = {} + jaccard: Float[Tensor, " s s"] = torch.full((s, s), fill_value=torch.nan) + for i in range(s): + for j in range(i, s): + X_i: Int[Tensor, " k_i n"] = X_expanded_list[i].to(torch.int16) + X_j: Int[Tensor, " k_j n"] = X_expanded_list[j].to(torch.int16) + intersects: Int[Tensor, " k_i k_j"] = X_i @ X_j.T + unions: Int[Tensor, " k_i k_j"] = ( + X_i.sum(dim=1, keepdim=True) + X_j.sum(dim=1, keepdim=True).T - intersects + ) + jaccard_mat: Int[Tensor, " k_i k_j"] = intersects / unions + + show_matrix( + X_i, title=f"One-hot matrix for row {i} of X\nshape={X_i.shape}", cmap="Blues" + ) + show_matrix( + X_j, title=f"One-hot matrix for row {j} of X\nshape={X_j.shape}", cmap="Blues" + ) + show_matrix( + jaccard_mat, + title=f"Gram matrix between row {i} and row {j}\n$[{jaccard_mat.min():.2f}, {jaccard_mat.max():.2f}]$", + ) + + # jaccard[i, j] = jaccard_mat.mean() + + return jaccard + + +jaccard_index( + torch.tensor( + [ + # [1, 2, 3, 3], + [0, 1, 1, 2, 3, 3], + [0, 1, 1, 1, 2, 2], + # [0, 0, 0, 0, 1, 1], + [0, 0, 0, 0, 0, 0], + # [0, 1, 2, 3], + ] + ) +) + +# dbg(X - z[0]) diff --git a/spd/clustering/math/merge_distances.py b/spd/clustering/math/merge_distances.py new file mode 100644 index 000000000..3d9215972 --- /dev/null +++ b/spd/clustering/math/merge_distances.py @@ -0,0 +1,53 @@ +from collections.abc import Callable + +import numpy as np +from jaxtyping import Float, Int +from muutils.parallel import run_maybe_parallel + +from spd.clustering.consts import ( + DistancesArray, + DistancesMethod, + MergesArray, + MergesAtIterArray, +) + +# from spd.clustering.math.jaccard import jaccard_partition_matrix +from spd.clustering.math.perm_invariant_hamming import perm_invariant_hamming_matrix + +DISTANCES_METHODS: dict[DistancesMethod, Callable[[MergesAtIterArray], DistancesArray]] = { + "perm_invariant_hamming": perm_invariant_hamming_matrix, + # "jaccard": jaccard_partition_matrix, +} + +# pyright: reportUnnecessaryComparison=false, reportUnreachable=false + + +def compute_distances( + normalized_merge_array: MergesArray, + method: DistancesMethod = "perm_invariant_hamming", +) -> DistancesArray: + n_iters: int = normalized_merge_array.shape[1] + merges_array_list: list[Int[np.ndarray, "n_ens n_components"]] + distances_list: list[Float[np.ndarray, "n_ens n_ens"]] + match method: + case "perm_invariant_hamming": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + + distances_list = run_maybe_parallel( + func=perm_invariant_hamming_matrix, + iterable=merges_array_list, + parallel=True, + ) + + return np.stack(distances_list, axis=0) + case "jaccard": + raise NotImplementedError("Jaccard distance computation is not implemented.") + # merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + # distances_list = run_maybe_parallel( + # func=jaccard_partition_matrix, + # iterable=merges_array_list, + # parallel=True, + # ) + # return np.stack(distances_list, axis=0) + case _: + raise ValueError(f"Unknown distance method: {method}") diff --git a/spd/clustering/math/merge_matrix.py b/spd/clustering/math/merge_matrix.py new file mode 100644 index 000000000..118f575e2 --- /dev/null +++ b/spd/clustering/math/merge_matrix.py @@ -0,0 +1,283 @@ +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Int +from muutils.tensor_info import array_summary +from torch import Tensor + +from spd.clustering.consts import GroupIdxsTensor + +# pyright: reportUnnecessaryTypeIgnoreComment=false + + +@dataclass(kw_only=True, slots=True) +class GroupMerge: + """Canonical component-to-group assignment. + + `group_idxs` is a length-`n_components` integer tensor; entry `c` + gives the group index (0 to `k_groups-1`) that contains component `c`. + """ + + group_idxs: GroupIdxsTensor + k_groups: int + old_to_new_idx: dict[int | None, int | None] | None = None + + def summary(self) -> dict[str, int | str | None]: + return dict( + group_idxs=array_summary(self.group_idxs, as_list=False), # pyright: ignore[reportCallIssue] + k_groups=self.k_groups, + old_to_new_idx=f"len={len(self.old_to_new_idx)}" + if self.old_to_new_idx is not None + else None, + ) + + @property + def _n_components(self) -> int: + return int(self.group_idxs.shape[0]) + + @property + def components_per_group(self) -> Int[Tensor, " k_groups"]: + return torch.bincount(self.group_idxs, minlength=self.k_groups) + + def components_in_group_mask(self, group_idx: int) -> Bool[Tensor, " n_components"]: + """Returns a boolean mask for components in the specified group.""" + if group_idx < 0 or group_idx >= self.k_groups: + raise ValueError("group index out of range") + return self.group_idxs == group_idx + + def components_in_group(self, group_idx: int) -> list[int]: + """Returns a list of component indices in the specified group.""" + indices: Int[Tensor, " n_matches"] = ( + (self.group_idxs == group_idx).nonzero(as_tuple=False).squeeze(-1) + ) + return indices.tolist() + + def validate(self, *, require_nonempty: bool = True) -> None: + v_min: int = int(self.group_idxs.min().item()) + v_max: int = int(self.group_idxs.max().item()) + if v_min < 0 or v_max >= self.k_groups: + raise ValueError("group indices out of range") + + if require_nonempty: + has_empty_groups: bool = bool(self.components_per_group.eq(0).any().item()) + if has_empty_groups: + raise ValueError("one or more groups are empty") + + def to_matrix( + self, device: torch.device | None = None + ) -> Bool[Tensor, "k_groups n_components"]: + if device is None: + device = self.group_idxs.device + mat: Bool[Tensor, "k_groups n_components"] = torch.zeros( + (self.k_groups, self._n_components), dtype=torch.bool, device=device + ) + idxs: Int[Tensor, " n_components"] = torch.arange( + self._n_components, device=device, dtype=torch.int + ) + mat[self.group_idxs.to(dtype=torch.int), idxs] = True + return mat + + @classmethod + def from_matrix(cls, mat: Bool[Tensor, "k_groups n_components"]) -> "GroupMerge": + if mat.dtype is not torch.bool: + raise TypeError("mat must have dtype bool") + if not mat.sum(dim=0).eq(1).all(): + raise ValueError("each column must contain exactly one True") + group_idxs: GroupIdxsTensor = mat.argmax(dim=0).to(torch.int64) + inst: GroupMerge = cls(group_idxs=group_idxs, k_groups=int(mat.shape[0])) + inst.validate(require_nonempty=False) + return inst + + @classmethod + def random( + cls, + n_components: int, + k_groups: int, + *, + ensure_groups_nonempty: bool = False, + device: torch.device | str = "cpu", + ) -> "GroupMerge": + if ensure_groups_nonempty and n_components < k_groups: + raise ValueError("n_components must be >= k_groups when ensure_groups_nonempty is True") + + group_idxs: GroupIdxsTensor + + if ensure_groups_nonempty: + base: Int[Tensor, " k_groups"] = torch.arange(k_groups, device=device) + if n_components > k_groups: + extra: Int[Tensor, " n_extra"] = torch.randint( + 0, k_groups, (n_components - k_groups,), device=device + ) + group_idxs = torch.cat((base, extra)) + group_idxs = group_idxs[torch.randperm(n_components, device=device)] + else: + group_idxs = base + else: + group_idxs = torch.randint(0, k_groups, (n_components,), device=device) + inst: GroupMerge = cls(group_idxs=group_idxs, k_groups=k_groups) + inst.validate(require_nonempty=ensure_groups_nonempty) + return inst + + @classmethod + def identity(cls, n_components: int) -> "GroupMerge": + """Creates a GroupMerge where each component is its own group.""" + return cls( + group_idxs=torch.arange(n_components, dtype=torch.int64), + k_groups=n_components, + ) + + def merge_groups(self, group_a: int, group_b: int) -> "GroupMerge": + """Merges two groups into one, returning a new GroupMerge.""" + if group_a < 0 or group_b < 0 or group_a >= self.k_groups or group_b >= self.k_groups: + raise ValueError("group indices out of range") + if group_a == group_b: + raise ValueError("Cannot merge a group with itself") + + # make sure group_a is the smaller index + if group_a > group_b: + group_a, group_b = group_b, group_a + + # make a copy + new_idxs: GroupIdxsTensor = self.group_idxs.clone() + # wherever its currently b, change it to a + new_idxs[new_idxs == group_b] = group_a + # wherever i currently above b, change it to i-1 + new_idxs[new_idxs > group_b] -= 1 + # create a new GroupMerge instance + merged: GroupMerge = GroupMerge(group_idxs=new_idxs, k_groups=self.k_groups - 1) + + # create a mapping from old to new group indices + # `None` as a key is for the new group that contains both a and b + # values of a and b are mapped to `None` since they are merged + old_to_new_idx: dict[int | None, int | None] = dict() + for i in range(self.k_groups): + if i in {group_a, group_b}: + old_to_new_idx[i] = None + elif i <= group_b: + old_to_new_idx[i] = i + else: + old_to_new_idx[i] = i - 1 + old_to_new_idx[None] = group_a # the new group index for the merged group + + # HACK: store the mapping in the instance for later use + merged.old_to_new_idx = old_to_new_idx # type: ignore[assignment] + + # validate the new instance + # merged.validate(require_nonempty=True) + return merged + + def all_downstream_merged(self) -> "BatchedGroupMerge": + downstream: list[GroupMerge] = [] + idxs: list[tuple[int, int]] = [] + for i in range(self.k_groups): + for j in range(i + 1, self.k_groups): + downstream.append(self.merge_groups(i, j)) + idxs.append((i, j)) + + return BatchedGroupMerge.from_list(merge_matrices=downstream) + + +@dataclass(slots=True) +class BatchedGroupMerge: + """Batch of merge matrices. + + `group_idxs` has shape `(batch, n_components)`; each row holds the + group index for every component in that matrix. + """ + + group_idxs: Int[Tensor, "batch n_components"] + k_groups: Int[Tensor, " batch"] + + def summary(self) -> dict[str, int | str | None]: + return dict( + group_idxs=array_summary(self.group_idxs, as_list=False), # pyright: ignore[reportCallIssue] + k_groups=array_summary(self.k_groups, as_list=False), # pyright: ignore[reportCallIssue] + # TODO: re-add metadata (which pairs merged at each step) + # meta=f"len={len(self.meta)}" if self.meta is not None else None, + ) + + @classmethod + def init_empty(cls, batch_size: int, n_components: int) -> "BatchedGroupMerge": + """Initialize an empty BatchedGroupMerge with the given batch size and number of components.""" + return cls( + group_idxs=torch.full((batch_size, n_components), -1, dtype=torch.int16), + k_groups=torch.zeros(batch_size, dtype=torch.int16), + ) + + @property + def _batch_size(self) -> int: + return int(self.group_idxs.shape[0]) + + @property + def _n_components(self) -> int: + return int(self.group_idxs.shape[1]) + + @property + def k_groups_unique(self) -> int: + """Returns the number of groups across all matrices, throws exception if they differ.""" + k_groups_set: set[int] = set(self.k_groups.tolist()) + if len(k_groups_set) != 1: + raise ValueError("All matrices must have the same number of groups") + return k_groups_set.pop() + + def to_matrix( + self, device: torch.device | None = None + ) -> Bool[Tensor, "batch k_groups n_components"]: + if device is None: + device = self.group_idxs.device + k_groups_u: int = self.k_groups_unique + mat = torch.nn.functional.one_hot(self.group_idxs, num_classes=k_groups_u) + return mat.permute(0, 2, 1).to(device=device, dtype=torch.bool) + + @classmethod + def from_matrix(cls, mat: Bool[Tensor, "batch k_groups n_components"]) -> "BatchedGroupMerge": + if mat.dtype is not torch.bool: + raise TypeError("mat must have dtype bool") + if not mat.sum(dim=1).eq(1).all(): + raise ValueError("each column must have exactly one True per matrix") + group_idxs = mat.argmax(dim=1).to(torch.int64) + batch_size: int = int(mat.shape[0]) + inst = cls( + group_idxs=group_idxs, + k_groups=torch.full((batch_size,), int(mat.shape[1]), dtype=torch.int64), + ) + # inst.validate(require_nonempty=False) + return inst + + @classmethod + def from_list( + cls, + merge_matrices: list[GroupMerge], + ) -> "BatchedGroupMerge": + group_idxs: Int[Tensor, "batch n_components"] = torch.stack( + [mm.group_idxs for mm in merge_matrices], dim=0 + ) + k_groups: Int[Tensor, " batch"] = torch.tensor( + [mm.k_groups for mm in merge_matrices], dtype=torch.int64 + ) + inst: BatchedGroupMerge = cls(group_idxs=group_idxs, k_groups=k_groups) + # inst.validate(require_nonempty=False) + return inst + + def __getitem__(self, idx: int) -> GroupMerge: + if not (0 <= idx < self._batch_size): + raise IndexError("index out of range") + group_idxs: GroupIdxsTensor = self.group_idxs[idx] + k_groups: int = int(self.k_groups[idx].item()) + return GroupMerge(group_idxs=group_idxs, k_groups=k_groups) + + def __setitem__(self, idx: int, value: GroupMerge) -> None: + if not (0 <= idx < self._batch_size): + raise IndexError("index out of range") + if value._n_components != self._n_components: + raise ValueError("value must have the same number of components as the batch") + self.group_idxs[idx] = value.group_idxs + self.k_groups[idx] = value.k_groups + + def __iter__(self): + """Iterate over the GroupMerge instances in the batch.""" + for i in range(self._batch_size): + yield self[i] + + def __len__(self) -> int: + return self._batch_size diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py new file mode 100644 index 000000000..24c050d36 --- /dev/null +++ b/spd/clustering/math/merge_pair_samplers.py @@ -0,0 +1,121 @@ +import random +from typing import Any, Literal, Protocol + +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, MergePair + +MergePairSamplerKey = Literal["range", "mcmc"] + + +class MergePairSamplerConfigurable(Protocol): + def __call__( + self, + costs: ClusterCoactivationShaped, + **kwargs: Any, + ) -> MergePair: ... + + +class MergePairSampler(Protocol): + def __call__( + self, + costs: ClusterCoactivationShaped, + ) -> MergePair: ... + + +def range_sampler( + costs: ClusterCoactivationShaped, + threshold: float = 0.05, + **kwargs: Any, +) -> MergePair: + """Sample a merge pair using threshold-based range selection. + + Considers all pairs with costs below a threshold defined as a fraction + of the range of non-diagonal costs, then randomly selects one. + + Args: + costs: Cost matrix for all possible merges + k_groups: Number of current groups + threshold: Fraction of cost range to consider (0=min only, 1=all pairs) + + Returns: + Tuple of (group_i, group_j) indices to merge + """ + assert not kwargs + k_groups: int = costs.shape[0] + assert costs.shape[1] == k_groups, "Cost matrix must be square" + + # Find the range of non-diagonal costs + non_diag_costs: Float[Tensor, " k_groups_squared_minus_k"] = costs[ + ~torch.eye(k_groups, dtype=torch.bool, device=costs.device) + ] + min_cost: float = float(non_diag_costs.min().item()) + max_cost: float = float(non_diag_costs.max().item()) + + # Calculate threshold cost + max_considered_cost: float = (max_cost - min_cost) * threshold + min_cost + + # Find all pairs below threshold + considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack( + torch.where(costs <= max_considered_cost), dim=1 + ) + # Remove diagonal entries (i == j) + considered_idxs = considered_idxs[considered_idxs[:, 0] != considered_idxs[:, 1]] + + # Randomly select one of the considered pairs + selected_idx: int = random.randint(0, considered_idxs.shape[0] - 1) + pair_tuple: tuple[int, int] = tuple(considered_idxs[selected_idx].tolist()) # type: ignore[assignment] + return MergePair(pair_tuple) + + +def mcmc_sampler( + costs: ClusterCoactivationShaped, + temperature: float = 1.0, + **kwargs: Any, +) -> MergePair: + """Sample a merge pair using MCMC with probability proportional to exp(-cost/temperature). + + Args: + costs: Cost matrix for all possible merges + k_groups: Number of current groups + temperature: Temperature parameter for softmax (higher = more uniform sampling) + + Returns: + Tuple of (group_i, group_j) indices to merge + """ + assert not kwargs + k_groups: int = costs.shape[0] + assert costs.shape[1] == k_groups, "Cost matrix must be square" + + # Create mask for valid pairs (non-diagonal) + valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( + k_groups, dtype=torch.bool, device=costs.device + ) + + # Compute probabilities: exp(-cost/temperature) + # Use stable softmax computation to avoid overflow + costs_masked: ClusterCoactivationShaped = costs.clone() + costs_masked[~valid_mask] = float("inf") # Set diagonal to inf so exp gives 0 + + # Subtract min for numerical stability + min_cost: float = float(costs_masked[valid_mask].min()) + probs: ClusterCoactivationShaped = ( + torch.exp((min_cost - costs_masked) / temperature) * valid_mask + ) # Zero out diagonal + probs_flatten: Float[Tensor, " k_groups_squared"] = probs.flatten() + probs_flatten = probs_flatten / probs_flatten.sum() + + # Sample from multinomial distribution + idx: int = int(torch.multinomial(probs_flatten, 1).item()) + row: int = idx // k_groups + col: int = idx % k_groups + + return MergePair((row, col)) + + +MERGE_PAIR_SAMPLERS: dict[MergePairSamplerKey, MergePairSamplerConfigurable] = { + "range": range_sampler, + "mcmc": mcmc_sampler, +} diff --git a/spd/clustering/math/perm_invariant_hamming.py b/spd/clustering/math/perm_invariant_hamming.py new file mode 100644 index 000000000..e70d3c7c0 --- /dev/null +++ b/spd/clustering/math/perm_invariant_hamming.py @@ -0,0 +1,70 @@ +import warnings + +import numpy as np +from jaxtyping import Float, Int +from scipy.optimize import linear_sum_assignment + + +def perm_invariant_hamming_matrix( + X: Int[np.ndarray, "n_ens n_components"], +) -> Float[np.ndarray, "n_ens n_ens"]: + """Compute all pairwise permutation-invariant Hamming distances. + + The strictly lower-triangular entries are filled with distances; + the diagonal and upper triangle are left as `np.nan`. + + # Parameters: + - `X : Int[np.ndarray, "n_ens n_components"]` + Matrix where each of the `n_ens` rows is a label vector of length `n_components`. + + # Returns: + - `Float[np.ndarray, "n_ens n_ens"]` + Distance matrix `D` with `D[i, j]` defined only for `i > j`; + all other positions are `np.nan`. + + # Usage: + ```python + >>> X = np.array([[0, 0, 1], + ... [1, 1, 0], + ... [0, 1, 0]]) + >>> D = perm_invariant_hamming_matrix(X) + >>> D + array([[nan, nan, nan], + [ 0., nan, nan], + [ 2., 2., nan]]) + ``` + """ + n_ens: int + n_components: int + n_ens, n_components = X.shape + D: Float[np.ndarray, "n_ens n_ens"] = np.full((n_ens, n_ens), np.nan, dtype=float) + + # Pre-compute max label in each row once. + row_max: Int[np.ndarray, " n_ens"] = X.max(axis=1) + + for i in range(1, n_ens): + a: Int[np.ndarray, " n_components"] = X[i] + for j in range(i): + b: Int[np.ndarray, " n_components"] = X[j] + + k_lbls: int = int(max(row_max[i], row_max[j]) + 1) + + # Handle case where all labels are -1 (no valid clustering) + if k_lbls <= 0: + warnings.warn( + f"All labels are -1 at rows {i} and {j}. Setting distance to 0.", + UserWarning, + stacklevel=2, + ) + D[i, j] = 0.0 + continue + + C: Int[np.ndarray, "k_lbls k_lbls"] = np.zeros((k_lbls, k_lbls), dtype=int) + np.add.at(C, (a, b), 1) + + row_ind, col_ind = linear_sum_assignment(-C) + matches: int = int(C[row_ind, col_ind].sum()) + + D[i, j] = n_components - matches # int is fine; array is float because of NaN + + return D diff --git a/spd/clustering/math/semilog.py b/spd/clustering/math/semilog.py new file mode 100644 index 000000000..a17ba63b5 --- /dev/null +++ b/spd/clustering/math/semilog.py @@ -0,0 +1,13 @@ +import math + + +def semilog( + value: float, + epsilon: float = 1e-3, +) -> float: + if abs(value) < epsilon: + return value + else: + sign: int = 1 if value >= 0 else -1 + # log10 here is safe, since we know the value is not close to zero + return sign * epsilon * math.log1p(abs(value) / epsilon) diff --git a/spd/clustering/math/tensor_stats.py b/spd/clustering/math/tensor_stats.py new file mode 100644 index 000000000..4080b9795 --- /dev/null +++ b/spd/clustering/math/tensor_stats.py @@ -0,0 +1,160 @@ +from typing import Literal + +import torch +from jaxtyping import Float +from torch import Tensor + +StatsKey = Literal[ + "mean", + "std", + "median", + "min", + "max", + "q01", + "q05", + "q10", + "q25", + "q50", + "q75", + "q90", + "q95", + "q99", + "chosen_pair", +] + + +def _flatten_if_needed(x: Tensor) -> Tensor: + """Make x 1D without copy when possible.""" + x_flat: Tensor = x.reshape(-1) + return x_flat + + +def _approx_quantile( + x: Tensor, + qs: Float[Tensor, " n_quantiles"], + *, + max_elems: int = 5_000_000, + generator: torch.Generator | None = None, +) -> Float[Tensor, " n_quantiles"]: + """Approximate quantiles by subsampling if needed, else exact. + + If x.numel() > max_elems, draws a random subset of size max_elems (with replacement) + on the same device as x, then computes torch.quantile once for all qs. + """ + x1d: Tensor = _flatten_if_needed(x) + n: int = x1d.numel() + if n == 0: + raise ValueError("Empty tensor.") + if n > max_elems: + # sample with replacement to avoid materializing a giant permutation + g: torch.Generator | None = generator + idx: Tensor = torch.randint(0, n, (max_elems,), device=x1d.device, generator=g) + x_used: Tensor = x1d[idx] + else: + x_used = x1d + # Compute all quantiles in one shot to reuse the sort + q: Tensor = torch.quantile(x_used, qs, interpolation="linear") + return q + + +def _exact_quantile_all_at_once( + x: Tensor, qs: Float[Tensor, " n_quantiles"] +) -> Float[Tensor, " n_quantiles"]: + """Exact quantiles without repeated sorts.""" + x1d: Tensor = _flatten_if_needed(x) + q: Float[Tensor, " n_quantiles"] = torch.quantile(x1d, qs, interpolation="linear") + return q + + +def stats_dict( + data: Tensor, + *, + approx_if_large: bool = True, + max_elems_for_quantile: int = 5_000_000, + rng: torch.Generator | None = None, +) -> dict[StatsKey, float]: + """summary + + Compute common stats plus a set of quantiles. Uses a single quantile() call + for all requested quantiles; optionally switches to an approximate method + by subsampling when the input is very large to avoid RuntimeError. + + # Parameters: + - `data : Tensor` + Input tensor of any shape and dtype convertible to floating for stats. + - `approx_if_large : bool` + If True, use subsampling for quantiles when data is huge. (defaults to True) + - `max_elems_for_quantile : int` + Max elements before triggering approximate mode. (defaults to 5_000_000) + - `rng : torch.Generator | None` + Optional torch generator for reproducible subsampling. + + # Returns: + - `dict[StatsKey, float]` + Mapping from stat name to Python float. + + # Modifies: + - None + + # Usage: + + ```python + >>> x = torch.randn(50_000_000, device="cuda") + >>> out = stats_dict(x, approx_if_large=True, max_elems_for_quantile=5_000_000) + >>> out["q95"] + 1.64 + ``` + + # Raises: + - `ValueError` : if `data` is empty + """ + x: Tensor = data + if x.numel() == 0: + raise ValueError("Empty tensor.") + # Work in float for numerics, but keep device + xf: Tensor = x.float() + + # Fast exact ops that do not need the full sort + # std_mean does mean and std in one pass; aminmax does min and max together + std: Tensor + mean: Tensor + std, mean = torch.std_mean(xf) + mn: Tensor + mx: Tensor + mn, mx = torch.aminmax(xf) + + # median is a quantile; we can either reuse below or do .median() directly. + # We will get it from the quantiles call to avoid extra work. + q_values: Float[Tensor, " 9"] = torch.tensor( + [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99], + device=xf.device, + dtype=xf.dtype, + ) + qs_all: Float[Tensor, " 9"] + if approx_if_large: + qs_all = _approx_quantile( + xf, + q_values, + max_elems=max_elems_for_quantile, + generator=rng, + ) + else: + qs_all = _exact_quantile_all_at_once(xf, q_values) + + out: dict[StatsKey, float] = { + "mean": float(mean.item()), + "std": float(std.item()), + "median": float(qs_all[4].item()), # median is at index 4 + "min": float(mn.item()), + "max": float(mx.item()), + "q01": float(qs_all[0].item()), + "q05": float(qs_all[1].item()), + "q10": float(qs_all[2].item()), + "q25": float(qs_all[3].item()), + "q50": float(qs_all[4].item()), # median again + "q75": float(qs_all[5].item()), + "q90": float(qs_all[6].item()), + "q95": float(qs_all[7].item()), + "q99": float(qs_all[8].item()), + } + return out diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py new file mode 100644 index 000000000..3692e1687 --- /dev/null +++ b/spd/clustering/merge.py @@ -0,0 +1,246 @@ +""" +Merge iteration with logging support. + +This wraps the pure merge_iteration_pure() function and adds WandB/plotting callbacks. +""" + +import warnings +from typing import Protocol + +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor +from tqdm import tqdm + +from spd.clustering.compute_costs import ( + compute_mdl_cost, + compute_merge_costs, + recompute_coacts_merge_pair, + recompute_coacts_pop_group, +) +from spd.clustering.consts import ( + ActivationsTensor, + BoolActivationsTensor, + ClusterCoactivationShaped, + ComponentLabels, + MergePair, +) +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory + +_BATCH_PREFIX_FMT: str = "\033[38;5;208m[{batch_id}]\033[0m" + + +class LogCallback(Protocol): + def __call__( + self, + current_coact: ClusterCoactivationShaped, + component_labels: ComponentLabels, + current_merge: GroupMerge, + costs: ClusterCoactivationShaped, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: Float[Tensor, " k_groups"], + ) -> None: ... + + +def merge_iteration( + merge_config: MergeConfig, + activations: ActivationsTensor, + component_labels: ComponentLabels, + log_callback: LogCallback | None = None, + batch_id: str = "unk", +) -> MergeHistory: + """ + Merge iteration with optional logging/plotting callbacks. + + This wraps the pure computation with logging capabilities while maintaining + the same core algorithm logic. + """ + + # setup + # ================================================== + pbar_prefix: str = _BATCH_PREFIX_FMT.format(batch_id=batch_id) + + # compute coactivations + # -------------------------------------------------- + activation_mask_orig: BoolActivationsTensor | ActivationsTensor | None = ( + activations > merge_config.activation_threshold + if merge_config.activation_threshold is not None + else activations + ) + coact: Float[Tensor, "c c"] = activation_mask_orig.float().T @ activation_mask_orig.float() + + # check shapes + c_components: int = coact.shape[0] + assert coact.shape[1] == c_components, "Coactivation matrix must be square" + + # determine number of iterations based on config and number of components + num_iters: int = merge_config.get_num_iters(c_components) + + # pop logic setup + # -------------------------------------------------- + # for speed, we precompute whether to pop components and which components to pop + # if we are not popping, we don't need these variables and can also delete other things + do_pop: bool = merge_config.pop_component_prob > 0.0 + if do_pop: + # at each iteration, we will pop a component with probability `pop_component_prob` + iter_pop: Bool[Tensor, " iters"] = ( + torch.rand(num_iters, device=coact.device) < merge_config.pop_component_prob + ) + # we pick a subcomponent at random, and if we decide to pop, we pop that one out of its group + # if the component is a singleton, nothing happens. this naturally biases towards popping + # less at the start and more at the end, since the effective probability of popping a component + # is actually something like `pop_component_prob * (c_components - k_groups) / c_components` + pop_component_idx: Int[Tensor, " iters"] = torch.randint( + 0, c_components, (num_iters,), device=coact.device + ) + + # initialize vars + # -------------------------------------------------- + # start with an identity merge + current_merge: GroupMerge = GroupMerge.identity(n_components=c_components) + + # initialize variables for the merge process + k_groups: int = c_components + current_coact: ClusterCoactivationShaped = coact.clone() + current_act_mask: Bool[Tensor, "samples k_groups"] = activation_mask_orig.clone() + + # variables we keep track of + merge_history: MergeHistory = MergeHistory.from_config( + merge_config=merge_config, + labels=component_labels, + ) + + # free up memory + if not do_pop: + del coact + del activation_mask_orig + activation_mask_orig = None + + # merge iteration + # ================================================== + pbar: tqdm[int] = tqdm( + range(num_iters), + unit="iter", + total=num_iters, + ) + for iter_idx in pbar: + # pop components + # -------------------------------------------------- + if do_pop and iter_pop[iter_idx]: # pyright: ignore[reportPossiblyUnboundVariable] + # we split up the group which our chosen component belongs to + pop_component_idx_i: int = int(pop_component_idx[iter_idx].item()) # pyright: ignore[reportPossiblyUnboundVariable] + n_components_in_pop_grp: int = int( + current_merge.components_per_group[ # pyright: ignore[reportArgumentType] + current_merge.group_idxs[pop_component_idx_i].item() + ] + ) + + # but, if the component is the only one in its group, there is nothing to do + if n_components_in_pop_grp > 1: + current_merge, current_coact, current_act_mask = recompute_coacts_pop_group( + coact=current_coact, + merges=current_merge, + component_idx=pop_component_idx_i, + activation_mask=current_act_mask, + # this complains if `activation_mask_orig is None`, but this is only the case + # if `do_pop` is False, which it won't be here. we do this to save memory + activation_mask_orig=activation_mask_orig, # pyright: ignore[reportArgumentType] + ) + k_groups = current_coact.shape[0] + + # compute costs, figure out what to merge + # -------------------------------------------------- + # HACK: this is messy + costs: ClusterCoactivationShaped = compute_merge_costs( + coact=current_coact / current_act_mask.shape[0], + merges=current_merge, + alpha=merge_config.alpha, + ) + + merge_pair: MergePair = merge_config.merge_pair_sample(costs) + + # merge the pair + # -------------------------------------------------- + # we do this *before* logging, so we can see how the sampled pair cost compares + # to the costs of all the other possible pairs + current_merge, current_coact, current_act_mask = recompute_coacts_merge_pair( + coact=current_coact, + merges=current_merge, + merge_pair=merge_pair, + activation_mask=current_act_mask, + ) + + # metrics and logging + # -------------------------------------------------- + # Store in history + merge_history.add_iteration( + idx=iter_idx, + selected_pair=merge_pair, + current_merge=current_merge, + ) + + # Compute metrics for logging + # the MDL loss computed here is the *cost of the current merge*, a single scalar value + # rather than the *delta in cost from merging a specific pair* (which is what `costs` matrix contains) + diag_acts: Float[Tensor, " k_groups"] = torch.diag(current_coact) + mdl_loss: float = compute_mdl_cost( + acts=diag_acts, + merges=current_merge, + alpha=merge_config.alpha, + ) + mdl_loss_norm: float = mdl_loss / current_act_mask.shape[0] + # this is the cost for the selected pair + merge_pair_cost: float = float(costs[merge_pair].item()) + + # Update progress bar + pbar.set_description( + f"{pbar_prefix} k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}" + ) + + if log_callback is not None: + log_callback( + iter_idx=iter_idx, + current_coact=current_coact, + component_labels=component_labels, + current_merge=current_merge, + costs=costs, + merge_history=merge_history, + k_groups=k_groups, + merge_pair_cost=merge_pair_cost, + mdl_loss=mdl_loss, + mdl_loss_norm=mdl_loss_norm, + diag_acts=diag_acts, + ) + + # iterate and sanity checks + # -------------------------------------------------- + k_groups -= 1 + assert current_coact.shape[0] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_coact.shape[1] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_act_mask.shape[1] == k_groups, ( + "Activation mask shape should match number of groups" + ) + + # early stopping failsafe + # -------------------------------------------------- + if k_groups <= 3: + warnings.warn( + f"Stopping early at iteration {iter_idx} as only {k_groups} groups left", + stacklevel=2, + ) + break + + # finish up + # ================================================== + return merge_history diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py new file mode 100644 index 000000000..03c601a9f --- /dev/null +++ b/spd/clustering/merge_config.py @@ -0,0 +1,119 @@ +import functools +import hashlib +from typing import Any, Literal + +from pydantic import ( + BaseModel, + Field, + PositiveInt, +) + +from spd.clustering.consts import ClusterCoactivationShaped, MergePair +from spd.clustering.math.merge_pair_samplers import ( + MERGE_PAIR_SAMPLERS, + MergePairSampler, + MergePairSamplerKey, +) +from spd.clustering.util import ModuleFilterFunc, ModuleFilterSource +from spd.spd_types import Probability + +MergeConfigKey = Literal[ + "activation_threshold", + "alpha", + "iters", + "merge_pair_sampling_method", + "merge_pair_sampling_kwargs", + "pop_component_prob", + "filter_dead_threshold", +] + + +def _to_module_filter( + filter_modules: ModuleFilterSource, +) -> ModuleFilterFunc: + """Convert the filter_modules argument to a callable.""" + if filter_modules is None: + return lambda _: True + elif isinstance(filter_modules, str): + return lambda module_name: module_name.startswith(filter_modules) + elif isinstance(filter_modules, set): + return lambda module_name: module_name in filter_modules + elif callable(filter_modules): + return filter_modules + else: + raise TypeError(f"filter_modules must be str, set, or callable, got {type(filter_modules)}") # pyright: ignore[reportUnreachable] + + +class MergeConfig(BaseModel): + activation_threshold: Probability | None = Field( + default=0.01, + description="Threshold for considering a component active in a group. If None, use raw scalar causal importances", + ) + alpha: float = Field( + default=1.0, + description="rank weight factor. Higher values mean a higher penalty on 'sending' the component weights", + ) + iters: PositiveInt | None = Field( + default=100, + description="max number of iterations to run the merge algorithm for. If `None`, set to number of components (after filtering) minus one.", + ) + merge_pair_sampling_method: MergePairSamplerKey = Field( + default="range", + description="Method for sampling merge pairs. Options: 'range', 'mcmc'.", + ) + merge_pair_sampling_kwargs: dict[str, Any] = Field( + default_factory=lambda: {"threshold": 0.05}, + description="Keyword arguments for the merge pair sampling method.", + ) + pop_component_prob: Probability = Field( + default=0, + description="Probability of popping a component in each iteration. If 0, no components are popped.", + ) + filter_dead_threshold: float = Field( + default=0.001, + description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", + ) + module_name_filter: ModuleFilterSource = Field( + default=None, + description="Filter for module names. Can be a string prefix, a set of names, or a callable that returns True for modules to include.", + ) + + @property + def merge_pair_sample_func(self) -> MergePairSampler: + return functools.partial( + MERGE_PAIR_SAMPLERS[self.merge_pair_sampling_method], + **self.merge_pair_sampling_kwargs, + ) + + def merge_pair_sample( + self, + costs: ClusterCoactivationShaped, + ) -> MergePair: + """do merge sampling based on the configured method and kwargs + + has signature `MergePairSampler = Callable[[ClusterCoactivationShaped], MergePair]` + """ + return self.merge_pair_sample_func(costs=costs) + + @property + def filter_modules(self) -> ModuleFilterFunc: + """Get the module filter function based on the provided source.""" + return _to_module_filter(self.module_name_filter) + + def get_num_iters(self, n_components: int) -> PositiveInt: + """Get the number of iterations to run the merge algorithm for. + + Args: + n_components: Number of components (after filtering) + + Returns: + Number of iterations to run + """ + if self.iters is None: + return n_components - 1 + else: + return self.iters + + @property + def stable_hash(self) -> str: + return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6] diff --git a/spd/clustering/merge_history.py b/spd/clustering/merge_history.py new file mode 100644 index 000000000..39247d0b7 --- /dev/null +++ b/spd/clustering/merge_history.py @@ -0,0 +1,463 @@ +import io +import json +import zipfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any, override + +import numpy as np +import torch +from jaxtyping import Float, Int +from muutils.dbg import dbg_tensor + +from spd.clustering.consts import ( + ComponentLabels, + DistancesArray, + DistancesMethod, + MergePair, + MergesArray, + SaveableObject, +) +from spd.clustering.math.merge_distances import compute_distances +from spd.clustering.math.merge_matrix import BatchedGroupMerge, GroupMerge +from spd.clustering.merge_config import MergeConfig + + +@dataclass(frozen=True) +class IterationInfo: + """Information about a single merge iteration.""" + + idx: int + selected_pair: list[int] + merges: GroupMerge + + +def _zip_save_arr(zf: zipfile.ZipFile, name: str, arr: np.ndarray) -> None: + """Save a numpy array to a zip file.""" + buf: io.BytesIO = io.BytesIO() + np.save(buf, arr) + zf.writestr(name, buf.getvalue()) + + +def _zip_save_arr_dict(zf: zipfile.ZipFile, data: dict[str, np.ndarray]) -> None: + """Save a dictionary of numpy arrays to a zip file, {key}.npy used as path""" + key: str + arr: np.ndarray + for key, arr in data.items(): + _zip_save_arr(zf, f"{key}.npy", arr) + + +@dataclass(kw_only=True) +class MergeHistory(SaveableObject): + """Track merge iteration history""" + + merges: BatchedGroupMerge + selected_pairs: Int[np.ndarray, " n_iters 2"] + labels: ComponentLabels + merge_config: MergeConfig + n_iters_current: int + + meta: dict[str, Any] | None = None + + @property + def c_components(self) -> int: + return len(self.labels) + + @classmethod + def from_config( + cls, + merge_config: MergeConfig, + labels: ComponentLabels, + ) -> "MergeHistory": + n_components: int = len(labels) + n_iters_target: int = merge_config.get_num_iters(n_components) + return MergeHistory( + labels=labels, + n_iters_current=0, + selected_pairs=np.full((n_iters_target, 2), -1, dtype=np.int16), + merges=BatchedGroupMerge.init_empty( + batch_size=n_iters_target, n_components=n_components + ), + merge_config=merge_config, + ) + + def summary(self) -> dict[str, str | int | None | dict[str, int | str | None]]: + return dict( + c_components=self.c_components, + n_iters_current=self.n_iters_current, + total_iters=len(self.merges.k_groups), + len_labels=len(self.labels), + # wandb_url=self.wandb_url, + merge_config=self.merge_config.model_dump(mode="json"), + merges_summary=self.merges.summary(), + ) + + @override + def __str__(self) -> str: + out: list[str] = [f" {key} = {value}" for key, value in self.summary().items()] + return "MergeHistory(\n" + "\n".join(out) + "\n)" + + @override + def __repr__(self) -> str: + return self.__str__() + + def add_iteration( + self, + idx: int, + selected_pair: MergePair, + current_merge: GroupMerge, + ) -> None: + """Add data for one iteration.""" + self.selected_pairs[idx] = np.array(selected_pair, dtype=np.int16) + self.merges[idx] = current_merge + + assert self.n_iters_current == idx + self.n_iters_current += 1 + + def __getitem__(self, idx: int) -> IterationInfo: + """Get data for a specific iteration.""" + if idx < 0 or idx >= self.n_iters_current: + raise IndexError( + f"Index {idx} out of range for history with {self.n_iters_current} iterations" + ) + + return IterationInfo( + idx=idx, + selected_pair=self.selected_pairs[idx].tolist(), + merges=self.merges[idx], + ) + + def __len__(self) -> int: + """Get the number of iterations in the history.""" + return self.n_iters_current + + def latest(self) -> IterationInfo: + """Get the latest values.""" + if self.n_iters_current == 0: + raise ValueError("No history available") + latest_idx: int = self.n_iters_current - 1 + return self[latest_idx] + + def get_unique_clusters(self, iteration: int) -> list[int]: + """Get unique cluster IDs at a given iteration. + + Args: + iteration: Iteration index (negative indexes from end) + + Returns: + List of unique cluster IDs + """ + if iteration < 0: + iteration = self.n_iters_current + iteration + assert 0 <= iteration < self.n_iters_current, ( + f"Invalid iteration: {iteration = }, {self.n_iters_current = }" + ) + merge: GroupMerge = self.merges[iteration] + return torch.unique(merge.group_idxs).tolist() + + def get_cluster_component_labels(self, iteration: int, cluster_id: int) -> ComponentLabels: + """Get component labels for a specific cluster at a given iteration. + + Args: + iteration: Iteration index (negative indexes from end) + cluster_id: Cluster ID to query + + Returns: + List of component labels in the cluster + """ + if iteration < 0: + iteration = self.n_iters_current + iteration + assert 0 <= iteration < self.n_iters_current, ( + f"Invalid iteration: {iteration = }, {self.n_iters_current = }" + ) + merge: GroupMerge = self.merges[iteration] + component_indices: list[int] = merge.components_in_group(cluster_id) + return ComponentLabels([self.labels[idx] for idx in component_indices]) + + def get_cluster_components_info(self, iteration: int, cluster_id: int) -> list[dict[str, Any]]: + """Get detailed component information for a cluster. + + Args: + iteration: Iteration index (negative indexes from end) + cluster_id: Cluster ID to query + + Returns: + List of dicts with keys: module, index, label + """ + component_labels: list[str] = self.get_cluster_component_labels(iteration, cluster_id) + result: list[dict[str, Any]] = [] + for label in component_labels: + module: str + idx_str: str + module, idx_str = label.rsplit(":", 1) + result.append({"module": module, "index": int(idx_str), "label": label}) + return result + + # Convenience properties for sweep analysis + @property + def total_iterations(self) -> int: + """Total number of iterations performed.""" + return self.n_iters_current + + @property + def final_k_groups(self) -> int: + """Final number of groups after merging.""" + if self.n_iters_current == 0: + return self.c_components + return int(self.merges.k_groups[self.n_iters_current - 1].item()) + + @property + def initial_k_groups(self) -> int: + """Initial number of groups before merging.""" + if self.n_iters_current == 0: + return self.c_components + return int(self.merges.k_groups[0].item()) + + @override + def save(self, path: Path, wandb_url: str | None = None) -> None: + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "w") as zf: + # save arrays + _zip_save_arr_dict( + zf=zf, + data={ + "merge.group_idxs": self.merges.group_idxs.cpu().numpy(), + "merge.k_groups": self.merges.k_groups.cpu().numpy(), + "selected_pairs": self.selected_pairs, + }, + ) + # Save labels + zf.writestr("labels.txt", "\n".join(self.labels)) + # Save metadata + zf.writestr( + "metadata.json", + json.dumps( + dict( + merge_config=self.merge_config.model_dump(mode="json"), + wandb_url=wandb_url, + c_components=self.c_components, + n_iters_current=self.n_iters_current, + labels=self.labels, + ) + ), + ) + + @override + @classmethod + def read(cls, path: Path) -> "MergeHistory": + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "r") as zf: + group_idxs: np.ndarray = np.load(io.BytesIO(zf.read("merge.group_idxs.npy"))) + k_groups: np.ndarray = np.load(io.BytesIO(zf.read("merge.k_groups.npy"))) + selected_pairs: np.ndarray = np.load(io.BytesIO(zf.read("selected_pairs.npy"))) + merges: BatchedGroupMerge = BatchedGroupMerge( + group_idxs=torch.from_numpy(group_idxs), + k_groups=torch.from_numpy(k_groups), + ) + labels_raw: list[str] = zf.read("labels.txt").decode("utf-8").splitlines() + labels: ComponentLabels = ComponentLabels(labels_raw) + metadata: dict[str, Any] = json.loads(zf.read("metadata.json").decode("utf-8")) + merge_config: MergeConfig = MergeConfig.model_validate(metadata["merge_config"]) + + metadata["origin_path"] = path + + return cls( + merges=merges, + selected_pairs=selected_pairs, + labels=labels, + merge_config=merge_config, + n_iters_current=metadata["n_iters_current"], + meta=metadata, + ) + + +@dataclass +class MergeHistoryEnsemble: + data: list[MergeHistory] + + def __iter__(self): + return iter(self.data) + + def __getitem__(self, idx: int) -> MergeHistory: + return self.data[idx] + + def _validate_configs_match(self) -> None: + """Ensure all histories have the same merge config.""" + if not self.data: + return + first_config: MergeConfig = self.data[0].merge_config + for history in self.data[1:]: + if history.merge_config != first_config: + raise ValueError("All histories must have the same merge config") + + @property + def config(self) -> MergeConfig: + """Get the merge config used in the ensemble.""" + self._validate_configs_match() + return self.data[0].merge_config + + @property + def n_iters_min(self) -> int: + """Minimum number of iterations across all histories in the ensemble.""" + return min(len(history.merges.k_groups) for history in self.data) + + @property + def n_iters_max(self) -> int: + """Maximum number of iterations across all histories in the ensemble.""" + return max(len(history.merges.k_groups) for history in self.data) + + @property + def n_iters_range(self) -> tuple[int, int]: + """Range of iterations (min, max) across all histories in the ensemble.""" + iter_counts = [len(history.merges.k_groups) for history in self.data] + return (min(iter_counts), max(iter_counts)) + + @property + def n_ensemble(self) -> int: + """Number of ensemble members.""" + return len(self.data) + + @property + def c_components(self) -> int: + """Number of components in each history.""" + c_components: int = self.data[0].c_components + assert all(history.c_components == c_components for history in self.data), ( + "All histories must have the same number of components" + ) + return c_components + + @property + def shape(self) -> tuple[int, int, int]: + """Shape of the ensemble data.""" + return (self.n_ensemble, self.n_iters_min, self.c_components) + + @property + def merges_array(self) -> MergesArray: + n_ens: int = self.n_ensemble + n_iters: int = self.n_iters_min + c_components: int = self.c_components + + output: MergesArray = np.full( + (n_ens, n_iters, c_components), + fill_value=-1, + dtype=np.int16, + # if you have more than 32k components, change this to np.int32 + # if you have more than 2.1b components, rethink your life choices + ) + for i_ens, history in enumerate(self.data): + for i_iter, merge in enumerate(history.merges): + output[i_ens, i_iter] = merge.group_idxs + + return output + + def normalized(self) -> tuple[MergesArray, dict[str, Any]]: + """Normalize the component labels across all histories. + + if different histories see different batches, then they might have different dead + components, and are hence not directly comparable. So, we find the union of all + component labels across all histories, and then any component missing from a history + is put into it's own group in that history + """ + + unique_labels_set: set[str] = set() + for history in self.data: + unique_labels_set.update(history.labels) + + unique_labels_list: list[str] = sorted(unique_labels_set) + unique_labels: ComponentLabels = ComponentLabels(unique_labels_list) + c_components: int = len(unique_labels) + component_label_idxs: dict[str, int] = { + label: idx for idx, label in enumerate(unique_labels) + } + + try: + merges_array: MergesArray = np.full( + (self.n_ensemble, self.n_iters_min, c_components), + fill_value=-1, + dtype=np.int16, + ) + except Exception as e: + err_msg = ( + f"failed to create merge array, probably due to issues with getting shape.\n" + f"{self = }\n" + f"{self.data = }\n" + ) + raise RuntimeError(err_msg) from e + + overlap_stats: Float[np.ndarray, " n_ens"] = np.full( + self.n_ensemble, + fill_value=float("nan"), + dtype=np.float32, + ) + i_ens: int + history: MergeHistory + for i_ens, history in enumerate(self.data): + hist_c_labels: list[str] = history.labels + hist_n_components: int = len(hist_c_labels) + overlap_stats[i_ens] = hist_n_components / c_components + # map from old component indices to new component indices + i_comp_old: int + comp_label: str + for i_comp_old, comp_label in enumerate(hist_c_labels): + i_comp_new: int = component_label_idxs[comp_label] + merges_array[i_ens, :, i_comp_new] = history.merges.group_idxs[ + : self.n_iters_min, i_comp_old + ] + + # assert np.max(merges_array[i_ens]) == hist_n_components - 1, ( + # f"Max component index in history {i_ens} should be {hist_n_components - 1}, " + # f"but got {np.max(merges_array[i_ens])}" + # ) + + # put each missing label into its own group + hist_missing_labels: set[str] = unique_labels_set - set(hist_c_labels) + assert len(hist_missing_labels) == c_components - hist_n_components + idx_missing: int + missing_label: str + for idx_missing, missing_label in enumerate(hist_missing_labels): + i_comp_new_relabel: int = component_label_idxs[missing_label] + merges_array[i_ens, :, i_comp_new_relabel] = np.full( + self.n_iters_min, + fill_value=idx_missing + hist_n_components, + dtype=np.int16, + ) + + # TODO: Consider logging overlap_stats to WandB if run is available + # For now, keep using dbg_tensor for overlap_stats analysis + dbg_tensor(overlap_stats) + + # TODO: double check this + # Convert any Path objects to strings for JSON serialization + history_metadatas: list[dict[str, Any] | None] = [] + for history in self.data: + if history.meta is not None: + meta_copy = history.meta.copy() + # Convert Path objects to strings + for key, value in meta_copy.items(): + if isinstance(value, Path): + meta_copy[key] = str(value) + history_metadatas.append(meta_copy) + else: + history_metadatas.append(None) + + return ( + # TODO: dataclass this + merges_array, + dict( + component_labels=unique_labels, + n_ensemble=self.n_ensemble, + n_iters_min=self.n_iters_min, + n_iters_max=self.n_iters_max, + n_iters_range=self.n_iters_range, + c_components=c_components, + config=self.config.model_dump(mode="json"), + history_metadatas=history_metadatas, + ), + ) + + def get_distances(self, method: DistancesMethod = "perm_invariant_hamming") -> DistancesArray: + merges_array: MergesArray = self.merges_array + return compute_distances( + normalized_merge_array=merges_array, + method=method, + ) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py new file mode 100644 index 000000000..feba16967 --- /dev/null +++ b/spd/clustering/merge_run_config.py @@ -0,0 +1,281 @@ +"""Configuration for merge clustering runs that combines merge config with run parameters.""" + +import hashlib +import json +import tomllib +import warnings +from pathlib import Path +from typing import Any, Literal, Self + +import yaml +from muutils.misc.numerical import shorten_numerical_to_str +from pydantic import BaseModel, Field, PositiveInt, model_validator + +from spd.clustering.consts import DistancesMethod +from spd.clustering.merge_config import MergeConfig +from spd.registry import EXPERIMENT_REGISTRY, ExperimentConfig +from spd.spd_types import TaskName + +# Define interval types and defaults +IntervalKey = Literal["stat", "tensor", "plot", "artifact"] + +IntervalsDict = dict[IntervalKey, PositiveInt] +"""Type alias for intervals dictionary + +- `stat`: logging statistics (e.g., k_groups, merge_pair_cost, mdl_loss) +- `tensor`: logging tensors (e.g., wandb_log_tensor, fraction calculations) +- `plot`: generating plots +- `artifact`: creating artifacts (checkpoints) + +""" + +_DEFAULT_INTERVALS: IntervalsDict = { + "stat": 1, + "tensor": 100, + "plot": 100, + "artifact": 100, +} + + +def toml_read_file_with_none(path: Path, null_sentinel: str = "__NULL__") -> dict[str, Any]: + """Read a TOML file and recursively convert sentinel values to None. + + TOML doesn't support null/None values natively, so we use a sentinel string + that gets converted to None after parsing. + + Args: + path: Path to the TOML file + null_sentinel: String value to be converted to None (default: "__NULL__") + + Returns: + Dictionary with sentinel values replaced by None + """ + + def replace_sentinel_recursive(obj: Any) -> Any: + """Recursively replace sentinel values with None.""" + if isinstance(obj, dict): + return {key: replace_sentinel_recursive(value) for key, value in obj.items()} + elif isinstance(obj, list): + return [replace_sentinel_recursive(item) for item in obj] + elif isinstance(obj, str) and obj == null_sentinel: + return None + else: + return obj + + with path.open("rb") as f: + data = tomllib.load(f) + return replace_sentinel_recursive(data) + + +class ClusteringRunConfig(BaseModel): + """Configuration for a complete merge clustering run. + + Extends MergeConfig with parameters for model, dataset, and batch configuration. + CLI parameters (base_path, devices, workers_per_device) have defaults but will always be overridden + """ + + merge_config: MergeConfig = Field( + description="Merge configuration", + ) + + model_path: str = Field( + description="WandB path to the model (format: wandb:entity/project/run_id)", + ) + task_name: TaskName = Field( + description="Task name for the model (must be explicit)", + ) + experiment_key: str | None = Field( + default=None, + description="Original experiment key if created from spd_exp registry", + ) + n_batches: PositiveInt = Field( + default=10, + description="Number of batches to split the dataset into (ensemble size)", + ) + batch_size: PositiveInt = Field( + default=64, + description="Size of each batch for processing", + ) + distances_method: DistancesMethod = Field( + default="perm_invariant_hamming", + description="Method to use for computing distances between clusterings", + ) + + # Implementation details + # note that these are *always* overriden by CLI args in `spd/clustering/scripts/main.py`, but we have to have defaults here + # to avoid type issues with pydantic. however, these defaults should match the defaults in the CLI args. + base_path: Path = Field( + default_factory=lambda: Path(".data/clustering/"), + description="Base path for saving clustering outputs", + ) + workers_per_device: int = Field( + default=1, + description="Maximum number of concurrent clustering processes per device", + ) + devices: list[str] = Field( + default_factory=lambda: ["cpu"], + description="Devices to use for clustering", + ) + + # WandB configuration + wandb_enabled: bool = Field( + default=False, + description="Enable WandB logging for clustering runs", + ) + wandb_project: str = Field( + default="spd-cluster", + description="WandB project name for clustering runs", + ) + intervals: dict[IntervalKey, PositiveInt] = Field( + default_factory=lambda: _DEFAULT_INTERVALS.copy(), + description="Intervals for different logging operations", + ) + + @model_validator(mode="after") + def validate_model_path(self) -> Self: + """Validate that model_path is a proper WandB path.""" + if not self.model_path.startswith("wandb:"): + raise ValueError(f"model_path must start with 'wandb:', got: {self.model_path}") + + assert self.task_name in TaskName.__args__, ( + f"Invalid task_name: {self.task_name = }, must be in {TaskName.__args__ = }" + ) + return self + + @model_validator(mode="after") + def validate_intervals(self) -> Self: + """Ensure all required interval keys are present.""" + # warning if any keys are missing + missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(self.intervals.keys()) + if missing_keys: + warnings.warn( + f"Missing interval keys in {self.intervals = }: {missing_keys}. Using defaults for those.", + UserWarning, + stacklevel=1, + ) + + self.intervals = { + **_DEFAULT_INTERVALS, + **self.intervals, + } + + return self + + @property + def wandb_decomp_model(self) -> str: + """Extract the WandB run ID of the source decomposition from the model_path + + Format: wandb:entity/project/run_id or wandb:entity/project/runs/run_id + """ + parts: list[str] = self.model_path.replace("wandb:", "").split("/") + if len(parts) >= 3: + # Handle both formats: with and without 'runs' in path + return parts[-1] if parts[-1] != "runs" else parts[-2] if len(parts) > 3 else parts[-1] + else: + raise ValueError(f"Invalid wandb path format: {self.model_path}") + + @property + def wandb_group(self) -> str: + """Generate WandB group name based on parent model""" + return f"model-{self.wandb_decomp_model}" + + @property + def _iters_str(self) -> str: + """Shortened string representation of iterations for run ID""" + if self.merge_config.iters is None: + return "_auto" + return shorten_numerical_to_str(self.merge_config.iters) + + @property + def config_identifier(self) -> str: + """Unique identifier for this specific config on this specific model. + + Format: model_abc123-a0.1-i1k-b64-n10-h_12ab + Allows filtering in WandB for all runs with this exact config and model. + """ + return f"task_{self.task_name}-w_{self.wandb_decomp_model}-a{self.merge_config.alpha:g}-i{self._iters_str}-b{self.batch_size}-n{self.n_batches}-h_{self.stable_hash}" + + @property + def stable_hash(self) -> str: + """Generate a stable hash including all config parameters.""" + return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6] + + @classmethod + def read(cls, path: Path) -> "ClusteringRunConfig": + """Load config from JSON, YAML, or TOML file. + + Handles legacy spd_exp: model_path format and enforces consistency. + For TOML files, the sentinel value "__NULL__" is converted to None. + """ + # read the file contents, load them according to extension + data: dict[str, Any] + content: str + if path.suffix == ".json": + content = path.read_text() + data = json.loads(content) + elif path.suffix in [".yaml", ".yml"]: + content = path.read_text() + data = yaml.safe_load(content) + elif path.suffix == ".toml": + data = toml_read_file_with_none(path) + else: + raise ValueError( + f"Unsupported file extension '{path.suffix}' on file '{path}' -- must be .json, .yaml, .yml, or .toml" + ) + + # if we provide an experiment_key, then: + # 1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name + # 2. check it's consistent with model_path and task_name from the file if those are provided + experiment_key: str | None = data.get("experiment_key") + model_path: str | None = data.get("model_path") + task_name: str | None = data.get("task_name") + if experiment_key is not None: + exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key] + + # Enforce consistency if explicit fields present + if model_path is not None: + assert model_path == exp_config.canonical_run, ( + f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})" + ) + if task_name is not None: + assert task_name == exp_config.task_name, ( + f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})" + ) + + # overwrite in data dict + data["model_path"] = exp_config.canonical_run + data["task_name"] = exp_config.task_name + + return cls.model_validate(data) + + def save(self, path: Path) -> None: + """Save config to file (format inferred from extension).""" + path.parent.mkdir(parents=True, exist_ok=True) + if path.suffix == ".json": + path.write_text(self.model_dump_json(indent=2)) + elif path.suffix in [".yaml", ".yml"]: + path.write_text( + yaml.dump( + self.model_dump(mode="json"), + default_flow_style=False, + sort_keys=False, + ) + ) + else: + raise ValueError(f"Unsupported file extension: {path.suffix}") + + def model_dump_with_properties(self) -> dict[str, Any]: + """Serialize config including computed properties for WandB logging.""" + base_dump: dict[str, Any] = self.model_dump() + + # Add computed properties + base_dump.update( + { + "wandb_decomp_model": self.wandb_decomp_model, + "wandb_group": self.wandb_group, + "config_identifier": self.config_identifier, + "stable_hash": self.stable_hash, + } + ) + + return base_dump diff --git a/spd/clustering/pipeline/__init__.py b/spd/clustering/pipeline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py new file mode 100644 index 000000000..1e07c71d7 --- /dev/null +++ b/spd/clustering/pipeline/clustering_pipeline.py @@ -0,0 +1,106 @@ +"""Orchestration layer - clustering pipeline coordination""" + +import os +from collections.abc import Iterator +from pathlib import Path +from typing import Any + +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.log import logger + +os.environ["WANDB_QUIET"] = "True" + + +def main(config: ClusteringRunConfig) -> None: + """Run the complete clustering pipeline. + + Args: + config: ClusteringRunConfig containing all pipeline parameters + """ + logger.section("setup") + + from spd.clustering.consts import BatchTensor, DistancesArray, DistancesMethod, MergesArray + from spd.clustering.math.merge_distances import compute_distances + from spd.clustering.pipeline.dist_utils import distribute_clustering + from spd.clustering.pipeline.s1_split_dataset import split_dataset + from spd.clustering.pipeline.s3_normalize_histories import normalize_and_save + from spd.clustering.pipeline.s4_compute_distances import create_clustering_report + from spd.clustering.pipeline.storage import ClusteringStorage + + logger.info("Imports complete") + + # Initialize storage + storage: ClusteringStorage = ClusteringStorage( + base_path=config.base_path, run_identifier=config.config_identifier + ) + logger.info(f"Initialized storage at: {storage.run_path}") + + # Save run configuration + storage.save_run_config(config) + logger.info(f"Run record saved to: {storage.run_config_file}") + + # Save config to a path that can be passed to subprocess + config_path: Path = storage.run_path / "config.json" + config_path.write_text(config.model_dump_json(indent=2)) + logger.info(f"Config saved to: {config_path}") + + # Split dataset into batches + logger.info(f"Splitting dataset into {config.n_batches} batches...") + batches: Iterator[BatchTensor] + dataset_config: dict[str, Any] + batches, dataset_config = split_dataset(config=config) + storage.save_batches(batches=batches, config=dataset_config) + batch_paths: list[Path] = storage.get_batch_paths() + n_batch_paths: int = len(batch_paths) + logger.info(f"Dataset split complete. Saved {n_batch_paths} batches to: {storage._batches_dir}") + + # Process batches in parallel via subprocess shell-out + logger.section("computing clusterings") + logger.info( + f"Processing {n_batch_paths} batches with {config.workers_per_device} workers per device on {config.devices}..." + ) + distribute_prefix: str = "\033[92m[spd-cluster]\033[0m" + + from spd.clustering.pipeline.dist_utils import ClusteringBatchResult + + results: list[ClusteringBatchResult] = distribute_clustering( + config_path=config_path, + data_files=batch_paths, + devices=config.devices, + base_path=config.base_path, + run_identifier=config.config_identifier, + workers_per_device=config.workers_per_device, + log_fn=lambda msg: logger.info(f"{distribute_prefix} {msg}"), + log_fn_error=lambda msg: logger.error(f"{distribute_prefix} {msg}"), + ) + logger.info(f"Batch processing complete. Processed {len(results)} batches") + + logger.section("computing distances") + + # Normalize and save ensemble + logger.info("Normalizing merge histories across ensemble...") + normalized_merge_array: MergesArray = normalize_and_save(storage=storage) + logger.info( + f"Normalized merge array saved: shape={normalized_merge_array.shape}, dtype={normalized_merge_array.dtype}" + ) + + # Compute distances + distances_method: DistancesMethod = config.distances_method + logger.info(f"Computing distances using method: {distances_method}") + distances: DistancesArray = compute_distances( + normalized_merge_array=normalized_merge_array, + method=distances_method, + ) + storage.save_distances(distances=distances, method=distances_method) + logger.info(f"Distances computed and saved: shape={distances.shape}") + + # Create clustering report + wandb_urls: list[str] = [r["wandb_url"] for r in results if r["wandb_url"] is not None] + logger.info(f"Creating clustering report with {len(wandb_urls)} WandB URLs") + create_clustering_report( + distances=distances, + method=distances_method, + wandb_urls=wandb_urls, + config_identifier=config.config_identifier, + ) + logger.info("Clustering report created successfully") diff --git a/spd/clustering/pipeline/dist_utils.py b/spd/clustering/pipeline/dist_utils.py new file mode 100644 index 000000000..5f7d8f7fd --- /dev/null +++ b/spd/clustering/pipeline/dist_utils.py @@ -0,0 +1,313 @@ +"""Distribution utilities for parallel clustering via subprocess shell-out.""" + +import json +import os +import selectors +import subprocess +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import IO, TypedDict + +from spd.log import logger +from spd.settings import REPO_ROOT + + +class ClusteringBatchResult(TypedDict): + """Result from clustering a single batch.""" + + hist_save_path: str + wandb_url: str | None + batch_name: str + config_identifier: str + + +# Module-global cache for JSON writer in child processes +_JSON_WRITER: IO[str] | None = None + + +@dataclass +class ActiveProcess: + """Tracks an active subprocess and its associated metadata.""" + + proc: subprocess.Popen[bytes] + json_fd: IO[bytes] + dataset_path: Path + device: str + + +def launch_child_with_json_fd(cmd: list[str]) -> tuple[subprocess.Popen[bytes], IO[bytes]]: + """Launch child process with JSON fd via environment variable. + + This allows the child to write structured JSON output to a dedicated file descriptor + while still allowing stdout/stderr to stream normally to the console. + + Args: + cmd: Command and arguments to execute + + Returns: + Tuple of (subprocess handle, read file descriptor for JSON results) + """ + # get the pipes + json_fd_rw: tuple[int, int] = os.pipe() # (read_fd, write_fd) + os.set_inheritable(json_fd_rw[1], True) + os.set_inheritable(json_fd_rw[0], False) + + # Pass the fd number via environment variable + env: dict[str, str] = dict(os.environ) + env["JSON_FD"] = str(json_fd_rw[1]) + + # launch the child process + proc: subprocess.Popen[bytes] = subprocess.Popen( + cmd, + env=env, + stdout=None, # Let stdout stream to console + stderr=None, # Let stderr stream to console + pass_fds=(json_fd_rw[1],), + close_fds=True, + ) + + # In parent process: close the write fd (child has it) and return read fd + os.close(json_fd_rw[1]) + json_r: IO[bytes] = os.fdopen(json_fd_rw[0], "rb", buffering=0) + return proc, json_r + + +def _open_json_fd() -> IO[str]: + """Open file descriptor for JSON output from environment variable. + + Called by child processes to get the fd for emitting structured results. + Caches the writer globally to avoid re-wrapping the same FD. + + Returns: + IO[str]: Text-mode writer (utf-8), line-buffered + """ + global _JSON_WRITER + if _JSON_WRITER is None: + fd_num: int = int(os.environ["JSON_FD"]) + # Use utf-8 explicitly; line-buffered + _JSON_WRITER = os.fdopen(fd_num, "w", buffering=1, encoding="utf-8") # pyright: ignore[reportConstantRedefinition] + return _JSON_WRITER + + +def emit_result(obj: ClusteringBatchResult) -> None: + """Emit result JSON via environment fd. + + Called by child processes to return structured results to the parent. + + Args: + obj: Result dictionary to serialize and emit + """ + out: IO[str] = _open_json_fd() + print(json.dumps(obj, separators=(",", ":")), file=out, flush=True) + + +def _read_json_result(json_r: IO[bytes], dataset_path: Path) -> ClusteringBatchResult: + """Read JSON result from file descriptor. + + Args: + json_r: Read file descriptor for JSON data + dataset_path: Path to dataset being processed (for error messages) + + Returns: + Parsed JSON result dictionary + + Raises: + RuntimeError: If no JSON result was received + ValueError: If JSON parsing failed + """ + json_line: bytes = json_r.readline() + if not json_line: + raise RuntimeError(f"No JSON result received from {dataset_path}") + + json_str: str = json_line.decode("utf-8", errors="strict").strip() + try: + result: ClusteringBatchResult = json.loads(json_str) + return result + except json.JSONDecodeError as e: + raise ValueError( + f"Failed to parse JSON result from {dataset_path}: {e}\nJSON string: {json_str}" + ) from e + + +def _collect_one_ready( + active: list[ActiveProcess], + log_fn: Callable[[str], None], +) -> tuple[ClusteringBatchResult, ActiveProcess]: + """Block until any active process has JSON ready, then collect it. + + Uses selectors to wait on multiple FDs simultaneously, avoiding head-of-line blocking. + + Args: + active: Currently active processes + log_fn: Logger for info messages + + Returns: + Tuple of (parsed JSON result, the corresponding ActiveProcess) + + Raises: + RuntimeError: If subprocess exits with non-zero code + """ + sel: selectors.BaseSelector = selectors.DefaultSelector() + try: + for ap in active: + sel.register(ap.json_fd, selectors.EVENT_READ, ap) + key: selectors.SelectorKey + key, _mask = sel.select()[0] # select() -> list[(SelectorKey, int)] + ap: ActiveProcess = key.data # type: ignore[assignment] + finally: + sel.close() + + result: ClusteringBatchResult = _read_json_result(ap.json_fd, ap.dataset_path) + rc: int | None = ap.proc.wait() + try: # noqa: SIM105 + ap.json_fd.close() + except Exception: + pass + + if rc != 0: + raise RuntimeError( + f"Subprocess {ap.proc.pid} on {ap.device} exited with code {rc} for dataset {ap.dataset_path}" + ) + + log_fn(f"Process {ap.proc.pid} finished, freeing slot on {ap.device}") + return result, ap + + +def distribute_clustering( + config_path: Path, + data_files: list[Path], + devices: list[str], + base_path: Path, + run_identifier: str, + workers_per_device: int = 1, + log_fn: Callable[[str], None] | None = None, + log_fn_error: Callable[[str], None] | None = None, +) -> list[ClusteringBatchResult]: + """Distribute clustering tasks across multiple devices via subprocess. + + Launches clustering processes using shell-out approach with JSON fd for structured + results. Manages concurrency based on workers_per_device and available devices. + + The concurrency model: + - Total concurrency = workers_per_device x len(devices) + - Uses round-robin device assignment starting point + - If target device is full, uses any available device + - If all devices are full, waits for ANY process to finish (whichever is ready first) + + Args: + config_path: Path to clustering configuration file + data_files: List of batch data files to process + devices: List of device strings (e.g., ['cuda:0', 'cuda:1']) + base_path: Base directory for clustering outputs + run_identifier: Unique identifier for this clustering run + workers_per_device: Maximum concurrent workers per device + log_fn: Optional logging function for info messages + log_fn_error: Optional logging function for error messages + + Returns: + List of result dictionaries from each batch processing + + Raises: + ValueError: If devices list is empty + RuntimeError: If subprocess fails or doesn't return results + """ + # setup logger + if log_fn is None: + log_fn = logger.info + if log_fn_error is None: + log_fn_error = lambda msg: logger.error(msg) + + # validate parameters + if workers_per_device < 1: + raise ValueError("workers_per_device must be >= 1") + + n_devices: int = len(devices) + if n_devices == 0: + raise ValueError("devices must be non-empty") + + # Track active processes per device to enforce workers_per_device limit + device_active_counts: dict[str, int] = {device: 0 for device in devices} + active: list[ActiveProcess] = [] + results: list[ClusteringBatchResult] = [] + + n_files: int = len(data_files) + try: + for idx, dataset in enumerate(data_files): + # Find a device with capacity, starting from round-robin position + device_idx = idx % n_devices + + # Check if we need to wait for a device to free up + while all(count >= workers_per_device for count in device_active_counts.values()): + # All devices are at capacity - wait for ANY process to finish + log_fn( + f"All devices at capacity ({workers_per_device} workers each). Waiting for any process to finish..." + ) + + # Wait for whichever process is ready first + result_i, finished_ap = _collect_one_ready(active, log_fn) + results.append(result_i) + device_active_counts[finished_ap.device] -= 1 + active.remove(finished_ap) + + # Now find a device with capacity, starting from our round-robin position + for i in range(n_devices): + check_idx = (device_idx + i) % n_devices + if device_active_counts[devices[check_idx]] < workers_per_device: + device_idx = check_idx + break + + device: str = devices[device_idx] + + cmd: list[str] = [ + "uv", + "run", + "python", + str(REPO_ROOT / "spd/clustering/pipeline/s2_clustering.py"), + "--config", + str(config_path), + "--dataset-path", + str(dataset), + "--base-path", + str(base_path), + "--run-identifier", + run_identifier, + "--device", + device, + ] + log_fn("[cmd] " + " ".join(cmd)) + + proc, json_r = launch_child_with_json_fd(cmd) + active_proc = ActiveProcess( + proc=proc, json_fd=json_r, dataset_path=dataset, device=device + ) + active.append(active_proc) + device_active_counts[device] += 1 + log_fn( + f"Started clustering {idx + 1}/{n_files} on {device} (pid={proc.pid}, active on device: {device_active_counts[device]}/{workers_per_device})\n\t{dataset}" + ) + + # Wait for remaining processes + while active: + result_i, finished_ap = _collect_one_ready(active, log_fn) + results.append(result_i) + device_active_counts[finished_ap.device] -= 1 + active.remove(finished_ap) + log_fn(f"Process {finished_ap.proc.pid} finished on {finished_ap.device}") + + except BaseException as e: + # this means we probably got a KeyboardInterrupt, so kill the child processes + log_fn_error(f"An error occurred: {e}") + for active_proc in active: + try: # noqa: SIM105 + active_proc.proc.kill() + except Exception: + pass + try: # noqa: SIM105 + active_proc.json_fd.close() + except Exception: + pass + log_fn_error(f"Killed process {active_proc.proc.pid} due to error") + raise + + return results diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py new file mode 100644 index 000000000..d5427e600 --- /dev/null +++ b/spd/clustering/pipeline/s1_split_dataset.py @@ -0,0 +1,151 @@ +""" +Loads and splits dataset into batches, returning them as an iterator. +""" + +from collections.abc import Generator, Iterator +from typing import Any + +import torch +from muutils.spinner import SpinnerContext +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm + +from spd.clustering.consts import BatchTensor +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.experiments.resid_mlp.configs import ResidMLPModelConfig, ResidMLPTaskConfig +from spd.experiments.resid_mlp.models import ResidMLP +from spd.models.component_model import ComponentModel, SPDRunInfo + + +def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], dict[str, Any]]: + """Split a dataset into n_batches of batch_size, returning iterator and config""" + ds: Generator[BatchTensor, None, None] + ds_config_dict: dict[str, Any] + match config.task_name: + case "lm": + ds, ds_config_dict = _get_dataloader_lm( + model_path=config.model_path, + batch_size=config.batch_size, + ) + case "resid_mlp": + ds, ds_config_dict = _get_dataloader_resid_mlp( + model_path=config.model_path, + batch_size=config.batch_size, + ) + case name: + raise ValueError( + f"Unsupported task name '{name}'. Supported tasks are 'lm' and 'resid_mlp'. {config.model_path=}, {name=}" + ) + + # Limit iterator to n_batches + def limited_iterator() -> Iterator[BatchTensor]: + batch_idx: int + batch: BatchTensor + for batch_idx, batch in tqdm(enumerate(ds), total=config.n_batches, unit="batch"): + if batch_idx >= config.n_batches: + break + yield batch + + return limited_iterator(), ds_config_dict + + +def _get_dataloader_lm( + model_path: str, + batch_size: int, +) -> tuple[Generator[BatchTensor, None, None], dict[str, Any]]: + """split up a SS dataset into n_batches of batch_size, returned the saved paths + + 1. load the config for a SimpleStories SPD Run given by model_path + 2. create a DataLoader for the dataset + 3. iterate over the DataLoader and save each batch to a file + + + """ + with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): + spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) + cfg: Config = spd_run.config + + try: + pretrained_model_name: str = cfg.pretrained_model_name # pyright: ignore[reportAssignmentType] + assert pretrained_model_name is not None + except Exception as e: + raise AttributeError( + "Could not find 'pretrained_model_name' in the SPD Run config, but called `_get_dataloader_lm`" + ) from e + + assert isinstance(cfg.task_config, LMTaskConfig), ( + f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" + ) + + dataset_config: DatasetConfig = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + is_tokenized=False, + streaming=False, + seed=0, + column_name=cfg.task_config.column_name, + ) + + with SpinnerContext(message="getting dataloader..."): + dataloader: DataLoader[dict[str, torch.Tensor]] + dataloader, _tokenizer = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=cfg.task_config.buffer_size, + global_seed=cfg.seed, + ddp_rank=0, + ddp_world_size=1, + ) + + return (batch["input_ids"] for batch in dataloader), dataset_config.model_dump(mode="json") + + +def _get_dataloader_resid_mlp( + model_path: str, + batch_size: int, +) -> tuple[Generator[torch.Tensor, None, None], dict[str, Any]]: + """Split a ResidMLP dataset into n_batches of batch_size and save the batches.""" + from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset + from spd.utils.data_utils import DatasetGeneratedDataLoader + + with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): + spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) + # SPD_RUN = SPDRunInfo.from_path(EXPERIMENT_REGISTRY["resid_mlp3"].canonical_run) + component_model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) + cfg: Config = spd_run.config + + with SpinnerContext(message="Creating ResidMLPDataset..."): + assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( + f"Expected task_config to be of type ResidMLPTaskConfig since using `_get_dataloader_resid_mlp`, but got {type(cfg.task_config) = }" + ) + assert isinstance(component_model.target_model, ResidMLP), ( + f"Expected patched_model to be of type ResidMLP since using `_get_dataloader_resid_mlp`, but got {type(component_model.patched_model) = }" + ) + + assert isinstance(component_model.target_model.config, ResidMLPModelConfig), ( + f"Expected patched_model.config to be of type ResidMLPModelConfig since using `_get_dataloader_resid_mlp`, but got {type(component_model.target_model.config) = }" + ) + resid_mlp_dataset_kwargs: dict[str, Any] = dict( + n_features=component_model.target_model.config.n_features, + feature_probability=cfg.task_config.feature_probability, + device="cpu", + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=None, + label_coeffs=None, + data_generation_type=cfg.task_config.data_generation_type, + ) + dataset: ResidMLPDataset = ResidMLPDataset(**resid_mlp_dataset_kwargs) + + dataloader: DatasetGeneratedDataLoader[tuple[Tensor, Tensor]] = DatasetGeneratedDataLoader( + dataset, batch_size=batch_size, shuffle=False + ) + + return (batch[0] for batch in dataloader), resid_mlp_dataset_kwargs diff --git a/spd/clustering/pipeline/s2_clustering.py b/spd/clustering/pipeline/s2_clustering.py new file mode 100644 index 000000000..d04b16bc5 --- /dev/null +++ b/spd/clustering/pipeline/s2_clustering.py @@ -0,0 +1,409 @@ +"""Stage 2: Run clustering on individual batches (CLI script interface).""" + +import argparse +import os +import tempfile +from collections.abc import Callable +from dataclasses import dataclass +from functools import partial +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +import wandb +from jaxtyping import Float, Int +from matplotlib.figure import Figure +from torch import Tensor +from wandb.sdk.wandb_run import Run + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.consts import ( + ActivationsTensor, + BatchTensor, + ClusterCoactivationShaped, + ComponentLabels, +) +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.math.semilog import semilog +from spd.clustering.merge import _BATCH_PREFIX_FMT, merge_iteration +from spd.clustering.merge_history import MergeHistory +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.pipeline.dist_utils import emit_result +from spd.clustering.pipeline.storage import ClusteringStorage +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration +from spd.clustering.wandb_tensor_info import wandb_log_tensor +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo + +os.environ["WANDB_QUIET"] = "True" + +LogCallback = Callable[ + [ + ClusterCoactivationShaped, + ComponentLabels, + GroupMerge, + ClusterCoactivationShaped, + MergeHistory, + int, + int, + float, + float, + float, + Float[Tensor, " k_groups"], + ], + None, +] + + +@dataclass +class ClusteringResult: + history_save_path: Path + wandb_url: str | None + + +def run_clustering( + config: ClusteringRunConfig, + data_path: Path, + base_path: Path, + run_identifier: str, + device: str, +) -> ClusteringResult: + """Run clustering on a single batch. + + Args: + config: Clustering configuration + data_path: Path to batch data file + base_path: Base directory for storage + run_identifier: Unique identifier for this clustering run + device: Device to run on (e.g., 'cuda:0', 'cpu') + + Returns: + ClusteringResult with save path and optional WandB URL + """ + batch_id: str = data_path.stem + prefix: str = _BATCH_PREFIX_FMT.format(batch_id=batch_id) + + def logger_call(msg: str) -> None: + logger.info(f"{prefix} {msg}") + + logger_call("starting batch") + storage: ClusteringStorage = ClusteringStorage( + base_path=base_path, run_identifier=run_identifier + ) + + run: Run | None = ( + _setup_wandb(batch_id=batch_id, config=config) if config.wandb_enabled else None + ) + logger_call("wandb setup complete") + + this_merge_plots_dir: Path = storage.history_path(batch_id).parent / "plots" + + spd_run: SPDRunInfo = SPDRunInfo.from_path(config.model_path) + logger_call("loaded spd run info") + + model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path).to(device) + logger_call("loaded model") + + batch: BatchTensor = storage.load_batch(data_path).to(device) + logger_call(f"loaded batch {batch_id} with shape {batch.shape}") + + activations_dict: ( + dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] + ) = component_activations( + model=model, + batch=batch, + device=device, + sigmoid_type=spd_run.config.sigmoid_type, + ) + logger_call("computed activations") + + processed_activations: ProcessedActivations = process_activations( + activations=activations_dict, + filter_dead_threshold=config.merge_config.filter_dead_threshold, + seq_mode="concat" if config.task_name == "lm" else None, + filter_modules=config.merge_config.filter_modules, + ) + logger_call("processed activations") + + wandb_url: str | None + if run is not None: + wandb_log_tensor( + run=run, + data=processed_activations.activations, + name="processed_activations", + step=0, + single=True, + ) + wandb_url = run.url + else: + wandb_url = None + + # Use original activations for raw plots, but filtered data for concat/coact/histograms + logger_call("plotting") + plot_activations( + processed_activations=processed_activations, + save_dir=this_merge_plots_dir, + n_samples_max=256, # TODO: make this configurable? + wandb_run=run, + ) + logger_call(f"plots saved to {this_merge_plots_dir}") + + logger_call("cleaning up memory") + activations: ActivationsTensor = processed_activations.activations + component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) + del processed_activations # we copied what we needed + del activations_dict # processed already + del model # already did the forward pass + del batch # already did the forward pass + + log_callback: LogCallback | None = ( + partial(_log_callback, run=run, batch_id=batch_id, config=config) + if run is not None + else None + ) + + logger_call("starting merging") + history: MergeHistory = merge_iteration( + merge_config=config.merge_config, + activations=activations, + component_labels=component_labels, + log_callback=log_callback, + batch_id=batch_id, + ) + logger_call("merging complete") + + history_save_path: Path = storage.history_path(batch_id) + + history.save(history_save_path, wandb_url=wandb_url) + + if run is not None: + _log_merge_history_plots_to_wandb(run, history) + _save_merge_history_to_wandb( + run, history_save_path, batch_id, config.config_identifier, history + ) + + run.finish() + + logger_call("batch complete") + + return ClusteringResult(history_save_path=history_save_path, wandb_url=wandb_url) + + +def _setup_wandb( + batch_id: str, + config: ClusteringRunConfig, +) -> Run: + run: Run = wandb.init( + project=config.wandb_project, + name=f"{config.config_identifier}-{batch_id}", + group=config.wandb_group, + config=config.model_dump_with_properties(), + tags=[ + "cluster-run", + f"model:{config.wandb_decomp_model}", + f"task:{config.task_name}", + f"batch:{batch_id}", + f"config:{config.config_identifier}", + ], + ) + logger.info( + f"{_BATCH_PREFIX_FMT.format(batch_id=batch_id)} Initialized WandB run: {run.name} in group {config.wandb_group}" + ) + return run + + +def _log_merge_history_plots_to_wandb(run: Run, history: MergeHistory) -> None: + fig_cs: Figure = plot_merge_history_cluster_sizes(history=history) + run.log( + {"plots/merge_history_cluster_sizes": wandb.Image(fig_cs)}, + step=history.n_iters_current, + ) + plt.close(fig_cs) + + +def _save_merge_history_to_wandb( + run: Run, + history_path: Path, + batch_id: str, + config_identifier: str, + history: MergeHistory, +) -> None: + artifact: wandb.Artifact = wandb.Artifact( + name=f"merge_history_{batch_id}", + type="merge_history", + description=f"Merge history for batch {batch_id}", + metadata={ + "batch_name": batch_id, + "config_identifier": config_identifier, + "n_iters_current": history.n_iters_current, + "filename": history_path, + }, + ) + artifact.add_file(str(history_path)) + run.log_artifact(artifact) + + +def _log_callback( + run: Run, + batch_id: str, + current_coact: ClusterCoactivationShaped, + component_labels: ComponentLabels, + current_merge: GroupMerge, + config: ClusteringRunConfig, + costs: ClusterCoactivationShaped, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: Float[Tensor, " k_groups"], +) -> None: + if iter_idx % config.intervals["stat"] == 0: + run.log( + { + "k_groups": int(k_groups), + "merge_pair_cost": merge_pair_cost, + "merge_pair_cost_semilog[1e-3]": semilog(merge_pair_cost, epsilon=1e-3), + "mdl_loss": float(mdl_loss), + "mdl_loss_norm": float(mdl_loss_norm), + }, + step=iter_idx, + ) + + if iter_idx % config.intervals["tensor"] == 0: + group_sizes: Int[Tensor, " k_groups"] = current_merge.components_per_group + + tensor_data: dict[str, Tensor] = { + "coactivation": current_coact, + "costs": costs, + "group_sizes": group_sizes, + "group_activations": diag_acts, + "group_activations_over_sizes": ( + diag_acts / group_sizes.to(device=diag_acts.device).float() + ), + } + + fraction_singleton_groups: float = (group_sizes == 1).float().mean().item() + if fraction_singleton_groups > 0: + tensor_data["group_sizes.log1p"] = torch.log1p(group_sizes.float()) + + fraction_zero_coacts: float = (current_coact == 0).float().mean().item() + if fraction_zero_coacts > 0: + tensor_data["coactivation.log1p"] = torch.log1p(current_coact.float()) + + wandb_log_tensor(run, tensor_data, name="iters", step=iter_idx) + + run.log( + { + "fraction_singleton_groups": float(fraction_singleton_groups), + "fraction_zero_coacts": float(fraction_zero_coacts), + }, + step=iter_idx, + ) + + if iter_idx > 0 and iter_idx % config.intervals["artifact"] == 0: + with tempfile.NamedTemporaryFile() as tmp_file: + file: Path = Path(tmp_file.name) + merge_history.save(file) + artifact: wandb.Artifact = wandb.Artifact( + name=f"merge_hist_iter.{batch_id}.iter_{iter_idx}", + type="merge_hist_iter", + description=f"Group indices for batch {batch_id} at iteration {iter_idx}", + metadata={ + "batch_name": batch_id, + "iteration": iter_idx, + "config": merge_history.merge_config.model_dump(mode="json"), + # TODO: had to remove identifiers on config due to MergeConfig <--> ClusteringRunConfig (formerly MergeRunConfig) split + # "config_identifier": merge_history.merge_config.config_identifier, + }, + ) + artifact.add_file(str(file)) + run.log_artifact(artifact) + + if iter_idx % config.intervals["plot"] == 0: + fig: Figure = plot_merge_iteration( + current_merge=current_merge, + current_coact=current_coact, + costs=costs, + iteration=iter_idx, + component_labels=component_labels, + show=False, + ) + run.log({"plots/merges": wandb.Image(fig)}, step=iter_idx) + plt.close(fig) + + +def cli() -> None: + """Command-line interface for running clustering on a single batch.""" + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Run clustering on a single batch of data" + ) + parser.add_argument( + "--config", + "-c", + type=Path, + required=True, + help="Path to the clustering run config JSON/YAML file", + ) + parser.add_argument( + "--dataset-path", + "-d", + type=Path, + required=True, + help="Path to the dataset batch file (e.g., batch_00.npz)", + ) + parser.add_argument( + "--base-path", + "-b", + type=Path, + required=True, + help="Base directory for clustering outputs", + ) + parser.add_argument( + "--run-identifier", + "-r", + type=str, + required=True, + help="Unique identifier for this clustering run", + ) + parser.add_argument( + "--device", + "-D", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu", + help="Device to run on (e.g., 'cuda:0', 'cpu')", + ) + + args: argparse.Namespace = parser.parse_args() + + # Load config + config: ClusteringRunConfig = ClusteringRunConfig.read(args.config) + + # Run clustering + result: ClusteringResult = run_clustering( + config=config, + data_path=args.dataset_path, + base_path=args.base_path, + run_identifier=args.run_identifier, + device=args.device, + ) + + # Emit structured result for parent process + emit_result( + { + "hist_save_path": str(result.history_save_path), + "wandb_url": result.wandb_url, + "batch_name": args.dataset_path.stem, + "config_identifier": config.config_identifier, + } + ) + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/pipeline/s3_normalize_histories.py b/spd/clustering/pipeline/s3_normalize_histories.py new file mode 100644 index 000000000..b09733d44 --- /dev/null +++ b/spd/clustering/pipeline/s3_normalize_histories.py @@ -0,0 +1,32 @@ +from pathlib import Path +from typing import Any + +from spd.clustering.consts import MergesArray +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.pipeline.storage import ClusteringStorage, NormalizedEnsemble +from spd.log import logger + + +def normalize_and_save(storage: ClusteringStorage) -> MergesArray: + """Load merge histories from storage, normalize, and save ensemble""" + # load + histories: list[MergeHistory] = storage.load_histories() + ensemble: MergeHistoryEnsemble = MergeHistoryEnsemble(data=histories) + + # normalize + normalized_merge_array: MergesArray + normalized_merge_meta: dict[str, Any] + normalized_merge_array, normalized_merge_meta = ensemble.normalized() + + # save + ensemble_data: NormalizedEnsemble = NormalizedEnsemble( + merge_array=normalized_merge_array, + metadata=normalized_merge_meta, + ) + metadata_path: Path + array_path: Path + metadata_path, array_path = storage.save_ensemble(ensemble_data) + logger.info(f"metadata saved to {metadata_path}") + logger.info(f"merge array saved to {array_path}") + + return normalized_merge_array diff --git a/spd/clustering/pipeline/s4_compute_distances.py b/spd/clustering/pipeline/s4_compute_distances.py new file mode 100644 index 000000000..5c0e05124 --- /dev/null +++ b/spd/clustering/pipeline/s4_compute_distances.py @@ -0,0 +1,92 @@ +import wandb +from matplotlib import pyplot as plt +from matplotlib.axes import Axes + +from spd.clustering.consts import ( + DistancesArray, + DistancesMethod, +) +from spd.clustering.plotting.merge import plot_dists_distribution +from spd.log import logger + + +def create_clustering_report( + distances: DistancesArray, + method: DistancesMethod, + wandb_urls: list[str], + config_identifier: str, +) -> None: + """Create a WandB report with clustering results and distances plot""" + + # Extract entity/project from first URL for the report + first_url: str = wandb_urls[0] + entity: str + project: str + + if first_url.startswith("wandb:"): + run_path_parts: list[str] = first_url.replace("wandb:", "").split("/") + entity, project = run_path_parts[0], run_path_parts[1] + else: + # Parse full URL + parts: list[str] = first_url.split("/") + if "runs" in parts: + run_idx: int = parts.index("runs") + 1 + entity, project = parts[run_idx - 3], parts[run_idx - 2] + else: + logger.warning(f"Could not parse WandB URL: {first_url}") + return + + # Initialize WandB run for the summary report + with wandb.init( + project=project, + entity=entity, + name=f"clustering-summary-{config_identifier}", + tags=["clustering-summary", f"config:{config_identifier}", f"method:{method}"], + job_type="clustering-analysis", + config=dict(config_identifier=config_identifier, method=method), + ) as run: + # Create and log the distances distribution plot + ax: Axes = plot_dists_distribution( + distances=distances, mode="points", label=f"{method} distances" + ) + plt.title(f"Distance Distribution ({method})") + + # Only add legend if there are labeled artists + handles, _labels = ax.get_legend_handles_labels() + if handles: + plt.legend() + + # Get the figure from the axes + fig: plt.Figure | None = ax.get_figure() + assert fig is not None + + # Log the plot + run.log( + { + f"distances/{method}": wandb.Image(fig), + "clustering/config_identifier": config_identifier, + } + ) + + plt.close(fig) + + # Log metadata about the batch runs + run.log( + { + "batch_runs/urls": wandb_urls, + } + ) + + # Create a summary table of run information + run_ids: list[str] = [] + for url in wandb_urls: + if "runs/" in url: + run_id: str = url.split("runs/")[-1] + run_ids.append(run_id) + + if run_ids: + run.log({"batch_runs/run_ids": run_ids}) + + logger.info( + f"Created wandb clustering summary report with {len(wandb_urls)} batch runs from config {config_identifier}:\n{run.url}/overview" + ) diff --git a/spd/clustering/pipeline/storage.py b/spd/clustering/pipeline/storage.py new file mode 100644 index 000000000..cb38befd8 --- /dev/null +++ b/spd/clustering/pipeline/storage.py @@ -0,0 +1,300 @@ +"""Storage layer for clustering pipeline - handles all persistence operations.""" + +import json +from collections.abc import Iterator +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import numpy as np +import torch +from torch import Tensor + +from spd.clustering.consts import BatchTensor, DistancesArray, DistancesMethod, MergesArray +from spd.clustering.merge_run_config import ClusteringRunConfig + +if TYPE_CHECKING: + from spd.clustering.merge_history import MergeHistory + + +@dataclass +class DatasetBatches: + """Container for dataset batches and their configuration.""" + + batches: list[Tensor] + config: dict[str, Any] + + +@dataclass +class NormalizedEnsemble: + """Container for normalized merge array and metadata.""" + + merge_array: MergesArray + metadata: dict[str, Any] + + +def _write_text_to_path_and_return(path: Path, data: str) -> Path: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(data) + return path + + +class ClusteringStorage: + """Handles all file I/O operations for the clustering pipeline. + + This class provides a clean separation between data transformations and persistence, + making the pipeline more testable and flexible. + + Filesystem structure: + / + └── / # Optional run-specific subdirectory + ├── run_config.json # Run configuration parameters + ├── dataset/ # Dataset and batch storage + │ ├── dataset_config.json # Dataset configuration metadata + │ └── batches/ # Individual data batches + │ ├── batch_00.npz # Batch 0 (input_ids array) + │ ├── batch_01.npz # Batch 1 + │ └── ... + ├── merge_histories/ # Merge history per batch + │ ├── data_/ # Per-batch history directory + │ │ └── merge_history.zip # Compressed merge history + │ └── ... + ├── ensemble/ # Normalized ensemble results + │ ├── ensemble_meta.json # Ensemble metadata + │ └── ensemble_merge_array.npz # Normalized merge array + └── distances/ # Distance matrices + ├── distances..npz # Distance array for each method + └── ... + """ + + # Directory structure constants + _DATASET_DIR: str = "dataset" + _BATCHES_DIR: str = "batches" + _HISTORIES_DIR: str = "merge_histories" + _ENSEMBLE_DIR: str = "ensemble" + _DISTANCES_DIR: str = "distances" + _DASHBOARD_DIR: str = "dashboard" + + # File naming constants + _RUN_CONFIG_FILE: str = "run_config.json" + _DATASET_CONFIG_FILE: str = "dataset_config.json" + _ENSEMBLE_META_FILE: str = "ensemble_meta.json" + _ENSEMBLE_ARRAY_FILE: str = "ensemble_merge_array.npz" + _BATCH_FILE_FMT: str = "batch_{batch_idx:02d}.npz" + _HISTORY_FILE_FMT: str = "{batch_id}" + _MERGE_HISTORY_FILE: str = "merge_history.zip" + _DISTANCES_FILE_FMT: str = "distances.{method}.npz" + _MODEL_INFO_FILE: str = "model_info.json" + _MAX_ACTIVATIONS_FILE_FMT: str = "max_activations_i{iteration}_n{n_samples}.json" + + def __init__(self, base_path: Path, run_identifier: str | None = None): + """Initialize storage with base path and optional run identifier. + + Args: + base_path: Root directory for all storage operations + run_identifier: Optional identifier to create a subdirectory for this run + """ + self._base_path: Path = base_path + if run_identifier: + self._run_path = base_path / run_identifier + else: + self._run_path = base_path + + # Ensure base directory exists + self._run_path.mkdir(parents=True, exist_ok=True) + + # directories + + # make base and run path properties so we don't accidentally modify them + @property + def base_path(self) -> Path: + return self._base_path + + @property + def run_path(self) -> Path: + return self._run_path + + @property + def _dataset_dir(self) -> Path: + return self.run_path / self._DATASET_DIR + + # directories themselves private, use the storage/read methods to interact with them + @property + def _batches_dir(self) -> Path: + return self._dataset_dir / self._BATCHES_DIR + + @property + def _histories_dir(self) -> Path: + return self.run_path / self._HISTORIES_DIR + + @property + def _ensemble_dir(self) -> Path: + return self.run_path / self._ENSEMBLE_DIR + + @property + def _distances_dir(self) -> Path: + return self.run_path / self._DISTANCES_DIR + + @property + def _dashboard_dir(self) -> Path: + return self.run_path / self._DASHBOARD_DIR + + # files + @property + def run_config_file(self) -> Path: + return self.run_path / self._RUN_CONFIG_FILE + + @property + def dataset_config_file(self) -> Path: + return self._dataset_dir / self._DATASET_CONFIG_FILE + + @property + def ensemble_meta_file(self) -> Path: + return self._ensemble_dir / self._ENSEMBLE_META_FILE + + @property + def ensemble_array_file(self) -> Path: + return self._ensemble_dir / self._ENSEMBLE_ARRAY_FILE + + @property + def model_info_file(self) -> Path: + return self.run_path / self._MODEL_INFO_FILE + + @property + def dashboard_model_info_file(self) -> Path: + return self._dashboard_dir / self._MODEL_INFO_FILE + + # dynamic + + def batch_path(self, batch_idx: int) -> Path: + return self._batches_dir / self._BATCH_FILE_FMT.format(batch_idx=batch_idx) + + def history_path(self, batch_id: str) -> Path: + return ( + self._histories_dir + / self._HISTORY_FILE_FMT.format(batch_id=batch_id) + / self._MERGE_HISTORY_FILE + ) + + def max_activations_path(self, iteration: int, n_samples: int) -> Path: + return self._dashboard_dir / self._MAX_ACTIVATIONS_FILE_FMT.format( + iteration=iteration, n_samples=n_samples + ) + + # Batch storage methods + + def save_dataset_config(self, config: dict[str, Any]) -> Path: + return _write_text_to_path_and_return( + self.dataset_config_file, json.dumps(config, indent=2) + ) + + def save_batch(self, batch: BatchTensor, batch_idx: int) -> Path: + batch_path: Path = self.batch_path(batch_idx) + batch_path.parent.mkdir(parents=True, exist_ok=True) + + np.savez_compressed(batch_path, input_ids=batch.cpu().numpy()) + return batch_path + + def save_batches(self, batches: Iterator[BatchTensor], config: dict[str, Any]) -> list[Path]: + paths: list[Path] = [] + + self.save_dataset_config(config) + + for idx, batch in enumerate(batches): + path: Path = self.save_batch(batch, idx) + paths.append(path) + + return paths + + def load_batch(self, batch_path: Path) -> BatchTensor: + data: dict[str, np.ndarray] = np.load(batch_path) + return torch.tensor(data["input_ids"]) + + def get_batch_paths(self) -> list[Path]: + return sorted(self._batches_dir.glob("batch_*.npz")) + + # History storage methods + + def save_history(self, history: "MergeHistory", batch_id: str) -> Path: + history_path: Path = self.history_path(batch_id) + history_path.parent.mkdir(parents=True, exist_ok=True) + history.save(history_path) + return history_path + + def load_history(self, batch_id: str) -> "MergeHistory": + # Import only at runtime to avoid circular dependencies + from spd.clustering.merge_history import MergeHistory + + return MergeHistory.read(self.history_path(batch_id)) + + def get_history_paths(self) -> list[Path]: + return sorted(self._histories_dir.glob(f"*/{self._MERGE_HISTORY_FILE}")) + + def load_histories(self) -> list["MergeHistory"]: + # Import only at runtime to avoid circular dependencies + from spd.clustering.merge_history import MergeHistory + + return [MergeHistory.read(path) for path in self.get_history_paths()] + + # Ensemble related storage methods + + def save_ensemble(self, ensemble: NormalizedEnsemble) -> tuple[Path, Path]: + """Save normalized ensemble data""" + self._ensemble_dir.mkdir(parents=True, exist_ok=True) + + # Save metadata + metadata_path: Path = self.ensemble_meta_file + metadata_path.write_text(json.dumps(ensemble.metadata, indent=2)) + + # Save merge array + array_path: Path = self.ensemble_array_file + np.savez_compressed(array_path, merges=ensemble.merge_array) + + return metadata_path, array_path + + def save_distances(self, distances: DistancesArray, method: DistancesMethod) -> Path: + self._distances_dir.mkdir(parents=True, exist_ok=True) + + distances_path: Path = self._distances_dir / self._DISTANCES_FILE_FMT.format(method=method) + np.savez_compressed(distances_path, distances=distances) + return distances_path + + def load_distances(self, method: DistancesMethod) -> DistancesArray: + distances_path: Path = self._distances_dir / self._DISTANCES_FILE_FMT.format(method=method) + data: dict[str, np.ndarray] = np.load(distances_path) + return data["distances"] + + def save_run_config(self, config: ClusteringRunConfig) -> Path: + return _write_text_to_path_and_return( + self.run_config_file, config.model_dump_json(indent=2) + ) + + def load_run_config(self) -> ClusteringRunConfig: + return ClusteringRunConfig.read(self.run_config_file) + + # Dashboard storage methods + + def save_max_activations( + self, data: dict[int, dict[str, list[dict[str, Any]]]], iteration: int, n_samples: int + ) -> Path: + """Save max activations data to dashboard directory.""" + max_act_path: Path = self.max_activations_path(iteration, n_samples) + return _write_text_to_path_and_return(max_act_path, json.dumps(data, indent=2)) + + def save_model_info(self, model_info: dict[str, Any]) -> Path: + """Save model info to run directory.""" + return _write_text_to_path_and_return( + self.model_info_file, json.dumps(model_info, indent=2) + ) + + def save_model_info_to_dashboard(self, model_info: dict[str, Any]) -> Path: + """Save or copy model info to dashboard directory.""" + return _write_text_to_path_and_return( + self.dashboard_model_info_file, json.dumps(model_info, indent=2) + ) + + def load_model_info(self) -> dict[str, Any] | None: + """Load model info from run directory if it exists.""" + if self.model_info_file.exists(): + return json.loads(self.model_info_file.read_text()) + return None diff --git a/spd/clustering/plotting/__init__.py b/spd/clustering/plotting/__init__.py new file mode 100644 index 000000000..b048d1d24 --- /dev/null +++ b/spd/clustering/plotting/__init__.py @@ -0,0 +1 @@ +"""Plotting utilities for clustering module.""" diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py new file mode 100644 index 000000000..2411eca38 --- /dev/null +++ b/spd/clustering/plotting/activations.py @@ -0,0 +1,379 @@ +"""Plotting functions for activation visualizations.""" + +from collections.abc import Sequence +from pathlib import Path + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb +import wandb.sdk.wandb_run +from jaxtyping import Float, Int +from torch import Tensor + +from spd.clustering.activations import ProcessedActivations, compute_coactivatons +from spd.clustering.consts import ActivationsTensor, ClusterCoactivationShaped, ComponentLabels + + +def plot_activations( + processed_activations: ProcessedActivations, + save_dir: Path, + n_samples_max: int, + pdf_prefix: str = "activations", + figsize_raw: tuple[int, int] = (12, 4), + figsize_concat: tuple[int, int] = (12, 2), + figsize_coact: tuple[int, int] = (8, 6), + hist_scales: tuple[str, str] = ("lin", "log"), + hist_bins: int = 100, + do_sorted_samples: bool = False, + wandb_run: wandb.sdk.wandb_run.Run | None = None, +) -> None: + """Plot activation visualizations including raw, concatenated, sorted, and coactivations. + + Args: + activations: Dictionary of raw activations by module + act_concat: Concatenated activations tensor + coact: Coactivation matrix + labels: Component labels + save_dir: The directory to save the plots to + pdf_prefix: Prefix for PDF filenames + figsize_raw: Figure size for raw activations + figsize_concat: Figure size for concatenated activations + figsize_coact: Figure size for coactivations + hist_scales: Tuple of (x_scale, y_scale) where each is "lin" or "log" + hist_bins: Number of bins for histograms + """ + save_dir.mkdir(parents=True, exist_ok=True) + + act_dict: dict[str, ActivationsTensor] = processed_activations.activations_raw + act_concat: ActivationsTensor = processed_activations.activations + coact: ClusterCoactivationShaped = compute_coactivatons(act_concat) + labels: ComponentLabels = ComponentLabels(processed_activations.labels) + n_samples: int = act_concat.shape[0] + + # trim the activations if n_samples_max is specified + # clone here so we don't modify the original tensor + act_concat = act_concat[:n_samples_max].clone() + # we don't use the stuff in this dict again, so we can modify it in-place + for key in act_dict: + act_dict[key] = act_dict[key][:n_samples_max] + + # Update n_samples to reflect the truncated size + n_samples = act_concat.shape[0] + + # Raw activations + axs_act: Sequence[plt.Axes] + _fig1: plt.Figure + _fig1, axs_act = plt.subplots(len(act_dict), 1, figsize=figsize_raw) # pyright: ignore[reportAssignmentType] + if len(act_dict) == 1: + assert isinstance(axs_act, plt.Axes) + axs_act = [axs_act] + for i, (key, act) in enumerate(act_dict.items()): + act_raw_data: np.ndarray = act.T.cpu().numpy() + axs_act[i].matshow( + act_raw_data, aspect="auto", vmin=act_raw_data.min(), vmax=act_raw_data.max() + ) + axs_act[i].set_ylabel(f"components\n{key}") + axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})") + + fig1_fname = save_dir / f"{pdf_prefix}_raw.pdf" + _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/raw": wandb.Image(_fig1)}, step=0) + + # Close figure to free memory + plt.close(_fig1) + + # Concatenated activations + fig2: plt.Figure + ax2: plt.Axes + fig2, ax2 = plt.subplots(figsize=figsize_concat) + act_data: np.ndarray = act_concat.T.cpu().numpy() + im2 = ax2.matshow(act_data, aspect="auto", vmin=act_data.min(), vmax=act_data.max()) + ax2.set_title("Concatenated Activations") + + # Add component labeling on y-axis + add_component_labeling(ax2, labels, axis="y") + + plt.colorbar(im2) + + fig2_fname: Path = save_dir / f"{pdf_prefix}_concatenated.pdf" + fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/concatenated": wandb.Image(fig2)}, step=0) + + # Close figure to free memory + plt.close(fig2) + + # Concatenated activations, sorted samples + if do_sorted_samples: + # TODO: move sample sorting logic to its own function, see + # https://github.com/goodfire-ai/spd/pull/172/files#r2387275601 + fig3: plt.Figure + ax3: plt.Axes + fig3, ax3 = plt.subplots(figsize=figsize_concat) + + # Compute gram matrix (sample similarity) and sort samples using greedy ordering + gram_matrix: Float[Tensor, "samples samples"] = act_concat @ act_concat.T + + # Normalize gram matrix to get cosine similarity + norms: Float[Tensor, "samples 1"] = torch.norm(act_concat, dim=1, keepdim=True) + norms = torch.where(norms > 1e-8, norms, torch.ones_like(norms)) + similarity_matrix: Float[Tensor, "samples samples"] = gram_matrix / (norms @ norms.T) + + # Greedy ordering: start with sample most similar to all others + avg_similarity: Float[Tensor, " samples"] = similarity_matrix.mean(dim=1) + start_idx: int = int(torch.argmax(avg_similarity).item()) + + # Build ordering greedily + ordered_indices: list[int] = [start_idx] + remaining: set[int] = set(range(n_samples)) + remaining.remove(start_idx) + + # Greedily add the nearest unvisited sample + current_idx: int = start_idx + while remaining: + # Find the unvisited sample most similar to current + best_similarity: float = -1 + best_idx: int = -1 + for idx in remaining: + sim: float = similarity_matrix[current_idx, idx].item() + if sim > best_similarity: + best_similarity = sim + best_idx = idx + + ordered_indices.append(best_idx) + remaining.remove(best_idx) + current_idx = best_idx + + sorted_indices: Int[Tensor, " samples"] = torch.tensor( + ordered_indices, dtype=torch.long, device=act_concat.device + ) + act_concat_sorted: ActivationsTensor = act_concat[sorted_indices] + + # Handle log10 properly - add small epsilon to avoid log(0) + act_sorted_data: np.ndarray = act_concat_sorted.T.cpu().numpy() + act_sorted_log: np.ndarray = np.log10(act_sorted_data + 1e-10) + im3 = ax3.matshow( + act_sorted_log, aspect="auto", vmin=act_sorted_log.min(), vmax=act_sorted_log.max() + ) + ax3.set_title("Concatenated Activations $\\log_{10}$, Sorted Samples") + + # Add component labeling on y-axis + add_component_labeling(ax3, labels, axis="y") + + plt.colorbar(im3) + + fig3_fname: Path = save_dir / f"{pdf_prefix}_concatenated_sorted.pdf" + fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/concatenated_sorted": wandb.Image(fig3)}, step=0) + + # Close figure to free memory + plt.close(fig3) + + # Coactivations + fig4: plt.Figure + ax4: plt.Axes + fig4, ax4 = plt.subplots(figsize=figsize_coact) + coact_data: np.ndarray = coact.cpu().numpy() + im4 = ax4.matshow(coact_data, aspect="auto", vmin=coact_data.min(), vmax=coact_data.max()) + ax4.set_title("Coactivations") + + # Add component labeling on both axes + add_component_labeling(ax4, labels, axis="x") + add_component_labeling(ax4, labels, axis="y") + + plt.colorbar(im4) + + fig4_fname: Path = save_dir / f"{pdf_prefix}_coactivations.pdf" + fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/coactivations": wandb.Image(fig4)}, step=0) + + # Close figure to free memory + plt.close(fig4) + + # log coactivations + fig4_log: plt.Figure + ax4_log: plt.Axes + fig4_log, ax4_log = plt.subplots(figsize=figsize_coact) + assert np.all(coact_data >= 0) + coact_log_data: np.ndarray = np.log10(coact_data + 1e-6) + im4_log = ax4_log.matshow( + coact_log_data, aspect="auto", vmin=coact_log_data.min(), vmax=coact_log_data.max() + ) + ax4_log.set_title("Coactivations $\\log_{10}$") + # Add component labeling on both axes + add_component_labeling(ax4_log, labels, axis="x") + add_component_labeling(ax4_log, labels, axis="y") + plt.colorbar(im4_log) + fig4_log_fname: Path = save_dir / f"{pdf_prefix}_coactivations_log.pdf" + fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/coactivations_log": wandb.Image(fig4_log)}, step=0) + + # Close figure to free memory + plt.close(fig4_log) + + # Activation histograms + fig5: plt.Figure + ax5a: plt.Axes + ax5b: plt.Axes + ax5c: plt.Axes + fig5, (ax5a, ax5b, ax5c) = plt.subplots(1, 3, figsize=(15, 4)) # pyright: ignore[reportGeneralTypeIssues] + + x_scale: str + y_scale: str + x_scale, y_scale = hist_scales + + # Histogram 1: All activations + all_activations: Float[Tensor, " samples*n_components"] = act_concat.flatten() + all_vals: np.ndarray = all_activations.cpu().numpy() + hist_counts: np.ndarray + bin_edges: np.ndarray + hist_counts, bin_edges = np.histogram(all_vals, bins=hist_bins) + bin_centers: np.ndarray = (bin_edges[:-1] + bin_edges[1:]) / 2 + ax5a.plot(bin_centers, hist_counts, color="blue", linewidth=2) + ax5a.set_title("All Activations") + ax5a.set_xlabel("Activation Value") + ax5a.set_ylabel("Count") + if x_scale == "log": + ax5a.set_xscale("log") + if y_scale == "log": + ax5a.set_yscale("log") + ax5a.grid(True, alpha=0.3) + + # Histogram 2: Activations per component + n_components: int = act_concat.shape[1] + + # Common bin edges for all component histograms + all_min: float = float(all_vals.min()) + all_max: float = float(all_vals.max()) + common_bins: np.ndarray = np.linspace(all_min, all_max, hist_bins) + common_centers: np.ndarray = (common_bins[:-1] + common_bins[1:]) / 2 + + # Get unique label prefixes and assign colors + label_prefixes: list[str] = [label.split(":")[0] for label in labels] + unique_prefixes: list[str] = list(dict.fromkeys(label_prefixes)) # Preserve order + colors: Sequence[tuple[int, int, int]] = mpl.colormaps["tab10"]( + np.linspace(0, 1, len(unique_prefixes)) + ) # pyright: ignore[reportAssignmentType] + prefix_colors: dict[str, tuple[int, int, int]] = { + prefix: colors[i] for i, prefix in enumerate(unique_prefixes) + } + + for comp_idx in range(n_components): + component_activations: Float[Tensor, " n_samples"] = act_concat[:, comp_idx] + comp_vals: np.ndarray = component_activations.cpu().numpy() + hist_counts, _ = np.histogram(comp_vals, bins=common_bins, density=True) + + # Get color based on label prefix + prefix: str = label_prefixes[comp_idx] + color: tuple[int, int, int] = prefix_colors[prefix] + + ax5b.plot(common_centers, hist_counts, color=color, alpha=0.1, linewidth=1) + + ax5b.set_title(f"Per Component ({n_components} components)") + ax5b.set_xlabel("Activation Value") + ax5b.set_ylabel("Density") + if x_scale == "log": + ax5b.set_xscale("log") + if y_scale == "log": + ax5b.set_yscale("log") + ax5b.grid(True, alpha=0.3) + + # Histogram 3: Activations per sample + for sample_idx in range(n_samples): + sample_activations: Float[Tensor, " n_components"] = act_concat[sample_idx, :] + sample_vals: np.ndarray = sample_activations.cpu().numpy() + hist_counts, _ = np.histogram(sample_vals, bins=common_bins, density=True) + ax5c.plot(common_centers, hist_counts, color="blue", alpha=0.1, linewidth=1) + + ax5c.set_title(f"Per Sample ({n_samples} samples)") + ax5c.set_xlabel("Activation Value") + ax5c.set_ylabel("Density") + if x_scale == "log": + ax5c.set_xscale("log") + if y_scale == "log": + ax5c.set_yscale("log") + ax5c.grid(True, alpha=0.3) + + plt.tight_layout() + + fig5_fname: Path = save_dir / f"{pdf_prefix}_histograms.pdf" + fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/histograms": wandb.Image(fig5)}, step=0) + + # Close figure to free memory + plt.close(fig5) + + +def add_component_labeling( + ax: plt.Axes, component_labels: ComponentLabels, axis: str = "x" +) -> None: + """Add component labeling using major/minor ticks to show module boundaries. + + Args: + ax: Matplotlib axis to modify + component_labels: List of component labels in format "module:index" + axis: Which axis to label ('x' or 'y') + """ + if not component_labels: + return + + # Extract module information + module_changes: list[int] = [] + current_module: str = component_labels[0].split(":")[0] + module_labels: list[str] = [] + + for i, label in enumerate(component_labels): + module: str = label.split(":")[0] + if module != current_module: + module_changes.append(i) + module_labels.append(current_module) + current_module = module + module_labels.append(current_module) + + # Set up major and minor ticks + # Minor ticks: every 10 components + minor_ticks: list[int] = list(range(0, len(component_labels), 10)) + + # Major ticks: module boundaries (start of each module) + major_ticks: list[int] = [0] + module_changes + major_labels: list[str] = module_labels + + if axis == "x": + ax.set_xticks(minor_ticks, minor=True) + ax.set_xticks(major_ticks) + ax.set_xticklabels(major_labels) + ax.set_xlim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="x", which="minor", length=2, width=0.5) + ax.tick_params(axis="x", which="major", length=6, width=1.5) + for x in major_ticks: + ax.axvline(x - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) + else: + ax.set_yticks(minor_ticks, minor=True) + ax.set_yticks(major_ticks) + ax.set_yticklabels(major_labels) + ax.set_ylim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="y", which="minor", length=2, width=0.5) + ax.tick_params(axis="y", which="major", length=6, width=1.5) + for y in major_ticks: + ax.axhline(y - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py new file mode 100644 index 000000000..e470b3114 --- /dev/null +++ b/spd/clustering/plotting/merge.py @@ -0,0 +1,327 @@ +"""Plotting functions for merge visualizations.""" + +from typing import Any, Literal + +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, ComponentLabels, DistancesArray +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.merge_history import MergeHistory +from spd.clustering.util import format_scientific_latex + +DEFAULT_PLOT_CONFIG: dict[str, Any] = dict( + figsize=(16, 10), + tick_spacing=5, + save_pdf=False, + pdf_prefix="merge_iteration", +) + + +def plot_merge_matrix( + merge_matrix: Bool[Tensor, "k_groups n_components"], + show: bool = True, + figsize: tuple[int, int] = (10, 3), + show_row_sums: bool | None = None, + ax: "plt.Axes | None" = None, + component_labels: ComponentLabels | None = None, +) -> None: + import matplotlib.pyplot as plt + + k_groups: int + k_groups, _ = merge_matrix.shape + group_sizes: Int[Tensor, " k_groups"] = merge_matrix.sum(dim=1) + + if show_row_sums is None: + show_row_sums = k_groups <= 20 + + ax_lbl: plt.Axes | None = None + if ax is not None: + show_row_sums = False # don't show row sums if we have an ax to plot on + ax_mat = ax + assert not show_row_sums + else: + if show_row_sums: + _fig, (ax_mat, ax_lbl) = plt.subplots( # pyright: ignore[reportGeneralTypeIssues] + 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [10, 1]} + ) + else: + _fig, ax_mat = plt.subplots(figsize=figsize) + + ax_mat.matshow(merge_matrix.cpu(), aspect="auto", cmap="Blues", interpolation="nearest") + ax_mat.set_xlabel("Components") + ax_mat.set_ylabel("Groups") + ax_mat.set_title("Merge Matrix") + + # Add component labeling if component labels are provided + if component_labels is not None: + # Import the function here to avoid circular imports + from spd.clustering.plotting.activations import add_component_labeling + + add_component_labeling(ax_mat, component_labels, axis="x") + + if show_row_sums: + assert ax_lbl is not None + ax_lbl.set_xlim(0, 1) + ax_lbl.set_ylim(-0.5, k_groups - 0.5) + ax_lbl.invert_yaxis() + ax_lbl.set_title("Row Sums") + ax_lbl.axis("off") + for i, size in enumerate(group_sizes): + ax_lbl.text(0.5, i, str(size.item()), va="center", ha="center", fontsize=12) + + plt.tight_layout() + if show: + plt.show() + + +def plot_merge_iteration( + current_merge: GroupMerge, + current_coact: ClusterCoactivationShaped, + costs: ClusterCoactivationShaped, + # pair_cost: float, + iteration: int, + component_labels: ComponentLabels | None = None, + plot_config: dict[str, Any] | None = None, + nan_diag: bool = True, + show: bool = False, +) -> plt.Figure: + """Plot merge iteration results with merge tree, coactivations, and costs. + + Args: + current_merge: Current merge state + current_coact: Current coactivation matrix + costs: Current cost matrix + pair_cost: Cost of selected merge pair + iteration: Current iteration number + component_labels: Component labels for axis labeling + plot_config: Plot configuration settings + nan_diag: Whether to set diagonal to NaN for visualization + show: Whether to display the plot (default: False) + + Returns: + The matplotlib figure object + + Note: + Caller is responsible for closing the returned figure with plt.close(fig) + to prevent memory leaks. + """ + plot_config_: dict[str, Any] = { + **DEFAULT_PLOT_CONFIG, + **(plot_config or {}), + } + axs: list[plt.Axes] + fig, axs = plt.subplots( # pyright: ignore[reportAssignmentType] + 1, 3, figsize=plot_config_["figsize"], sharey=True, gridspec_kw={"width_ratios": [2, 1, 1]} + ) + + # Merge plot + plot_merge_matrix( + current_merge.to_matrix(), + ax=axs[0], + show=False, + component_labels=component_labels, + ) + + axs[0].set_title("Merge") + + # Coactivations plot + coact_min: float = current_coact.min().item() + coact_max: float = current_coact.max().item() + if nan_diag: + current_coact = current_coact.clone() + current_coact.fill_diagonal_(np.nan) + axs[1].matshow(current_coact.cpu().numpy(), aspect="equal") + coact_min_str: str = format_scientific_latex(coact_min) + coact_max_str: str = format_scientific_latex(coact_max) + axs[1].set_title(f"Coactivations\n[{coact_min_str}, {coact_max_str}]") + + # Setup ticks for coactivations + k_groups: int = current_coact.shape[0] + minor_ticks: list[int] = list(range(0, k_groups, plot_config_["tick_spacing"])) + axs[1].set_yticks(minor_ticks) + axs[1].set_xticks(minor_ticks) + axs[1].set_xticklabels([]) # Remove x-axis tick labels but keep ticks + + # Costs plot + costs_min: float = costs.min().item() + costs_max: float = costs.max().item() + if nan_diag: + costs = costs.clone() + costs.fill_diagonal_(np.nan) + axs[2].matshow(costs.cpu().numpy(), aspect="equal") + costs_min_str: str = format_scientific_latex(costs_min) + costs_max_str: str = format_scientific_latex(costs_max) + axs[2].set_title(f"Costs\n[{costs_min_str}, {costs_max_str}]") + + # Setup ticks for costs + axs[2].set_yticks(minor_ticks) + axs[2].set_xticks(minor_ticks) + axs[2].set_xticklabels([]) # Remove x-axis tick labels but keep ticks + + # fig.suptitle(f"Iteration {iteration} with cost {pair_cost:.4f}") + fig.suptitle(f"Iteration {iteration}") + plt.tight_layout() + + if plot_config_["save_pdf"]: + fig.savefig( + f"{plot_config_['pdf_prefix']}_iter_{iteration:03d}.pdf", bbox_inches="tight", dpi=300 + ) + + if show: + plt.show() + + return fig + + +def plot_dists_distribution( + distances: DistancesArray, + mode: Literal["points", "dist"] = "points", + label: str | None = None, + ax: plt.Axes | None = None, + kwargs_fig: dict[str, Any] | None = None, + kwargs_plot: dict[str, Any] | None = None, +) -> plt.Axes: + n_iters: int = distances.shape[0] + n_ens: int = distances.shape[1] + assert distances.shape[2] == n_ens, "Distances must be square" + + # Ensure ax and kwargs_fig are not both provided + if ax is not None and kwargs_fig is not None: + raise ValueError("Cannot provide both ax and kwargs_fig") + + dists_flat: Float[np.ndarray, " n_iters n_ens*n_ens"] = distances.reshape( + distances.shape[0], -1 + ) + + # Create figure if ax not provided + if ax is None: + _fig, ax_ = plt.subplots( # pyright: ignore[reportCallIssue] + 1, + 1, + **dict( + figsize=(8, 5), # pyright: ignore[reportArgumentType] + **(kwargs_fig or {}), + ), + ) + else: + ax_ = ax + + if mode == "points": + # Original points mode + n_samples: int = dists_flat.shape[1] + for i in range(n_iters): + ax_.plot( + np.full((n_samples), i), + dists_flat[i], + **dict( # pyright: ignore[reportArgumentType] + marker="o", + linestyle="", + color="blue", + alpha=min(1, 10 / (n_ens * n_ens)), + markersize=5, + markeredgewidth=0, + **(kwargs_plot or {}), + ), + ) + elif mode == "dist": + # Distribution statistics mode + # Generate a random color for this plot + color: Float[np.ndarray, " 3"] = np.random.rand(3) + + # Calculate statistics for each iteration + mins: list[float] = [] + maxs: list[float] = [] + means: list[float] = [] + medians: list[float] = [] + q1s: list[float] = [] + q3s: list[float] = [] + + for i in range(n_iters): + # Filter out NaN values (diagonal and upper triangle) + valid_dists: Float[np.ndarray, " n_valid"] = dists_flat[i][~np.isnan(dists_flat[i])] + if len(valid_dists) > 0: + mins.append(np.min(valid_dists)) + maxs.append(np.max(valid_dists)) + means.append(float(np.mean(valid_dists))) + medians.append(float(np.median(valid_dists))) + q1s.append(float(np.percentile(valid_dists, 25))) + q3s.append(float(np.percentile(valid_dists, 75))) + else: + # Handle case with no valid distances + mins.append(np.nan) + maxs.append(np.nan) + means.append(np.nan) + medians.append(np.nan) + q1s.append(np.nan) + q3s.append(np.nan) + + iterations: Int[np.ndarray, " n_iters"] = np.arange(n_iters) + + # Plot statistics + ax_.plot(iterations, mins, "-", color=color, alpha=0.5) + ax_.plot(iterations, maxs, "-", color=color, alpha=0.5) + ax_.plot(iterations, means, "-", color=color, linewidth=2, label=label) + ax_.plot(iterations, medians, "--", color=color, linewidth=2) + ax_.plot(iterations, q1s, ":", color=color, alpha=0.7) + ax_.plot(iterations, q3s, ":", color=color, alpha=0.7) + + # Shade between quartiles + ax_.fill_between(iterations, q1s, q3s, color=color, alpha=0.2) + + ax_.set_xlabel("Iteration #") + ax_.set_ylabel("distance") + ax_.set_title("Distribution of pairwise distances between group merges in an ensemble") + + return ax_ + + +def plot_merge_history_cluster_sizes( + history: MergeHistory, + figsize: tuple[int, int] = (10, 5), + fmt: str = "png", + file_prefix: str | None = None, +) -> plt.Figure: + """Plot cluster sizes over iterations. + + Note: + Caller is responsible for closing the returned figure with plt.close(fig) + to prevent memory leaks. + """ + k_groups_t: Int[Tensor, " n_iters"] = history.merges.k_groups + valid_mask: Bool[Tensor, " n_iters"] = k_groups_t.ne(-1) + has_data: bool = bool(valid_mask.any().item()) + if not has_data: + raise ValueError("No populated iterations in history.k_groups") + + group_idxs_all: Int[Tensor, " n_iters n_components"] = history.merges.group_idxs[valid_mask] + k_groups_all: Int[Tensor, " n_iters"] = k_groups_t[valid_mask] + max_k: int = int(k_groups_all.max().item()) + + counts_list: list[Int[Tensor, " max_k"]] = [ + torch.bincount(row[row.ge(0)], minlength=max_k) # per-iteration cluster sizes + for row in group_idxs_all + ] + counts: Int[Tensor, " n_iters max_k"] = torch.stack(counts_list, dim=0) + + mask_pos: Bool[Tensor, " n_iters max_k"] = counts.gt(0) + it_idx_t, grp_idx_t = torch.nonzero(mask_pos, as_tuple=True) + xs_t: Float[Tensor, " n_points"] = it_idx_t.to(torch.float32) + sizes_t: Float[Tensor, " n_points"] = counts[it_idx_t, grp_idx_t].to(torch.float32) + + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + xs_t.cpu().numpy(), sizes_t.cpu().numpy(), "bo", markersize=3, alpha=0.15, markeredgewidth=0 + ) + ax.set_xlabel("Iteration") + ax.set_ylabel("Cluster size") + ax.set_yscale("log") + ax.set_title("Distribution of cluster sizes over time") + + if file_prefix is not None: + fig.savefig(f"{file_prefix}_cluster_sizes.{fmt}", bbox_inches="tight", dpi=300) + + return fig diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py new file mode 100644 index 000000000..56cee4d84 --- /dev/null +++ b/spd/clustering/scripts/main.py @@ -0,0 +1,81 @@ +import argparse +from pathlib import Path + +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.pipeline.clustering_pipeline import main +from spd.log import logger +from spd.settings import REPO_ROOT + + +def cli() -> None: + """Command-line interface for clustering.""" + + logger.set_format("console", style="terse") + + parser: argparse.ArgumentParser = argparse.ArgumentParser( + description="Run clustering on a dataset using clean architecture" + ) + parser.add_argument( + "--config", + "-c", + type=Path, + required=True, + help="Path to the merge run config JSON/YAML/TOML file", + ) + parser.add_argument( + "--base-path", + "-p", + type=Path, + default=REPO_ROOT / ".data/clustering/", + help="Base path for saving clustering outputs", + ) + parser.add_argument( + "--devices", + "-d", + type=str, + default=None, + help="Comma-separated list of devices to use for clustering (e.g., 'cuda:0,cuda:1')", + ) + parser.add_argument( + "--workers-per-device", + "-x", + type=int, + default=1, + help="Maximum number of concurrent clustering processes per device (default: 1)", + ) + args: argparse.Namespace = parser.parse_args() + + logger.info("Starting clustering pipeline") + + # Parse devices + devices: list[str] + if args.devices is None: + import torch + + devices = ["cuda" if torch.cuda.is_available() else "cpu"] + logger.info(f"No devices specified, auto-detected: {devices}") + else: + devices = args.devices.split(",") + logger.info(f"Using specified devices: {devices}") + + # Load and augment config + # Note that the defaults for args here always override the default values in `RunConfig` itself, + # but we must have those defaults to avoid type issues + logger.info(f"Loading config from {args.config}") + config: ClusteringRunConfig = ClusteringRunConfig.read(args.config) + config.base_path = args.base_path + config.devices = devices + config.workers_per_device = args.workers_per_device + + logger.info(f"Configuration loaded: {config.config_identifier}") + logger.info(f"Base path: {config.base_path}") + logger.info(f"{config.workers_per_device = }, {config.devices = }, {config.n_batches = }") + + # Run + main(config=config) + + logger.info("Clustering pipeline completed successfully") + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/util.py b/spd/clustering/util.py new file mode 100644 index 000000000..bd11e2fd4 --- /dev/null +++ b/spd/clustering/util.py @@ -0,0 +1,18 @@ +from collections.abc import Callable + + +def format_scientific_latex(value: float) -> str: + """Format a number in LaTeX scientific notation style.""" + if value == 0: + return r"$0$" + + import math + + exponent: int = int(math.floor(math.log10(abs(value)))) + mantissa: float = value / (10**exponent) + + return f"${mantissa:.2f} \\times 10^{{{exponent}}}$" + + +ModuleFilterSource = str | Callable[[str], bool] | set[str] | None +ModuleFilterFunc = Callable[[str], bool] diff --git a/spd/clustering/wandb_tensor_info.py b/spd/clustering/wandb_tensor_info.py new file mode 100644 index 000000000..14463ab88 --- /dev/null +++ b/spd/clustering/wandb_tensor_info.py @@ -0,0 +1,169 @@ +"""Minimal WandB tensor logging utilities using muutils.""" + +import warnings +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import wandb +import wandb.sdk.wandb_run +from muutils.dbg import dbg_tensor +from muutils.tensor_info import array_info +from torch import Tensor + + +def wandb_log_tensor( + run: wandb.sdk.wandb_run.Run, + data: Tensor | dict[str, Tensor], + name: str, + step: int, + single: bool = False, +) -> None: + """Log tensor(s) with stats to WandB as metrics and histograms. + + Args: + run: Current WandB run (None if WandB disabled) + data: Either a Tensor or dict[str, Tensor] + name: Name for logging + step: WandB step + single: True if this tensor is only logged once (component activations) + """ + try: + if isinstance(data, dict): + # Handle dict of tensors + for key, tensor in data.items(): + full_name: str = f"{name}.{key}" + _log_one(run, tensor, full_name, step, single=single) + else: + # Handle single tensor + _log_one(run, data, name, step, single=single) + except Exception as e: + warnings.warn(f"Failed to log tensor {name}: {e}") # noqa: B028 + dbg_tensor(data) + raise e + + +def _create_histogram( + info: dict[str, Any], tensor: Tensor, name: str, logy: bool = True +) -> plt.Figure: + """Create matplotlib histogram with stats markers.""" + # sanity check + if info["status"] != "ok" or info["size"] == 0: + fig: plt.Figure + ax: plt.Axes + fig, ax = plt.subplots(figsize=(8, 6)) + ax.text(0.5, 0.5, f"{info['status']}", ha="center", va="center") + ax.set_title(f"{name} - {info['status']}") + return fig + + # make basic hist + values: np.ndarray = tensor.flatten().detach().cpu().numpy() + if info["has_nans"]: + values = values[~np.isnan(values)] + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.hist(values, bins=50, alpha=0.7, edgecolor="black", linewidth=0.5) + + # Add stat lines + mean_val: float = info["mean"] or float("nan") + median_val: float = info["median"] or float("nan") + std_val: float = info["std"] or float("nan") + + if info["mean"] is not None: + ax.axvline( + mean_val, + color="red", + linestyle="-", + linewidth=2, + label="$\\mu$", + ) + ax.axvline( + median_val, + color="blue", + linestyle="-", + linewidth=2, + label="$\\tilde{x}$", + ) + if std_val: + ax.axvline( + mean_val + std_val, + color="orange", + linestyle="--", + linewidth=1.5, + alpha=0.8, + label="$\\mu+\\sigma$", + ) + ax.axvline( + mean_val - std_val, + color="orange", + linestyle="--", + linewidth=1.5, + alpha=0.8, + label="$\\mu-\\sigma$", + ) + + # Build informative title with tensor stats + shape_str: str = str(tuple(info["shape"])) if "shape" in info else "unknown" + dtype_str: str = str(info.get("dtype", "unknown")).replace("torch.", "") + + title_line1: str = f"{name}" + title_line2: str = f"shape={shape_str}, dtype={dtype_str}" + title_line3: str = ( + f"range=[{info['min']:.3g}, {info['max']:.3g}], " + f"$\\mu$={mean_val:.3g}, $\\tilde{{x}}$={median_val:.3g}, $\\sigma$={std_val:.3g}" + ) + + # Combine into multi-line title + full_title: str = f"{title_line1}\n{title_line2}\n{title_line3}" + ax.set_title(full_title, fontsize=10) + ax.set_xlabel("Value") + ax.set_ylabel("Count") + ax.legend() + ax.grid(True, alpha=0.3) + if logy: + ax.set_yscale("log") + + plt.tight_layout() + return fig + + +def _log_one( + run: wandb.sdk.wandb_run.Run, + tensor_: Tensor, + name: str, + step: int, + single: bool = False, + # use_log_counts: bool = True, +) -> None: + """Log a single tensor.""" + info: dict[str, Any] = array_info(tensor_) + + if single: + # For single-use logging, log a single histogram as a figure + hist_fig: plt.Figure = _create_histogram(info=info, tensor=tensor_, name=name) + histogram_key: str = f"single_hists/{name}" + run.log({histogram_key: wandb.Image(hist_fig)}, step=step) + plt.close(hist_fig) # Close figure to free memory + else: + # Log numeric stats as metrics (viewable like loss) using dict comprehension + stats_to_log: dict[str, float | wandb.Histogram] = { + f"tensor_metrics/{name}/{key}": info[key] + for key in ["mean", "std", "median", "min", "max"] + if key in info and info[key] is not None + } + + # For regular logging, use wandb.Histogram directly + hist_key: str = f"tensor_histograms/{name}" + stats_to_log[hist_key] = wandb.Histogram(tensor_.flatten().cpu().numpy()) # pyright: ignore[reportArgumentType] + + # Add nan_percent if present + nan_percent: float | None = info["nan_percent"] + # TODO: this is a hack for when the tensor is empty + if nan_percent is None: + dbg_tensor(tensor_) + nan_percent = float("nan") + if nan_percent > 0: + stats_to_log[f"tensor_metrics/{name}/nan_percent"] = nan_percent + + if stats_to_log: + run.log(stats_to_log, step=step) diff --git a/spd/models/component_model.py b/spd/models/component_model.py index e0fa0a17d..1ba4064c0 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -308,6 +308,7 @@ def __call__( def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache: return super().__call__(*args, **kwargs) + # TODO: why doesnt this have overrides??? @override def forward( self, diff --git a/tests/clustering/math/test_perm_invariant_hamming.py b/tests/clustering/math/test_perm_invariant_hamming.py new file mode 100644 index 000000000..7d2bf4740 --- /dev/null +++ b/tests/clustering/math/test_perm_invariant_hamming.py @@ -0,0 +1,123 @@ +from itertools import permutations + +import numpy as np +import pytest + +from spd.clustering.math.perm_invariant_hamming import perm_invariant_hamming_matrix + +# pyright complains about the types when calling perm_invariant_hamming +# pyright: reportCallIssue=false + + +def brute_force_min_hamming(a: np.ndarray, b: np.ndarray) -> int: + """Exhaustive check for small k.""" + k = int(max(a.max(), b.max()) + 1) + best = len(a) + for perm in permutations(range(k)): + mapping = np.array(perm) + best = min(best, int((mapping[a] != b).sum())) + return best + + +def test_identity() -> None: + """a == b should give distance 0.""" + a = np.array([0, 1, 2, 1, 0]) + b = a.copy() + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + # Distance between row 1 and row 0 should be 0 + assert D[1, 0] == 0 + + +def test_all_one_group() -> None: + """All rows belong to one group in both arrays (possibly different labels).""" + a = np.zeros(10, dtype=int) + b = np.ones(10, dtype=int) # different label but identical grouping + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 + + +def test_permuted_labels() -> None: + a = np.array([0, 2, 1, 1, 0]) + b = np.array([1, 0, 0, 2, 1]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 1 + + +def test_swap_two_labels() -> None: + a = np.array([0, 0, 1, 1]) + b = np.array([1, 1, 0, 0]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 + + +def test_random_small_bruteforce() -> None: + rng = np.random.default_rng(0) + for _ in range(50): + n = 7 + k = 3 + a = rng.integers(0, k, size=n) + b = rng.integers(0, k, size=n) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + d_alg = D[1, 0] + d_true = brute_force_min_hamming(a, b) + assert d_alg == d_true + + +def test_shape_mismatch() -> None: + a = np.array([0, 1, 2]) + b = np.array([0, 1]) + with pytest.raises((ValueError, IndexError)): + # This should fail when trying to create the matrix due to shape mismatch + X = np.array([a, b]) + perm_invariant_hamming_matrix(X) + + +def test_matrix_multiple_pairs() -> None: + """Test the matrix function with multiple label vectors.""" + a = np.array([0, 0, 1, 1]) + b = np.array([2, 2, 3, 3]) # Should be distance 0 (perfect mapping) + c = np.array([0, 1, 0, 1]) # Should be distance 2 from both a and b + X = np.array([a, b, c]) + D = perm_invariant_hamming_matrix(X) + + assert D[1, 0] == 0 # a and b should have distance 0 + assert D[2, 0] == 2 # a and c should have distance 2 + assert D[2, 1] == 2 # b and c should have distance 2 + + +def test_matrix_upper_triangle_nan() -> None: + """Test that upper triangle and diagonal are NaN.""" + a = np.array([0, 1, 0]) + b = np.array([1, 0, 1]) + c = np.array([0, 0, 1]) + X = np.array([a, b, c]) + D = perm_invariant_hamming_matrix(X) + + # Diagonal should be NaN + assert np.isnan(D[0, 0]) + assert np.isnan(D[1, 1]) + assert np.isnan(D[2, 2]) + + # Upper triangle should be NaN + assert np.isnan(D[0, 1]) + assert np.isnan(D[0, 2]) + assert np.isnan(D[1, 2]) + + # Lower triangle should have actual distances + assert not np.isnan(D[1, 0]) + assert not np.isnan(D[2, 0]) + assert not np.isnan(D[2, 1]) + + +def test_unused_labels() -> None: + """Test when arrays don't use all labels 0..k-1.""" + a = np.array([0, 0, 3, 3]) # skips 1, 2 + b = np.array([1, 1, 2, 2]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py new file mode 100644 index 000000000..df8dd0a2f --- /dev/null +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -0,0 +1,196 @@ +# %% +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import torch +from muutils.dbg import dbg_auto +from torch import Tensor + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import ( + plot_dists_distribution, + plot_merge_iteration, +) +from spd.configs import Config +from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.registry import EXPERIMENT_REGISTRY +from spd.utils.data_utils import DatasetGeneratedDataLoader + +DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" +TEMP_DIR: Path = Path( + "tests/.temp" +) # save to an actual dir that is gitignored, so users can view plots +TEMP_DIR.mkdir(parents=True, exist_ok=True) + + +# pyright: reportUnusedParameter=false + +# magic autoreload +# %load_ext autoreload +# %autoreload 2 + +# %% +# Load model +# ============================================================ +_CANONICAL_RUN: str | None = EXPERIMENT_REGISTRY["resid_mlp2"].canonical_run +assert _CANONICAL_RUN is not None, "No canonical run found for resid_mlp2 experiment" +SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(_CANONICAL_RUN) +MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) +MODEL.to(DEVICE) +SPD_CONFIG: Config = SPD_RUN.config + +# %% +# Setup dataset and dataloader +# ============================================================ +N_SAMPLES: int = 128 + +DATASET: ResidMLPDataset = ResidMLPDataset( + n_features=MODEL.target_model.config.n_features, # pyright: ignore[reportAttributeAccessIssue, reportArgumentType], + feature_probability=SPD_CONFIG.task_config.feature_probability, # pyright: ignore[reportAttributeAccessIssue] + device=DEVICE, + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=None, + label_coeffs=None, + data_generation_type=SPD_CONFIG.task_config.data_generation_type, # pyright: ignore[reportAttributeAccessIssue] +) + +dbg_auto( + dict( + n_features=DATASET.n_features, + feature_probability=DATASET.feature_probability, + data_generation_type=DATASET.data_generation_type, + ) +) +DATALOADER = DatasetGeneratedDataLoader(DATASET, batch_size=N_SAMPLES, shuffle=False) + +# %% +# Get component activations +# ============================================================ +# Get a single batch from the dataloader +BATCH_DATA: tuple[Tensor, Tensor] = next(iter(DATALOADER)) +BATCH: Tensor = BATCH_DATA[0] + +COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + device=DEVICE, + batch=BATCH, + sigmoid_type="hard", +) + +dbg_auto(COMPONENT_ACTS) + +# %% + +FILTER_DEAD_THRESHOLD: float = 0.1 + +# Process activations +# ============================================================ +PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) + + +plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, +) + +# %% +# run the merge iteration +# ============================================================ + +MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.1, + alpha=1, + iters=int(PROCESSED_ACTIVATIONS.n_components_alive * 0.9), + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.0}, + pop_component_prob=0, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) + + +def _plot_func( + current_coact: torch.Tensor, + component_labels: ComponentLabels, + current_merge: Any, + costs: torch.Tensor, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: torch.Tensor, +) -> None: + if (iter_idx % 50 == 0 and iter_idx > 0) or iter_idx == 1: + plot_merge_iteration( + current_merge=current_merge, + current_coact=current_coact, + costs=costs, + iteration=iter_idx, + component_labels=component_labels, + show=True, # Show the plot interactively + ) + + +MERGE_HIST: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + batch_id="batch_0", + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=_plot_func, +) + +# %% +# Plot merge history +# ============================================================ + +# plt.hist(mh[270]["merges"].components_per_group, bins=np.linspace(0, 56, 57)) +# plt.yscale("log") +# plt.xscale("log") + + +# %% +# compute and plot distances in an ensemble +# ============================================================ + +# Modern approach: run merge_iteration multiple times to create ensemble +ENSEMBLE_SIZE: int = 4 +HISTORIES: list[MergeHistory] = [] +for i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + batch_id=f"batch_{i}", + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, + ) + HISTORIES.append(HISTORY) + +ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) + +DISTANCES = ENSEMBLE.get_distances(method="perm_invariant_hamming") + +plot_dists_distribution( + distances=DISTANCES, + mode="points", + # label="v1" +) +plt.legend() diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py new file mode 100644 index 000000000..6ede368f0 --- /dev/null +++ b/tests/clustering/scripts/cluster_ss.py @@ -0,0 +1,129 @@ +# %% +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +from jaxtyping import Int +from muutils.dbg import dbg_auto +from torch import Tensor + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.pipeline.s1_split_dataset import split_dataset +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import plot_dists_distribution +from spd.models.component_model import ComponentModel, SPDRunInfo + +DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" +TEMP_DIR: Path = Path( + "tests/.temp" +) # save to an actual dir that is gitignored, so users can view plots +TEMP_DIR.mkdir(parents=True, exist_ok=True) + +# magic autoreload +# %load_ext autoreload +# %autoreload 2 + +# %% +# Load model and dataset +# ============================================================ +MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" + +SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) +MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) +MODEL.to(DEVICE) +SPD_CONFIG = SPD_RUN.config + +# Use split_dataset with RunConfig to get real data +CONFIG: ClusteringRunConfig = ClusteringRunConfig( + merge_config=MergeConfig(), + model_path=MODEL_PATH, + task_name="lm", + n_batches=1, + batch_size=2, +) +BATCHES, _ = split_dataset(config=CONFIG) + +# %% +# Load data batch +# ============================================================ +DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) + +# %% +# Get component activations +# ============================================================ +COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + batch=DATA_BATCH, + device=DEVICE, + sigmoid_type="hard", +) + +_ = dbg_auto(COMPONENT_ACTS) +# %% +# Process activations +# ============================================================ +FILTER_DEAD_THRESHOLD: float = 0.001 +FILTER_MODULES: str = "model.layers.0" + +PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + activations=COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, + filter_modules=lambda x: x.startswith(FILTER_MODULES), + seq_mode="concat", +) + +plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, +) + +# %% +# Compute ensemble merge iterations +# ============================================================ +MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.01, + alpha=0.01, + iters=2, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + pop_component_prob=0, + module_name_filter=FILTER_MODULES, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) + +# Modern approach: run merge_iteration multiple times to create ensemble +ENSEMBLE_SIZE: int = 2 +HISTORIES: list[MergeHistory] = [] +for i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + batch_id=f"batch_{i}", + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, + ) + HISTORIES.append(HISTORY) + +ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) + + +# %% +# Compute and plot distances +# ============================================================ +DISTANCES = ENSEMBLE.get_distances() + +plot_dists_distribution( + distances=DISTANCES, + mode="points", +) +plt.legend() diff --git a/tests/clustering/test_clustering_experiments.py b/tests/clustering/test_clustering_experiments.py new file mode 100644 index 000000000..5031adfce --- /dev/null +++ b/tests/clustering/test_clustering_experiments.py @@ -0,0 +1,99 @@ +"""Tests for clustering experiments and notebook-style scripts.""" + +import subprocess +import sys +from pathlib import Path + +import pytest + +# Test resource directories +NOTEBOOK_DIR: Path = Path("tests/clustering/scripts") +CONFIG_DIR: Path = Path("spd/clustering/configs") + + +@pytest.mark.slow +def test_cluster_resid_mlp_notebook(): + """Test running the cluster_resid_mlp.py notebook-style script.""" + script_path = NOTEBOOK_DIR / "cluster_resid_mlp.py" + assert script_path.exists(), f"Script not found: {script_path}" + + # Run the script as-is + result = subprocess.run( + [sys.executable, str(script_path)], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Script failed with return code {result.returncode}" + + +@pytest.mark.slow +def test_clustering_with_resid_mlp1_config(): + """Test running clustering with test-resid_mlp1.json config.""" + config_path = CONFIG_DIR / "test-resid_mlp1.json" + assert config_path.exists(), f"Config not found: {config_path}" + + # Run the clustering main script with the test config + result = subprocess.run( + [ + "spd-cluster", + "--config", + str(config_path), + ], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Clustering failed with return code {result.returncode}" + + +@pytest.mark.slow +def test_cluster_ss_notebook(): + """Test running the cluster_ss.py notebook-style script.""" + script_path = NOTEBOOK_DIR / "cluster_ss.py" + assert script_path.exists(), f"Script not found: {script_path}" + + # Run the script as-is + result = subprocess.run( + [sys.executable, str(script_path)], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Script failed with return code {result.returncode}" + + +@pytest.mark.slow +def test_clustering_with_simplestories_config(): + """Test running clustering with test-simplestories.json config.""" + config_path = CONFIG_DIR / "test-simplestories.json" + assert config_path.exists(), f"Config not found: {config_path}" + + # Run the clustering main script with the test config + result = subprocess.run( + [ + "spd-cluster", + "--config", + str(config_path), + ], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Clustering failed with return code {result.returncode}" diff --git a/tests/clustering/test_filter_dead_components.py b/tests/clustering/test_filter_dead_components.py new file mode 100644 index 000000000..654631f37 --- /dev/null +++ b/tests/clustering/test_filter_dead_components.py @@ -0,0 +1,131 @@ +"""Tests for filter_dead_components function in activations.py""" + +import pytest +import torch +from torch import Tensor + +from spd.clustering.activations import FilteredActivations, filter_dead_components +from spd.clustering.consts import ComponentLabels + + +@pytest.mark.parametrize( + "max_values,threshold,expected_alive_indices", + [ + # No filtering when threshold is 0 + ([0.1, 0.2, 0.3], 0.0, [0, 1, 2]), + # Filter all when all below threshold + ([0.005, 0.003, 0.004], 0.01, []), + # Filter some components + ([0.0, 0.02, 0.0, 0.03, 0.0], 0.01, [1, 3]), + # Boundary cases: at threshold is kept + ([0.009, 0.01, 0.011], 0.01, [1, 2]), + # High threshold filters everything + ([0.1, 0.2, 0.3], 2.0, []), + # Negative threshold filters nothing + ([0.1, 0.2, 0.3], -0.01, [0, 1, 2]), + # Single component above threshold + ([0.5], 0.01, [0]), + ], +) +def test_filter_dead_components_thresholds( + max_values: list[float], + threshold: float, + expected_alive_indices: list[int], +) -> None: + """Test filtering with various max values and thresholds.""" + n_steps: int = 10 + n_components: int = len(max_values) + + activations: Tensor + labels: ComponentLabels + if n_components == 0: + activations = torch.zeros(n_steps, 0) + labels = ComponentLabels([]) + else: + activations = torch.zeros(n_steps, n_components) + # Set max values in first row + for i, val in enumerate(max_values): + activations[0, i] = val + labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=labels, filter_dead_threshold=threshold + ) + + assert result.labels == [f"comp_{i}" for i in expected_alive_indices] + assert result.n_alive == len(expected_alive_indices) + assert result.n_dead == n_components - len(expected_alive_indices) + assert result.activations.shape == (n_steps, len(expected_alive_indices)) + + # Check dead components labels + if threshold <= 0 or all(v >= threshold for v in max_values): + # No filtering occurred + assert result.dead_components_labels is None or result.dead_components_labels == [] + else: + dead_indices: list[int] = [ + i for i in range(n_components) if i not in expected_alive_indices + ] + expected_dead: list[str] = [f"comp_{i}" for i in dead_indices] + assert result.dead_components_labels is not None + assert set(result.dead_components_labels) == set(expected_dead) + + +@pytest.mark.parametrize( + "step_locations,threshold", + [ + # Max at different steps + ([0, 5, 9], 0.01), + # All at same step + ([0, 0, 0], 0.01), + # Random steps + ([3, 7, 1, 8], 0.05), + ], +) +def test_max_across_steps(step_locations: list[int], threshold: float) -> None: + """Verify that filter_dead_components correctly finds the maximum activation + across ALL time steps for each component, not just looking at a single step. + + This test creates components where the maximum activation occurs at different + time steps, ensuring the function scans the entire temporal dimension.""" + n_steps: int = 10 + n_components: int = len(step_locations) + activations: Tensor = torch.zeros(n_steps, n_components) + + # Set values above threshold at specified steps + for i, step in enumerate(step_locations): + activations[step, i] = threshold + 0.01 + + labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=labels, filter_dead_threshold=threshold + ) + + # All components should be alive since their max is above threshold + assert result.n_alive == n_components + assert result.n_dead == 0 + assert result.labels == labels + + +@pytest.mark.parametrize("threshold", [0.001, 0.01, 0.1, 0.5]) +def test_linear_gradient_thresholds(threshold: float) -> None: + """Test with linearly spaced activation values.""" + n_steps: int = 10 + n_components: int = 10 + activations: Tensor = torch.zeros(n_steps, n_components) + + # Create linearly spaced max values: 0, 0.1, 0.2, ..., 0.9 + for i in range(n_components): + activations[0, i] = i * 0.1 + + labels: list[str] = [f"comp_{i}" for i in range(n_components)] + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=ComponentLabels(labels), filter_dead_threshold=threshold + ) + + # Count how many components should be alive + expected_alive: int = sum(i * 0.1 >= threshold for i in range(n_components)) + + assert result.n_alive == expected_alive + assert result.n_dead == n_components - expected_alive diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py new file mode 100644 index 000000000..9f191075b --- /dev/null +++ b/tests/clustering/test_merge_config.py @@ -0,0 +1,181 @@ +"""Tests for MergeConfig with new sampling system.""" + +import pytest +import torch + +from spd.clustering.merge_config import MergeConfig + + +class TestMergeConfigSampling: + """Test MergeConfig integration with sampling system.""" + + def test_default_config(self): + """Test default MergeConfig uses range sampler.""" + config = MergeConfig() + + assert config.merge_pair_sampling_method == "range" + assert config.merge_pair_sampling_kwargs == {"threshold": 0.05} + + def test_range_sampler_config(self): + """Test MergeConfig with range sampler.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + + assert config.merge_pair_sampling_method == "range" + assert config.merge_pair_sampling_kwargs == {"threshold": 0.1} + + # Test that sampler works + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + + def test_mcmc_sampler_config(self): + """Test MergeConfig with MCMC sampler.""" + config = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 2.0} + ) + + assert config.merge_pair_sampling_method == "mcmc" + assert config.merge_pair_sampling_kwargs == {"temperature": 2.0} + + # Test that sampler works + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + + def test_invalid_sampler_method(self): + """Test that invalid sampler method raises error.""" + from pydantic import ValidationError + + # Pydantic validates at construction time + with pytest.raises(ValidationError): + _config = MergeConfig(merge_pair_sampling_method="invalid") # pyright: ignore[reportArgumentType] + + def test_config_with_all_parameters(self): + """Test MergeConfig with all parameters set.""" + config = MergeConfig( + activation_threshold=0.01, + alpha=1.5, + iters=200, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 0.5}, + pop_component_prob=0.1, + filter_dead_threshold=0.001, + module_name_filter="model.layers", + ) + + assert config.activation_threshold == 0.01 + assert config.alpha == 1.5 + assert config.iters == 200 + assert config.merge_pair_sampling_method == "mcmc" + assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} + assert config.pop_component_prob == 0.1 + assert config.filter_dead_threshold == 0.001 + assert config.module_name_filter == "model.layers" + + def test_config_serialization(self): + """Test that config can be serialized and deserialized.""" + config = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.5} + ) + + # Serialize to dict + config_dict = config.model_dump() + assert config_dict["merge_pair_sampling_method"] == "mcmc" + assert config_dict["merge_pair_sampling_kwargs"] == {"temperature": 1.5} + + # Deserialize from dict + config2 = MergeConfig(**config_dict) + assert config2.merge_pair_sampling_method == "mcmc" + assert config2.merge_pair_sampling_kwargs == {"temperature": 1.5} + + def test_config_json_serialization(self): + """Test JSON serialization of config.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.2} + ) + + # Serialize to JSON string + json_str = config.model_dump_json() + assert "range" in json_str + assert "0.2" in json_str + + # Parse back from JSON + import json + + config_dict = json.loads(json_str) + config2 = MergeConfig(**config_dict) + + assert config2.merge_pair_sampling_method == "range" + assert config2.merge_pair_sampling_kwargs == {"threshold": 0.2} + + def test_stable_hash_changes_with_sampling_params(self): + """Test that stable_hash changes when sampling parameters change.""" + config1 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + config2 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.2} + ) + config3 = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0} + ) + + # Different configs should have different hashes + assert config1.stable_hash != config2.stable_hash + assert config1.stable_hash != config3.stable_hash + assert config2.stable_hash != config3.stable_hash + + # Same config should have same hash + config4 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + assert config1.stable_hash == config4.stable_hash + + def test_empty_kwargs(self): + """Test that empty kwargs dict works.""" + config = MergeConfig(merge_pair_sampling_method="range", merge_pair_sampling_kwargs={}) + + # Should work with default parameters of the sampler + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Range sampler has default threshold=0.05 + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert pair[0] != pair[1] + + def test_extra_kwargs_filtered(self): + """Test that only valid kwargs are used by sampler.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.3} + ) + + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Should work with config's method + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert pair[0] != pair[1] diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py new file mode 100644 index 000000000..6463ad07b --- /dev/null +++ b/tests/clustering/test_merge_integration.py @@ -0,0 +1,201 @@ +"""Integration tests for the merge system with new samplers.""" + +import torch + +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig + + +class TestMergeIntegration: + """Test the full merge iteration with different samplers.""" + + def test_merge_with_range_sampler(self): + """Test merge iteration with range sampler.""" + # Create test data + n_samples = 100 + n_components = 10 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Configure with range sampler + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=5, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + pop_component_prob=0, + filter_dead_threshold=0.001, + ) + + # Run merge iteration + history = merge_iteration( + activations=activations, + batch_id="test_merge_with_range_sampler", + merge_config=config, + component_labels=component_labels, + ) + + # Check results + assert history is not None + assert len(history.merges.k_groups) > 0 + # First entry is after first merge, so should be n_components - 1 + assert history.merges.k_groups[0].item() == n_components - 1 + # After iterations, should have fewer groups (merges reduce count) + # Exact count depends on early stopping conditions + assert history.merges.k_groups[-1].item() < n_components + assert history.merges.k_groups[-1].item() >= 2 # Should stop before going below 2 + + def test_merge_with_mcmc_sampler(self): + """Test merge iteration with MCMC sampler.""" + # Create test data + n_samples = 100 + n_components = 10 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Configure with MCMC sampler + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=5, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 1.0}, + pop_component_prob=0, + filter_dead_threshold=0.001, + ) + + # Run merge iteration + history = merge_iteration( + activations=activations, + batch_id="test_merge_with_mcmc_sampler", + merge_config=config, + component_labels=component_labels, + ) + + # Check results + assert history is not None + assert len(history.merges.k_groups) > 0 + # First entry is after first merge, so should be n_components - 1 + assert history.merges.k_groups[0].item() == n_components - 1 + # Should have fewer groups after iterations + assert history.merges.k_groups[-1].item() < n_components + assert history.merges.k_groups[-1].item() >= 2 + + def test_merge_with_popping(self): + """Test merge iteration with component popping.""" + # Create test data + n_samples = 100 + n_components = 15 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Configure with popping enabled + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=10, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.05}, + pop_component_prob=0.3, # 30% chance of popping + filter_dead_threshold=0.001, + ) + + # Run merge iteration + history = merge_iteration( + activations=activations, + batch_id="test_merge_with_popping", + merge_config=config, + component_labels=component_labels, + ) + + # Check results + assert history is not None + # First entry is after first merge, so should be n_components - 1 + assert history.merges.k_groups[0].item() == n_components - 1 + # Final group count depends on pops, but should be less than initial + assert history.merges.k_groups[-1].item() < n_components + + def test_merge_comparison_samplers(self): + """Compare behavior of different samplers with same data.""" + # Create test data with clear structure + n_samples = 100 + n_components = 8 + activations = torch.rand(n_samples, n_components) + + # Make some components more active to create cost structure + activations[:, 0] *= 2 # Component 0 is very active + activations[:, 1] *= 0.1 # Component 1 is rarely active + + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Run with range sampler (threshold=0 for deterministic minimum selection) + config_range = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=3, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum + pop_component_prob=0, + ) + + history_range = merge_iteration( + activations=activations.clone(), + batch_id="test_merge_comparison_samplers_range", + merge_config=config_range, + component_labels=ComponentLabels(component_labels.copy()), + ) + + # Run with MCMC sampler (low temperature for near-deterministic) + config_mcmc = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=3, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp + pop_component_prob=0, + ) + + history_mcmc = merge_iteration( + activations=activations.clone(), + batch_id="test_merge_comparison_samplers_mcmc", + merge_config=config_mcmc, + component_labels=ComponentLabels(component_labels.copy()), + ) + + # Both should reduce groups from initial count + assert history_range.merges.k_groups[-1].item() < n_components + assert history_mcmc.merges.k_groups[-1].item() < n_components + assert history_range.merges.k_groups[-1].item() >= 2 + assert history_mcmc.merges.k_groups[-1].item() >= 2 + + def test_merge_with_small_components(self): + """Test merge with very few components.""" + # Edge case: only 3 components + n_samples = 50 + n_components = 3 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=1, # Just one merge + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 2.0}, + pop_component_prob=0, + ) + + history = merge_iteration( + activations=activations, + batch_id="test_merge_with_small_components", + merge_config=config, + component_labels=component_labels, + ) + + # First entry is after first merge, so should be 3 - 1 = 2 + assert history.merges.k_groups[0].item() == 2 + # Early stopping may occur at 2 groups, so final count could be 2 or 3 + assert history.merges.k_groups[-1].item() >= 2 + assert history.merges.k_groups[-1].item() <= 3 diff --git a/tests/clustering/test_merge_pair_samplers.py b/tests/clustering/test_merge_pair_samplers.py new file mode 100644 index 000000000..e400b0dd3 --- /dev/null +++ b/tests/clustering/test_merge_pair_samplers.py @@ -0,0 +1,274 @@ +"""Tests for merge pair sampling functionality.""" + +import pytest +import torch + +from spd.clustering.math.merge_pair_samplers import ( + MERGE_PAIR_SAMPLERS, + mcmc_sampler, + range_sampler, +) + + +class TestRangeSampler: + """Test range-based merge pair sampling.""" + + def test_range_sampler_basic(self): + """Test basic functionality of range sampler.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 # Make symmetric + costs.fill_diagonal_(float("inf")) # No self-merges + + # Test with different thresholds + pair_low = range_sampler(costs, threshold=0.0) + pair_mid = range_sampler(costs, threshold=0.5) + pair_high = range_sampler(costs, threshold=1.0) + + # All should return valid pairs + assert pair_low[0] != pair_low[1] + assert pair_mid[0] != pair_mid[1] + assert pair_high[0] != pair_high[1] + + # All indices should be in valid range + for pair in [pair_low, pair_mid, pair_high]: + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + def test_range_sampler_threshold_zero(self): + """Test that threshold=0 always selects minimum cost pair.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Find the true minimum + min_val = float("inf") + _min_pair = None + for i in range(k): + for j in range(k): + if i != j and costs[i, j] < min_val: + min_val = costs[i, j].item() + _min_pair = (i, j) + + # Sample multiple times with threshold=0 + for _ in range(10): + pair = range_sampler(costs, threshold=0.0) + # Should always get the minimum (or its symmetric equivalent) + assert costs[pair[0], pair[1]] == min_val or costs[pair[1], pair[0]] == min_val + + def test_range_sampler_threshold_one(self): + """Test that threshold=1 can select any non-diagonal pair.""" + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Sample many times to check we get different pairs + pairs_seen = set() + for _ in range(100): + pair = range_sampler(costs, threshold=1.0) + # Normalize pair order for comparison + normalized = tuple(sorted(pair)) + pairs_seen.add(normalized) + + # With threshold=1, we should see multiple different pairs + assert len(pairs_seen) > 1 + + def test_range_sampler_small_matrix(self): + """Test range sampler with 2x2 matrix.""" + costs = torch.tensor([[float("inf"), 1.0], [1.0, float("inf")]]) + + pair = range_sampler(costs, threshold=0.5) + # Only valid pair is (0, 1) or (1, 0) + assert set(pair) == {0, 1} + + +class TestMCMCSampler: + """Test MCMC-based merge pair sampling.""" + + def test_mcmc_sampler_basic(self): + """Test basic functionality of MCMC sampler.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Test with different temperatures + pair_low_temp = mcmc_sampler(costs, temperature=0.1) + pair_mid_temp = mcmc_sampler(costs, temperature=1.0) + pair_high_temp = mcmc_sampler(costs, temperature=10.0) + + # All should return valid pairs + for pair in [pair_low_temp, pair_mid_temp, pair_high_temp]: + assert pair[0] != pair[1] + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + def test_mcmc_sampler_low_temperature(self): + """Test that low temperature favors low-cost pairs.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Find minimum cost + min_val = float("inf") + for i in range(k): + for j in range(k): + if i != j: + min_val = min(min_val, costs[i, j].item()) + + # Sample many times with very low temperature + low_cost_count = 0 + n_samples = 100 + for _ in range(n_samples): + pair = mcmc_sampler(costs, temperature=0.01) + cost = costs[pair[0], pair[1]].item() + # Check if it's close to minimum + if abs(cost - min_val) < 0.5: # Within 0.5 of minimum + low_cost_count += 1 + + # Most samples should be near minimum with low temperature + assert low_cost_count > n_samples * 0.7 + + def test_mcmc_sampler_high_temperature(self): + """Test that high temperature gives more uniform sampling.""" + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Sample many times with high temperature + pairs_count = {} + n_samples = 1000 + for _ in range(n_samples): + pair = mcmc_sampler(costs, temperature=100.0) + # Normalize pair order for counting + normalized = tuple(sorted(pair)) + pairs_count[normalized] = pairs_count.get(normalized, 0) + 1 + + # With high temperature, distribution should be relatively uniform + # There are k*(k-1)/2 unique pairs + expected_count = n_samples / (k * (k - 1) / 2) + for count in pairs_count.values(): + # Each pair count should be within reasonable range of expected + assert expected_count * 0.3 < count < expected_count * 1.7 + + def test_mcmc_sampler_small_matrix(self): + """Test MCMC sampler with 2x2 matrix.""" + costs = torch.tensor([[float("inf"), 1.0], [1.0, float("inf")]]) + + pair = mcmc_sampler(costs, temperature=1.0) + # Only valid pair is (0, 1) or (1, 0) + assert set(pair) == {0, 1} + + def test_mcmc_sampler_extreme_costs(self): + """Test MCMC sampler with extreme cost differences.""" + k = 3 + # Create matrix with one very low cost and rest high + costs = torch.full((k, k), 1000.0) + costs[0, 1] = costs[1, 0] = 1.0 # One low-cost pair + costs.fill_diagonal_(float("inf")) + + # With low temperature, should almost always select the low-cost pair + low_cost_selected = 0 + for _ in range(100): + pair = mcmc_sampler(costs, temperature=0.1) + if set(pair) == {0, 1}: + low_cost_selected += 1 + + assert low_cost_selected > 95 # Should almost always select (0,1) + + +class TestSamplerRegistry: + """Test the sampler registry.""" + + def test_registry_contains_samplers(self): + """Test that registry contains expected samplers.""" + assert "range" in MERGE_PAIR_SAMPLERS + assert "mcmc" in MERGE_PAIR_SAMPLERS + assert MERGE_PAIR_SAMPLERS["range"] is range_sampler + assert MERGE_PAIR_SAMPLERS["mcmc"] is mcmc_sampler + + def test_registry_samplers_callable(self): + """Test that all registry samplers are callable with correct signature.""" + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + for name, sampler in MERGE_PAIR_SAMPLERS.items(): + # Should be callable + assert callable(sampler) + + # Test with default kwargs + if name == "range": + pair = sampler(costs, threshold=0.5) + elif name == "mcmc": + pair = sampler(costs, temperature=1.0) + else: + pytest.fail(f"Unknown sampler {name}") + + # Should return valid pair + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + +class TestSamplerIntegration: + """Integration tests for samplers with edge cases.""" + + def test_samplers_with_gpu_tensors(self): + """Test samplers work with GPU tensors if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + k = 4 + costs = torch.randn(k, k, device="cuda") + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Both samplers should work with GPU tensors + pair_range = range_sampler(costs, threshold=0.5) + pair_mcmc = mcmc_sampler(costs, temperature=1.0) + + assert isinstance(pair_range, tuple) + assert isinstance(pair_mcmc, tuple) + + def test_samplers_deterministic_with_seed(self): + """Test that samplers are deterministic with fixed seed.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Test range sampler + torch.manual_seed(42) + pair1 = range_sampler(costs, threshold=0.5) + torch.manual_seed(42) + pair2 = range_sampler(costs, threshold=0.5) + # Can't guarantee exact match due to Python's random module + # but both should be valid + assert pair1[0] != pair1[1] + assert pair2[0] != pair2[1] + + # Test MCMC sampler + torch.manual_seed(42) + pair1 = mcmc_sampler(costs, temperature=1.0) + torch.manual_seed(42) + pair2 = mcmc_sampler(costs, temperature=1.0) + assert pair1 == pair2 # Should be deterministic with same seed + + def test_samplers_all_infinite_costs(self): + """Test samplers handle all-infinite costs gracefully.""" + k = 3 + costs = torch.full((k, k), float("inf")) + + # This is an edge case - no valid pairs exist + # Samplers should handle this without crashing + # (though the result may not be meaningful) + with pytest.raises((ValueError, RuntimeError, IndexError)): + range_sampler(costs, threshold=0.5) diff --git a/tests/clustering/test_storage.py b/tests/clustering/test_storage.py new file mode 100644 index 000000000..d5e3d535e --- /dev/null +++ b/tests/clustering/test_storage.py @@ -0,0 +1,351 @@ +"""Comprehensive tests for ClusteringStorage.""" + +import tempfile +from collections.abc import Iterator +from pathlib import Path + +import numpy as np +import pytest +import torch + +from spd.clustering.consts import ComponentLabels, DistancesMethod +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.pipeline.storage import ClusteringStorage, NormalizedEnsemble + + +@pytest.fixture +def temp_storage() -> Iterator[ClusteringStorage]: + """Create a temporary ClusteringStorage instance.""" + with tempfile.TemporaryDirectory() as tmp_dir: + storage = ClusteringStorage(base_path=Path(tmp_dir), run_identifier="test_run") + yield storage + + +@pytest.fixture +def sample_config() -> MergeConfig: + """Create a sample MergeConfig for testing.""" + return MergeConfig( + iters=5, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ) + + +class TestStorageInitialization: + """Test storage initialization and directory structure.""" + + def test_storage_creates_run_directory(self): + """Test that storage creates the run directory on initialization.""" + with tempfile.TemporaryDirectory() as tmp_dir: + base_path = Path(tmp_dir) + storage = ClusteringStorage(base_path=base_path, run_identifier="test_run") + + assert storage.run_path.exists() + assert storage.run_path == base_path / "test_run" + + def test_storage_without_run_identifier(self): + """Test that storage works without a run identifier.""" + with tempfile.TemporaryDirectory() as tmp_dir: + base_path = Path(tmp_dir) + storage = ClusteringStorage(base_path=base_path, run_identifier=None) + + assert storage.run_path == base_path + + def test_storage_paths_are_consistent(self, temp_storage: ClusteringStorage): + """Test that all storage paths are under the run path.""" + assert str(temp_storage._dataset_dir).startswith(str(temp_storage.run_path)) + assert str(temp_storage._batches_dir).startswith(str(temp_storage.run_path)) + assert str(temp_storage._histories_dir).startswith(str(temp_storage.run_path)) + assert str(temp_storage._ensemble_dir).startswith(str(temp_storage.run_path)) + assert str(temp_storage._distances_dir).startswith(str(temp_storage.run_path)) + + +class TestRunConfigStorage: + """Test run configuration storage.""" + + def test_save_and_load_run_config(self, temp_storage: ClusteringStorage): + """Test saving and loading RunConfig.""" + # Create a minimal RunConfig + config = ClusteringRunConfig( + merge_config=MergeConfig( + iters=10, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ), + model_path="wandb:entity/project/run_id", + task_name="lm", + n_batches=5, + batch_size=32, + base_path=temp_storage.base_path, + workers_per_device=1, + devices=["cuda"], + ) + + # Save config + saved_path = temp_storage.save_run_config(config) + assert saved_path.exists() + assert saved_path == temp_storage.run_config_file + + # Load and verify + loaded_config = ClusteringRunConfig.read(saved_path) + assert loaded_config.n_batches == 5 + assert loaded_config.batch_size == 32 + assert loaded_config.task_name == "lm" + + +class TestBatchStorage: + """Test batch data storage.""" + + def test_save_single_batch(self, temp_storage: ClusteringStorage): + """Test saving a single batch.""" + batch = torch.randint(0, 100, (8, 16)) # batch_size=8, seq_len=16 + batch_idx = 0 + + saved_path = temp_storage.save_batch(batch, batch_idx) + assert saved_path.exists() + assert saved_path.name == "batch_00.npz" + + def test_save_and_load_batch(self, temp_storage: ClusteringStorage): + """Test saving and loading a batch.""" + original_batch = torch.randint(0, 100, (8, 16)) + batch_idx = 0 + + # Save + temp_storage.save_batch(original_batch, batch_idx) + + # Load + loaded_batch = temp_storage.load_batch(temp_storage.batch_path(batch_idx)) + + # Verify + assert torch.equal(loaded_batch, original_batch) + + def test_save_multiple_batches(self, temp_storage: ClusteringStorage): + """Test saving multiple batches using save_batches.""" + batches = [torch.randint(0, 100, (8, 16)) for _ in range(3)] + config = {"test": "config"} + + saved_paths = temp_storage.save_batches(iter(batches), config) + + assert len(saved_paths) == 3 + assert all(p.exists() for p in saved_paths) + assert temp_storage.dataset_config_file.exists() + + def test_get_batch_paths(self, temp_storage: ClusteringStorage): + """Test retrieving all batch paths.""" + # Save some batches + for i in range(3): + temp_storage.save_batch(torch.randint(0, 100, (8, 16)), i) + + # Get paths + batch_paths = temp_storage.get_batch_paths() + + assert len(batch_paths) == 3 + assert all(p.exists() for p in batch_paths) + # Should be sorted + assert batch_paths == sorted(batch_paths) + + +class TestHistoryStorage: + """Test merge history storage.""" + + def test_save_and_load_history( + self, temp_storage: ClusteringStorage, sample_config: MergeConfig + ): + """Test saving and loading merge history.""" + # Create history + history = MergeHistory.from_config( + merge_config=sample_config, + labels=ComponentLabels(["comp0", "comp1", "comp2"]), + ) + + batch_id = "batch_00" + + # Save + saved_path = temp_storage.save_history(history, batch_id) + assert saved_path.exists() + assert "batch_00" in str(saved_path) + + # Load + loaded_history = temp_storage.load_history(batch_id) + assert loaded_history is not None + assert len(loaded_history.labels) == 3 + + def test_load_multiple_histories( + self, temp_storage: ClusteringStorage, sample_config: MergeConfig + ): + """Test loading all histories.""" + # Save multiple histories + for i in range(3): + history = MergeHistory.from_config( + merge_config=sample_config, + labels=ComponentLabels([f"comp{j}" for j in range(4)]), + ) + temp_storage.save_history(history, batch_id=f"batch_{i:02d}") + + # Load all + histories = temp_storage.load_histories() + assert len(histories) == 3 + + def test_get_history_paths(self, temp_storage: ClusteringStorage, sample_config: MergeConfig): + """Test getting all history paths.""" + # Save histories + for i in range(2): + history = MergeHistory.from_config( + merge_config=sample_config, + labels=ComponentLabels(["comp0", "comp1"]), + ) + temp_storage.save_history(history, batch_id=f"batch_{i:02d}") + + # Get paths + history_paths = temp_storage.get_history_paths() + assert len(history_paths) == 2 + assert all(p.exists() for p in history_paths) + + +class TestEnsembleStorage: + """Test ensemble data storage.""" + + def test_save_ensemble(self, temp_storage: ClusteringStorage): + """Test saving ensemble data.""" + # Create dummy ensemble data + merge_array = np.random.randint(0, 10, size=(2, 5, 8)) # n_ens, n_iters, c_components + metadata = {"n_ensemble": 2, "n_iters": 5} + + ensemble = NormalizedEnsemble(merge_array=merge_array, metadata=metadata) + + # Save + meta_path, array_path = temp_storage.save_ensemble(ensemble) + + assert meta_path.exists() + assert array_path.exists() + assert meta_path == temp_storage.ensemble_meta_file + assert array_path == temp_storage.ensemble_array_file + + def test_ensemble_data_integrity(self, temp_storage: ClusteringStorage): + """Test that ensemble data can be saved and loaded correctly.""" + # Create ensemble data + original_array = np.random.randint(0, 10, size=(2, 5, 8)) + metadata = {"test": "value", "n_ensemble": 2} + + ensemble = NormalizedEnsemble(merge_array=original_array, metadata=metadata) + + # Save + _, array_path = temp_storage.save_ensemble(ensemble) + + # Load and verify + loaded_data = np.load(array_path) + loaded_array = loaded_data["merges"] + + assert np.array_equal(loaded_array, original_array) + + +class TestDistancesStorage: + """Test distance matrix storage.""" + + def test_save_distances(self, temp_storage: ClusteringStorage): + """Test saving distance matrix.""" + distances = np.random.rand(5, 3, 3) # n_iters, n_ens, n_ens + method: DistancesMethod = "perm_invariant_hamming" + + saved_path = temp_storage.save_distances(distances, method) + + assert saved_path.exists() + assert method in saved_path.name + + def test_save_and_load_distances(self, temp_storage: ClusteringStorage): + """Test saving and loading distances.""" + original_distances = np.random.rand(5, 3, 3) + method: DistancesMethod = "perm_invariant_hamming" + + # Save + temp_storage.save_distances(original_distances, method) + + # Load + loaded_distances = temp_storage.load_distances(method) + + assert np.array_equal(loaded_distances, original_distances) + + +class TestStorageIntegration: + """Test integration scenarios.""" + + def test_full_pipeline_storage_flow( + self, temp_storage: ClusteringStorage, sample_config: MergeConfig + ): + """Test a complete storage workflow.""" + # 1. Save run config + run_config = ClusteringRunConfig( + merge_config=sample_config, + model_path="wandb:entity/project/run_id", + task_name="lm", + n_batches=2, + batch_size=8, + base_path=temp_storage.base_path, + workers_per_device=1, + devices=["cpu"], + ) + temp_storage.save_run_config(run_config) + + # 2. Save batches + batches = [torch.randint(0, 100, (8, 16)) for _ in range(2)] + temp_storage.save_batches(iter(batches), {"dataset": "test"}) + + # 3. Save histories + for i in range(2): + history = MergeHistory.from_config( + merge_config=sample_config, + labels=ComponentLabels(["comp0", "comp1", "comp2"]), + ) + temp_storage.save_history(history, batch_id=f"batch_{i:02d}") + + # 4. Save ensemble + merge_array = np.random.randint(0, 3, size=(2, 5, 3)) + ensemble = NormalizedEnsemble( + merge_array=merge_array, + metadata={"n_ensemble": 2, "n_iters": 5}, + ) + temp_storage.save_ensemble(ensemble) + + # 5. Save distances + distances = np.random.rand(5, 2, 2) + temp_storage.save_distances(distances, "perm_invariant_hamming") + + # Verify all files exist + assert temp_storage.run_config_file.exists() + assert temp_storage.dataset_config_file.exists() + assert len(temp_storage.get_batch_paths()) == 2 + assert len(temp_storage.get_history_paths()) == 2 + assert temp_storage.ensemble_meta_file.exists() + assert temp_storage.ensemble_array_file.exists() + + def test_storage_filesystem_structure(self, temp_storage: ClusteringStorage): + """Test that the filesystem structure matches documentation.""" + # Create minimal data to generate structure + temp_storage.save_run_config( + ClusteringRunConfig( + merge_config=MergeConfig( + iters=1, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ), + model_path="wandb:e/p/r", + task_name="lm", + n_batches=1, + batch_size=1, + base_path=temp_storage.base_path, + workers_per_device=1, + devices=["cpu"], + ) + ) + + # Verify structure + assert (temp_storage.run_path / "run_config.json").exists() + + # The directories are created lazily, so trigger their creation + temp_storage.save_batch(torch.tensor([[1, 2, 3]]), 0) + assert (temp_storage.run_path / "dataset" / "batches").exists() diff --git a/tests/clustering/test_wandb_integration.py b/tests/clustering/test_wandb_integration.py new file mode 100644 index 000000000..cf400ca2b --- /dev/null +++ b/tests/clustering/test_wandb_integration.py @@ -0,0 +1,153 @@ +"""Quick sanity tests for WandB integration features.""" + +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.pipeline.s2_clustering import _save_merge_history_to_wandb +from spd.clustering.pipeline.s3_normalize_histories import normalize_and_save + + +def test_wandb_url_parsing_short_format(): + """Test that normalize_and_save can process merge histories using storage.""" + from spd.clustering.pipeline.storage import ClusteringStorage + + # Create temporary directory for storage + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + + # Create ClusteringStorage instance + storage = ClusteringStorage(base_path=tmp_path, run_identifier="test_run") + + # Create mock merge histories + config = MergeConfig( + iters=5, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ) + + # Save histories using storage + for idx in range(2): + history = MergeHistory.from_config( + merge_config=config, + labels=ComponentLabels([f"comp{j}" for j in range(5)]), + ) + storage.save_history(history, batch_id=f"batch_{idx:02d}") + + # Test normalize_and_save with storage + result = normalize_and_save(storage=storage) + + # Basic checks + assert result is not None + assert storage.ensemble_meta_file.exists() + assert storage.ensemble_array_file.exists() + + # Verify we can load the histories back + loaded_histories = storage.load_histories() + assert len(loaded_histories) == 2 + + +def test_merge_history_ensemble(): + """Test that MergeHistoryEnsemble can handle multiple histories.""" + + # Create test merge histories + config = MergeConfig( + iters=3, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ) + + histories = [] + for _idx in range(2): + history = MergeHistory.from_config( + merge_config=config, + labels=ComponentLabels([f"comp{j}" for j in range(4)]), + ) + histories.append(history) + + # Test ensemble creation + ensemble = MergeHistoryEnsemble(data=histories) + assert len(ensemble.data) == 2 + + # Test normalization + normalized_array, metadata = ensemble.normalized() + assert normalized_array is not None + assert metadata is not None + + +def test_save_merge_history_to_wandb(): + """Test that _save_merge_history_to_wandb creates the expected artifact.""" + + # Create a real MergeHistory + config = MergeConfig( + iters=5, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ) + + history = MergeHistory.from_config( + merge_config=config, + labels=ComponentLabels(["comp0", "comp1", "comp2"]), + ) + + # Mock wandb run and artifact + mock_wandb_run = Mock() + mock_artifact = Mock() + + with tempfile.TemporaryDirectory() as tmp_dir: + history_path = Path(tmp_dir) / "test_history.zip" + history.save(history_path) + + with patch("spd.clustering.pipeline.s2_clustering.wandb.Artifact") as mock_artifact_class: + mock_artifact_class.return_value = mock_artifact + + # Call the function + _save_merge_history_to_wandb( + run=mock_wandb_run, + history_path=history_path, + batch_id="batch_01", + config_identifier="test_config", + history=history, + ) + + # Check that artifact was created and logged + mock_artifact_class.assert_called_once() + mock_wandb_run.log_artifact.assert_called_once_with(mock_artifact) + + # Check artifact creation parameters + call_args = mock_artifact_class.call_args + assert call_args.kwargs["name"] == "merge_history_batch_01" + assert call_args.kwargs["type"] == "merge_history" + assert "batch_01" in call_args.kwargs["description"] + + +def test_wandb_url_field_in_merge_history(): + """Test that MergeHistory can store and serialize wandb_url.""" + + # Create a simple config + config = MergeConfig( + iters=10, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ) + + # Create MergeHistory with wandb_url + history = MergeHistory.from_config( + merge_config=config, + labels=ComponentLabels(["comp0", "comp1", "comp2", "comp3", "comp4"]), + ) + # Check that it can be serialized and deserialized + with tempfile.TemporaryDirectory() as tmp_dir: + save_path = Path(tmp_dir) / "test_history.zip" + history.save(save_path) + loaded_history = MergeHistory.read(save_path) + + assert loaded_history is not None + assert loaded_history.merges.group_idxs.shape == (10, 5) # (iters, n_components) diff --git a/uv.lock b/uv.lock index 49be9f8a8..29dbd67e9 100644 --- a/uv.lock +++ b/uv.lock @@ -1034,6 +1034,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, ] +[[package]] +name = "muutils" +version = "0.8.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ae/80/38cfd93c6e17356cb5be1d31d06e835b1a9f603f7fe35acce98d009db744/muutils-0.8.11.tar.gz", hash = "sha256:391abd59c57c81df5a2eef2a12217d4797b735256c6b01e20ed27b49bc475505", size = 3094363, upload-time = "2025-07-08T03:20:07.511Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/00/e4872f5da08e12ee3130889a96a2074c783b35e3cd096004203e62d3d659/muutils-0.8.11-py3-none-any.whl", hash = "sha256:a98718c4b216f37637bd6c2480494a330de758dfd5f334c2c28bbd18799ee767", size = 126722, upload-time = "2025-07-08T03:20:04.876Z" }, +] + [[package]] name = "narwhals" version = "2.7.0" @@ -1164,7 +1173,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -1175,7 +1184,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -1202,9 +1211,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -1215,7 +1224,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -1965,6 +1974,47 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" }, ] +[[package]] +name = "scipy" +version = "1.16.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/3b/546a6f0bfe791bbb7f8d591613454d15097e53f906308ec6f7c1ce588e8e/scipy-1.16.2.tar.gz", hash = "sha256:af029b153d243a80afb6eabe40b0a07f8e35c9adc269c019f364ad747f826a6b", size = 30580599, upload-time = "2025-09-11T17:48:08.271Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/8d/6396e00db1282279a4ddd507c5f5e11f606812b608ee58517ce8abbf883f/scipy-1.16.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:89d6c100fa5c48472047632e06f0876b3c4931aac1f4291afc81a3644316bb0d", size = 36646259, upload-time = "2025-09-11T17:40:39.329Z" }, + { url = "https://files.pythonhosted.org/packages/3b/93/ea9edd7e193fceb8eef149804491890bde73fb169c896b61aa3e2d1e4e77/scipy-1.16.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:ca748936cd579d3f01928b30a17dc474550b01272d8046e3e1ee593f23620371", size = 28888976, upload-time = "2025-09-11T17:40:46.82Z" }, + { url = "https://files.pythonhosted.org/packages/91/4d/281fddc3d80fd738ba86fd3aed9202331180b01e2c78eaae0642f22f7e83/scipy-1.16.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:fac4f8ce2ddb40e2e3d0f7ec36d2a1e7f92559a2471e59aec37bd8d9de01fec0", size = 20879905, upload-time = "2025-09-11T17:40:52.545Z" }, + { url = "https://files.pythonhosted.org/packages/69/40/b33b74c84606fd301b2915f0062e45733c6ff5708d121dd0deaa8871e2d0/scipy-1.16.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:033570f1dcefd79547a88e18bccacff025c8c647a330381064f561d43b821232", size = 23553066, upload-time = "2025-09-11T17:40:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/55/a7/22c739e2f21a42cc8f16bc76b47cff4ed54fbe0962832c589591c2abec34/scipy-1.16.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ea3421209bf00c8a5ef2227de496601087d8f638a2363ee09af059bd70976dc1", size = 33336407, upload-time = "2025-09-11T17:41:06.796Z" }, + { url = "https://files.pythonhosted.org/packages/53/11/a0160990b82999b45874dc60c0c183d3a3a969a563fffc476d5a9995c407/scipy-1.16.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f66bd07ba6f84cd4a380b41d1bf3c59ea488b590a2ff96744845163309ee8e2f", size = 35673281, upload-time = "2025-09-11T17:41:15.055Z" }, + { url = "https://files.pythonhosted.org/packages/96/53/7ef48a4cfcf243c3d0f1643f5887c81f29fdf76911c4e49331828e19fc0a/scipy-1.16.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5e9feab931bd2aea4a23388c962df6468af3d808ddf2d40f94a81c5dc38f32ef", size = 36004222, upload-time = "2025-09-11T17:41:23.868Z" }, + { url = "https://files.pythonhosted.org/packages/49/7f/71a69e0afd460049d41c65c630c919c537815277dfea214031005f474d78/scipy-1.16.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:03dfc75e52f72cf23ec2ced468645321407faad8f0fe7b1f5b49264adbc29cb1", size = 38664586, upload-time = "2025-09-11T17:41:31.021Z" }, + { url = "https://files.pythonhosted.org/packages/34/95/20e02ca66fb495a95fba0642fd48e0c390d0ece9b9b14c6e931a60a12dea/scipy-1.16.2-cp312-cp312-win_amd64.whl", hash = "sha256:0ce54e07bbb394b417457409a64fd015be623f36e330ac49306433ffe04bc97e", size = 38550641, upload-time = "2025-09-11T17:41:36.61Z" }, + { url = "https://files.pythonhosted.org/packages/92/ad/13646b9beb0a95528ca46d52b7babafbe115017814a611f2065ee4e61d20/scipy-1.16.2-cp312-cp312-win_arm64.whl", hash = "sha256:2a8ffaa4ac0df81a0b94577b18ee079f13fecdb924df3328fc44a7dc5ac46851", size = 25456070, upload-time = "2025-09-11T17:41:41.3Z" }, + { url = "https://files.pythonhosted.org/packages/c1/27/c5b52f1ee81727a9fc457f5ac1e9bf3d6eab311805ea615c83c27ba06400/scipy-1.16.2-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:84f7bf944b43e20b8a894f5fe593976926744f6c185bacfcbdfbb62736b5cc70", size = 36604856, upload-time = "2025-09-11T17:41:47.695Z" }, + { url = "https://files.pythonhosted.org/packages/32/a9/15c20d08e950b540184caa8ced675ba1128accb0e09c653780ba023a4110/scipy-1.16.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:5c39026d12edc826a1ef2ad35ad1e6d7f087f934bb868fc43fa3049c8b8508f9", size = 28864626, upload-time = "2025-09-11T17:41:52.642Z" }, + { url = "https://files.pythonhosted.org/packages/4c/fc/ea36098df653cca26062a627c1a94b0de659e97127c8491e18713ca0e3b9/scipy-1.16.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e52729ffd45b68777c5319560014d6fd251294200625d9d70fd8626516fc49f5", size = 20855689, upload-time = "2025-09-11T17:41:57.886Z" }, + { url = "https://files.pythonhosted.org/packages/dc/6f/d0b53be55727f3e6d7c72687ec18ea6d0047cf95f1f77488b99a2bafaee1/scipy-1.16.2-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:024dd4a118cccec09ca3209b7e8e614931a6ffb804b2a601839499cb88bdf925", size = 23512151, upload-time = "2025-09-11T17:42:02.303Z" }, + { url = "https://files.pythonhosted.org/packages/11/85/bf7dab56e5c4b1d3d8eef92ca8ede788418ad38a7dc3ff50262f00808760/scipy-1.16.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7a5dc7ee9c33019973a470556081b0fd3c9f4c44019191039f9769183141a4d9", size = 33329824, upload-time = "2025-09-11T17:42:07.549Z" }, + { url = "https://files.pythonhosted.org/packages/da/6a/1a927b14ddc7714111ea51f4e568203b2bb6ed59bdd036d62127c1a360c8/scipy-1.16.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c2275ff105e508942f99d4e3bc56b6ef5e4b3c0af970386ca56b777608ce95b7", size = 35681881, upload-time = "2025-09-11T17:42:13.255Z" }, + { url = "https://files.pythonhosted.org/packages/c1/5f/331148ea5780b4fcc7007a4a6a6ee0a0c1507a796365cc642d4d226e1c3a/scipy-1.16.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:af80196eaa84f033e48444d2e0786ec47d328ba00c71e4299b602235ffef9acb", size = 36006219, upload-time = "2025-09-11T17:42:18.765Z" }, + { url = "https://files.pythonhosted.org/packages/46/3a/e991aa9d2aec723b4a8dcfbfc8365edec5d5e5f9f133888067f1cbb7dfc1/scipy-1.16.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9fb1eb735fe3d6ed1f89918224e3385fbf6f9e23757cacc35f9c78d3b712dd6e", size = 38682147, upload-time = "2025-09-11T17:42:25.177Z" }, + { url = "https://files.pythonhosted.org/packages/a1/57/0f38e396ad19e41b4c5db66130167eef8ee620a49bc7d0512e3bb67e0cab/scipy-1.16.2-cp313-cp313-win_amd64.whl", hash = "sha256:fda714cf45ba43c9d3bae8f2585c777f64e3f89a2e073b668b32ede412d8f52c", size = 38520766, upload-time = "2025-09-11T17:43:25.342Z" }, + { url = "https://files.pythonhosted.org/packages/1b/a5/85d3e867b6822d331e26c862a91375bb7746a0b458db5effa093d34cdb89/scipy-1.16.2-cp313-cp313-win_arm64.whl", hash = "sha256:2f5350da923ccfd0b00e07c3e5cfb316c1c0d6c1d864c07a72d092e9f20db104", size = 25451169, upload-time = "2025-09-11T17:43:30.198Z" }, + { url = "https://files.pythonhosted.org/packages/09/d9/60679189bcebda55992d1a45498de6d080dcaf21ce0c8f24f888117e0c2d/scipy-1.16.2-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:53d8d2ee29b925344c13bda64ab51785f016b1b9617849dac10897f0701b20c1", size = 37012682, upload-time = "2025-09-11T17:42:30.677Z" }, + { url = "https://files.pythonhosted.org/packages/83/be/a99d13ee4d3b7887a96f8c71361b9659ba4ef34da0338f14891e102a127f/scipy-1.16.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:9e05e33657efb4c6a9d23bd8300101536abd99c85cca82da0bffff8d8764d08a", size = 29389926, upload-time = "2025-09-11T17:42:35.845Z" }, + { url = "https://files.pythonhosted.org/packages/bf/0a/130164a4881cec6ca8c00faf3b57926f28ed429cd6001a673f83c7c2a579/scipy-1.16.2-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:7fe65b36036357003b3ef9d37547abeefaa353b237e989c21027b8ed62b12d4f", size = 21381152, upload-time = "2025-09-11T17:42:40.07Z" }, + { url = "https://files.pythonhosted.org/packages/47/a6/503ffb0310ae77fba874e10cddfc4a1280bdcca1d13c3751b8c3c2996cf8/scipy-1.16.2-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:6406d2ac6d40b861cccf57f49592f9779071655e9f75cd4f977fa0bdd09cb2e4", size = 23914410, upload-time = "2025-09-11T17:42:44.313Z" }, + { url = "https://files.pythonhosted.org/packages/fa/c7/1147774bcea50d00c02600aadaa919facbd8537997a62496270133536ed6/scipy-1.16.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ff4dc42bd321991fbf611c23fc35912d690f731c9914bf3af8f417e64aca0f21", size = 33481880, upload-time = "2025-09-11T17:42:49.325Z" }, + { url = "https://files.pythonhosted.org/packages/6a/74/99d5415e4c3e46b2586f30cdbecb95e101c7192628a484a40dd0d163811a/scipy-1.16.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:654324826654d4d9133e10675325708fb954bc84dae6e9ad0a52e75c6b1a01d7", size = 35791425, upload-time = "2025-09-11T17:42:54.711Z" }, + { url = "https://files.pythonhosted.org/packages/1b/ee/a6559de7c1cc710e938c0355d9d4fbcd732dac4d0d131959d1f3b63eb29c/scipy-1.16.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63870a84cd15c44e65220eaed2dac0e8f8b26bbb991456a033c1d9abfe8a94f8", size = 36178622, upload-time = "2025-09-11T17:43:00.375Z" }, + { url = "https://files.pythonhosted.org/packages/4e/7b/f127a5795d5ba8ece4e0dce7d4a9fb7cb9e4f4757137757d7a69ab7d4f1a/scipy-1.16.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:fa01f0f6a3050fa6a9771a95d5faccc8e2f5a92b4a2e5440a0fa7264a2398472", size = 38783985, upload-time = "2025-09-11T17:43:06.661Z" }, + { url = "https://files.pythonhosted.org/packages/3e/9f/bc81c1d1e033951eb5912cd3750cc005943afa3e65a725d2443a3b3c4347/scipy-1.16.2-cp313-cp313t-win_amd64.whl", hash = "sha256:116296e89fba96f76353a8579820c2512f6e55835d3fad7780fece04367de351", size = 38631367, upload-time = "2025-09-11T17:43:14.44Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5e/2cc7555fd81d01814271412a1d59a289d25f8b63208a0a16c21069d55d3e/scipy-1.16.2-cp313-cp313t-win_arm64.whl", hash = "sha256:98e22834650be81d42982360382b43b17f7ba95e0e6993e2a4f5b9ad9283a94d", size = 25787992, upload-time = "2025-09-11T17:43:19.745Z" }, +] + [[package]] name = "sentry-sdk" version = "2.40.0" @@ -2036,9 +2086,11 @@ dependencies = [ { name = "ipykernel" }, { name = "jaxtyping" }, { name = "matplotlib" }, + { name = "muutils" }, { name = "numpy" }, { name = "pydantic" }, { name = "python-dotenv" }, + { name = "scipy" }, { name = "simple-stories-train" }, { name = "streamlit" }, { name = "streamlit-antd-components" }, @@ -2069,9 +2121,11 @@ requires-dist = [ { name = "ipykernel" }, { name = "jaxtyping" }, { name = "matplotlib", specifier = "==3.9.1" }, + { name = "muutils" }, { name = "numpy" }, { name = "pydantic" }, { name = "python-dotenv" }, + { name = "scipy", specifier = ">=1.14.1" }, { name = "simple-stories-train", git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev" }, { name = "streamlit" }, { name = "streamlit-antd-components" }, @@ -2213,27 +2267,27 @@ wheels = [ [[package]] name = "tokenizers" -version = "0.22.1" +version = "0.21.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/46/fb6854cec3278fbfa4a75b50232c77622bc517ac886156e6afbfa4d8fc6e/tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9", size = 363123, upload-time = "2025-09-19T09:49:23.424Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/2f/402986d0823f8d7ca139d969af2917fefaa9b947d1fb32f6168c509f2492/tokenizers-0.21.4.tar.gz", hash = "sha256:fa23f85fbc9a02ec5c6978da172cdcbac23498c3ca9f3645c5c68740ac007880", size = 351253, upload-time = "2025-07-28T15:48:54.325Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/33/f4b2d94ada7ab297328fc671fed209368ddb82f965ec2224eb1892674c3a/tokenizers-0.22.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59fdb013df17455e5f950b4b834a7b3ee2e0271e6378ccb33aa74d178b513c73", size = 3069318, upload-time = "2025-09-19T09:49:11.848Z" }, - { url = "https://files.pythonhosted.org/packages/1c/58/2aa8c874d02b974990e89ff95826a4852a8b2a273c7d1b4411cdd45a4565/tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:8d4e484f7b0827021ac5f9f71d4794aaef62b979ab7608593da22b1d2e3c4edc", size = 2926478, upload-time = "2025-09-19T09:49:09.759Z" }, - { url = "https://files.pythonhosted.org/packages/1e/3b/55e64befa1e7bfea963cf4b787b2cea1011362c4193f5477047532ce127e/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d2962dd28bc67c1f205ab180578a78eef89ac60ca7ef7cbe9635a46a56422a", size = 3256994, upload-time = "2025-09-19T09:48:56.701Z" }, - { url = "https://files.pythonhosted.org/packages/71/0b/fbfecf42f67d9b7b80fde4aabb2b3110a97fac6585c9470b5bff103a80cb/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:38201f15cdb1f8a6843e6563e6e79f4abd053394992b9bbdf5213ea3469b4ae7", size = 3153141, upload-time = "2025-09-19T09:48:59.749Z" }, - { url = "https://files.pythonhosted.org/packages/17/a9/b38f4e74e0817af8f8ef925507c63c6ae8171e3c4cb2d5d4624bf58fca69/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1cbe5454c9a15df1b3443c726063d930c16f047a3cc724b9e6e1a91140e5a21", size = 3508049, upload-time = "2025-09-19T09:49:05.868Z" }, - { url = "https://files.pythonhosted.org/packages/d2/48/dd2b3dac46bb9134a88e35d72e1aa4869579eacc1a27238f1577270773ff/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7d094ae6312d69cc2a872b54b91b309f4f6fbce871ef28eb27b52a98e4d0214", size = 3710730, upload-time = "2025-09-19T09:49:01.832Z" }, - { url = "https://files.pythonhosted.org/packages/93/0e/ccabc8d16ae4ba84a55d41345207c1e2ea88784651a5a487547d80851398/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afd7594a56656ace95cdd6df4cca2e4059d294c5cfb1679c57824b605556cb2f", size = 3412560, upload-time = "2025-09-19T09:49:03.867Z" }, - { url = "https://files.pythonhosted.org/packages/d0/c6/dc3a0db5a6766416c32c034286d7c2d406da1f498e4de04ab1b8959edd00/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2ef6063d7a84994129732b47e7915e8710f27f99f3a3260b8a38fc7ccd083f4", size = 3250221, upload-time = "2025-09-19T09:49:07.664Z" }, - { url = "https://files.pythonhosted.org/packages/d7/a6/2c8486eef79671601ff57b093889a345dd3d576713ef047776015dc66de7/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ba0a64f450b9ef412c98f6bcd2a50c6df6e2443b560024a09fa6a03189726879", size = 9345569, upload-time = "2025-09-19T09:49:14.214Z" }, - { url = "https://files.pythonhosted.org/packages/6b/16/32ce667f14c35537f5f605fe9bea3e415ea1b0a646389d2295ec348d5657/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:331d6d149fa9c7d632cde4490fb8bbb12337fa3a0232e77892be656464f4b446", size = 9271599, upload-time = "2025-09-19T09:49:16.639Z" }, - { url = "https://files.pythonhosted.org/packages/51/7c/a5f7898a3f6baa3fc2685c705e04c98c1094c523051c805cdd9306b8f87e/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:607989f2ea68a46cb1dfbaf3e3aabdf3f21d8748312dbeb6263d1b3b66c5010a", size = 9533862, upload-time = "2025-09-19T09:49:19.146Z" }, - { url = "https://files.pythonhosted.org/packages/36/65/7e75caea90bc73c1dd8d40438adf1a7bc26af3b8d0a6705ea190462506e1/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a0f307d490295717726598ef6fa4f24af9d484809223bbc253b201c740a06390", size = 9681250, upload-time = "2025-09-19T09:49:21.501Z" }, - { url = "https://files.pythonhosted.org/packages/30/2c/959dddef581b46e6209da82df3b78471e96260e2bc463f89d23b1bf0e52a/tokenizers-0.22.1-cp39-abi3-win32.whl", hash = "sha256:b5120eed1442765cd90b903bb6cfef781fd8fe64e34ccaecbae4c619b7b12a82", size = 2472003, upload-time = "2025-09-19T09:49:27.089Z" }, - { url = "https://files.pythonhosted.org/packages/b3/46/e33a8c93907b631a99377ef4c5f817ab453d0b34f93529421f42ff559671/tokenizers-0.22.1-cp39-abi3-win_amd64.whl", hash = "sha256:65fd6e3fb11ca1e78a6a93602490f134d1fdeb13bcef99389d5102ea318ed138", size = 2674684, upload-time = "2025-09-19T09:49:24.953Z" }, + { url = "https://files.pythonhosted.org/packages/98/c6/fdb6f72bf6454f52eb4a2510be7fb0f614e541a2554d6210e370d85efff4/tokenizers-0.21.4-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:2ccc10a7c3bcefe0f242867dc914fc1226ee44321eb618cfe3019b5df3400133", size = 2863987, upload-time = "2025-07-28T15:48:44.877Z" }, + { url = "https://files.pythonhosted.org/packages/8d/a6/28975479e35ddc751dc1ddc97b9b69bf7fcf074db31548aab37f8116674c/tokenizers-0.21.4-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:5e2f601a8e0cd5be5cc7506b20a79112370b9b3e9cb5f13f68ab11acd6ca7d60", size = 2732457, upload-time = "2025-07-28T15:48:43.265Z" }, + { url = "https://files.pythonhosted.org/packages/aa/8f/24f39d7b5c726b7b0be95dca04f344df278a3fe3a4deb15a975d194cbb32/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:39b376f5a1aee67b4d29032ee85511bbd1b99007ec735f7f35c8a2eb104eade5", size = 3012624, upload-time = "2025-07-28T13:22:43.895Z" }, + { url = "https://files.pythonhosted.org/packages/58/47/26358925717687a58cb74d7a508de96649544fad5778f0cd9827398dc499/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2107ad649e2cda4488d41dfd031469e9da3fcbfd6183e74e4958fa729ffbf9c6", size = 2939681, upload-time = "2025-07-28T13:22:47.499Z" }, + { url = "https://files.pythonhosted.org/packages/99/6f/cc300fea5db2ab5ddc2c8aea5757a27b89c84469899710c3aeddc1d39801/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c73012da95afafdf235ba80047699df4384fdc481527448a078ffd00e45a7d9", size = 3247445, upload-time = "2025-07-28T15:48:39.711Z" }, + { url = "https://files.pythonhosted.org/packages/be/bf/98cb4b9c3c4afd8be89cfa6423704337dc20b73eb4180397a6e0d456c334/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f23186c40395fc390d27f519679a58023f368a0aad234af145e0f39ad1212732", size = 3428014, upload-time = "2025-07-28T13:22:49.569Z" }, + { url = "https://files.pythonhosted.org/packages/75/c7/96c1cc780e6ca7f01a57c13235dd05b7bc1c0f3588512ebe9d1331b5f5ae/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cc88bb34e23a54cc42713d6d98af5f1bf79c07653d24fe984d2d695ba2c922a2", size = 3193197, upload-time = "2025-07-28T13:22:51.471Z" }, + { url = "https://files.pythonhosted.org/packages/f2/90/273b6c7ec78af547694eddeea9e05de771278bd20476525ab930cecaf7d8/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51b7eabb104f46c1c50b486520555715457ae833d5aee9ff6ae853d1130506ff", size = 3115426, upload-time = "2025-07-28T15:48:41.439Z" }, + { url = "https://files.pythonhosted.org/packages/91/43/c640d5a07e95f1cf9d2c92501f20a25f179ac53a4f71e1489a3dcfcc67ee/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:714b05b2e1af1288bd1bc56ce496c4cebb64a20d158ee802887757791191e6e2", size = 9089127, upload-time = "2025-07-28T15:48:46.472Z" }, + { url = "https://files.pythonhosted.org/packages/44/a1/dd23edd6271d4dca788e5200a807b49ec3e6987815cd9d0a07ad9c96c7c2/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:1340ff877ceedfa937544b7d79f5b7becf33a4cfb58f89b3b49927004ef66f78", size = 9055243, upload-time = "2025-07-28T15:48:48.539Z" }, + { url = "https://files.pythonhosted.org/packages/21/2b/b410d6e9021c4b7ddb57248304dc817c4d4970b73b6ee343674914701197/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:3c1f4317576e465ac9ef0d165b247825a2a4078bcd01cba6b54b867bdf9fdd8b", size = 9298237, upload-time = "2025-07-28T15:48:50.443Z" }, + { url = "https://files.pythonhosted.org/packages/b7/0a/42348c995c67e2e6e5c89ffb9cfd68507cbaeb84ff39c49ee6e0a6dd0fd2/tokenizers-0.21.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:c212aa4e45ec0bb5274b16b6f31dd3f1c41944025c2358faaa5782c754e84c24", size = 9461980, upload-time = "2025-07-28T15:48:52.325Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d3/dacccd834404cd71b5c334882f3ba40331ad2120e69ded32cf5fda9a7436/tokenizers-0.21.4-cp39-abi3-win32.whl", hash = "sha256:6c42a930bc5f4c47f4ea775c91de47d27910881902b0f20e4990ebe045a415d0", size = 2329871, upload-time = "2025-07-28T15:48:56.841Z" }, + { url = "https://files.pythonhosted.org/packages/41/f2/fd673d979185f5dcbac4be7d09461cbb99751554ffb6718d0013af8604cb/tokenizers-0.21.4-cp39-abi3-win_amd64.whl", hash = "sha256:475d807a5c3eb72c59ad9b5fcdb254f6e17f53dfcbb9903233b0dfa9c943b597", size = 2507568, upload-time = "2025-07-28T15:48:55.456Z" }, ] [[package]] @@ -2354,7 +2408,7 @@ wheels = [ [[package]] name = "transformers" -version = "4.57.0" +version = "4.53.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -2368,9 +2422,9 @@ dependencies = [ { name = "tokenizers" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f3/5c/a22c39dac2687f3fe2a6b97e2c1ae516e91cd4d3976a7a2b7c24ff2fae48/transformers-4.57.0.tar.gz", hash = "sha256:d045753f3d93f9216e693cdb168698dfd2e9d3aad1bb72579a5d60ebf1545a8b", size = 10142956, upload-time = "2025-10-03T17:03:47.177Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/67/80f51466ec447028fd84469b208eb742533ce06cc8fad2e3181380199e5c/transformers-4.53.2.tar.gz", hash = "sha256:6c3ed95edfb1cba71c4245758f1b4878c93bf8cde77d076307dacb2cbbd72be2", size = 9201233, upload-time = "2025-07-11T12:39:08.742Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/2b/4d2708ac1ff5cd708b6548f4c5812d0ae40d1c28591c4c1c762b6dbdef2d/transformers-4.57.0-py3-none-any.whl", hash = "sha256:9d7c6d098c026e40d897e017ed1f481ab803cbac041021dbc6ae6100e4949b55", size = 11990588, upload-time = "2025-10-03T17:03:43.629Z" }, + { url = "https://files.pythonhosted.org/packages/96/88/beb33a79a382fcd2aed0be5222bdc47f41e4bfe7aaa90ae1374f1d8ea2af/transformers-4.53.2-py3-none-any.whl", hash = "sha256:db8f4819bb34f000029c73c3c557e7d06fc1b8e612ec142eecdae3947a9c78bf", size = 10826609, upload-time = "2025-07-11T12:39:05.461Z" }, ] [[package]] @@ -2378,7 +2432,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools" }, + { name = "setuptools", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/d0/66/b1eb52839f563623d185f0927eb3530ee4d5ffe9d377cdaf5346b306689e/triton-3.4.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31c1d84a5c0ec2c0f8e8a072d7fd150cab84a9c239eaddc6706c081bfae4eb04", size = 155560068, upload-time = "2025-07-30T19:58:37.081Z" }, From 26c2957db60e29cb8632421eccd2dfa182ffa360 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 13:18:53 +0100 Subject: [PATCH 02/77] timing cluster_ss.py --- .github/workflows/checks.yaml | 3 + tests/clustering/scripts/cluster_ss.py | 149 ++++++++++++++----------- 2 files changed, 84 insertions(+), 68 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 258935d2b..49b66908d 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -48,6 +48,9 @@ jobs: - name: Run ruff format run: uv run ruff format . + + - name: "[TEMP] run cluster_ss.py" + run: uv run python tests/clustering/scripts/cluster_ss.py - name: Run tests run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 6ede368f0..020a7ddd5 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -6,6 +6,7 @@ from jaxtyping import Int from muutils.dbg import dbg_auto from torch import Tensor +from muutils.spinner import SpinnerContext from spd.clustering.activations import ( ProcessedActivations, @@ -34,96 +35,108 @@ # %% # Load model and dataset # ============================================================ -MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" - -SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) -MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) -MODEL.to(DEVICE) -SPD_CONFIG = SPD_RUN.config - -# Use split_dataset with RunConfig to get real data -CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(), - model_path=MODEL_PATH, - task_name="lm", - n_batches=1, - batch_size=2, -) -BATCHES, _ = split_dataset(config=CONFIG) +with SpinnerContext(message="Load model"): + MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" + + SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) + MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) + MODEL.to(DEVICE) + SPD_CONFIG = SPD_RUN.config + + # Use split_dataset with RunConfig to get real data + CONFIG: ClusteringRunConfig = ClusteringRunConfig( + merge_config=MergeConfig(), + model_path=MODEL_PATH, + task_name="lm", + n_batches=1, + batch_size=2, + ) + +with SpinnerContext(message="Load data"): + BATCHES, _ = split_dataset(config=CONFIG) # %% # Load data batch # ============================================================ -DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) +with SpinnerContext(message="Load data batch"): + DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) # %% # Get component activations # ============================================================ -COMPONENT_ACTS: dict[str, Tensor] = component_activations( - model=MODEL, - batch=DATA_BATCH, - device=DEVICE, - sigmoid_type="hard", -) +with SpinnerContext(message="Get component activations"): + COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + batch=DATA_BATCH, + device=DEVICE, + sigmoid_type="hard", + ) -_ = dbg_auto(COMPONENT_ACTS) + _ = dbg_auto(COMPONENT_ACTS) # %% # Process activations # ============================================================ -FILTER_DEAD_THRESHOLD: float = 0.001 -FILTER_MODULES: str = "model.layers.0" - -PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( - activations=COMPONENT_ACTS, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, - filter_modules=lambda x: x.startswith(FILTER_MODULES), - seq_mode="concat", -) +with SpinnerContext(message="Process activations"): + FILTER_DEAD_THRESHOLD: float = 0.001 + FILTER_MODULES: str = "model.layers.0" + + PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + activations=COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, + filter_modules=lambda x: x.startswith(FILTER_MODULES), + seq_mode="concat", + ) -plot_activations( - processed_activations=PROCESSED_ACTIVATIONS, - save_dir=TEMP_DIR, - n_samples_max=256, - wandb_run=None, -) +with SpinnerContext(message="Plot activations"): + + plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, + ) # %% # Compute ensemble merge iterations # ============================================================ -MERGE_CFG: MergeConfig = MergeConfig( - activation_threshold=0.01, - alpha=0.01, - iters=2, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, - module_name_filter=FILTER_MODULES, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, -) - -# Modern approach: run merge_iteration multiple times to create ensemble -ENSEMBLE_SIZE: int = 2 -HISTORIES: list[MergeHistory] = [] -for i in range(ENSEMBLE_SIZE): - HISTORY: MergeHistory = merge_iteration( - merge_config=MERGE_CFG, - batch_id=f"batch_{i}", - activations=PROCESSED_ACTIVATIONS.activations, - component_labels=PROCESSED_ACTIVATIONS.labels, - log_callback=None, +with SpinnerContext(message="Compute merge iterations"): + MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.01, + alpha=0.01, + iters=2, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + pop_component_prob=0, + module_name_filter=FILTER_MODULES, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) - HISTORIES.append(HISTORY) -ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) + # Modern approach: run merge_iteration multiple times to create ensemble + ENSEMBLE_SIZE: int = 2 + HISTORIES: list[MergeHistory] = [] + for i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + batch_id=f"batch_{i}", + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, + ) + HISTORIES.append(HISTORY) + + ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) # %% # Compute and plot distances # ============================================================ -DISTANCES = ENSEMBLE.get_distances() +with SpinnerContext(message="compute distances"): + DISTANCES = ENSEMBLE.get_distances() -plot_dists_distribution( - distances=DISTANCES, - mode="points", -) -plt.legend() + +with SpinnerContext(message="plot distances"): + plot_dists_distribution( + distances=DISTANCES, + mode="points", + ) + plt.legend() From 55be0d8de5ea085cdf8f7c98af57ea82652a1d21 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 13:20:03 +0100 Subject: [PATCH 03/77] Revert "timing cluster_ss.py" This reverts commit 26c2957db60e29cb8632421eccd2dfa182ffa360. --- .github/workflows/checks.yaml | 3 - tests/clustering/scripts/cluster_ss.py | 149 +++++++++++-------------- 2 files changed, 68 insertions(+), 84 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 49b66908d..258935d2b 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -48,9 +48,6 @@ jobs: - name: Run ruff format run: uv run ruff format . - - - name: "[TEMP] run cluster_ss.py" - run: uv run python tests/clustering/scripts/cluster_ss.py - name: Run tests run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 020a7ddd5..6ede368f0 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -6,7 +6,6 @@ from jaxtyping import Int from muutils.dbg import dbg_auto from torch import Tensor -from muutils.spinner import SpinnerContext from spd.clustering.activations import ( ProcessedActivations, @@ -35,108 +34,96 @@ # %% # Load model and dataset # ============================================================ -with SpinnerContext(message="Load model"): - MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" - - SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) - MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) - MODEL.to(DEVICE) - SPD_CONFIG = SPD_RUN.config - - # Use split_dataset with RunConfig to get real data - CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(), - model_path=MODEL_PATH, - task_name="lm", - n_batches=1, - batch_size=2, - ) - -with SpinnerContext(message="Load data"): - BATCHES, _ = split_dataset(config=CONFIG) +MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" + +SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) +MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) +MODEL.to(DEVICE) +SPD_CONFIG = SPD_RUN.config + +# Use split_dataset with RunConfig to get real data +CONFIG: ClusteringRunConfig = ClusteringRunConfig( + merge_config=MergeConfig(), + model_path=MODEL_PATH, + task_name="lm", + n_batches=1, + batch_size=2, +) +BATCHES, _ = split_dataset(config=CONFIG) # %% # Load data batch # ============================================================ -with SpinnerContext(message="Load data batch"): - DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) +DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) # %% # Get component activations # ============================================================ -with SpinnerContext(message="Get component activations"): - COMPONENT_ACTS: dict[str, Tensor] = component_activations( - model=MODEL, - batch=DATA_BATCH, - device=DEVICE, - sigmoid_type="hard", - ) +COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + batch=DATA_BATCH, + device=DEVICE, + sigmoid_type="hard", +) - _ = dbg_auto(COMPONENT_ACTS) +_ = dbg_auto(COMPONENT_ACTS) # %% # Process activations # ============================================================ -with SpinnerContext(message="Process activations"): - FILTER_DEAD_THRESHOLD: float = 0.001 - FILTER_MODULES: str = "model.layers.0" - - PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( - activations=COMPONENT_ACTS, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, - filter_modules=lambda x: x.startswith(FILTER_MODULES), - seq_mode="concat", - ) - -with SpinnerContext(message="Plot activations"): +FILTER_DEAD_THRESHOLD: float = 0.001 +FILTER_MODULES: str = "model.layers.0" + +PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + activations=COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, + filter_modules=lambda x: x.startswith(FILTER_MODULES), + seq_mode="concat", +) - plot_activations( - processed_activations=PROCESSED_ACTIVATIONS, - save_dir=TEMP_DIR, - n_samples_max=256, - wandb_run=None, - ) +plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, +) # %% # Compute ensemble merge iterations # ============================================================ -with SpinnerContext(message="Compute merge iterations"): - MERGE_CFG: MergeConfig = MergeConfig( - activation_threshold=0.01, - alpha=0.01, - iters=2, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, - module_name_filter=FILTER_MODULES, - filter_dead_threshold=FILTER_DEAD_THRESHOLD, - ) +MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.01, + alpha=0.01, + iters=2, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + pop_component_prob=0, + module_name_filter=FILTER_MODULES, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) - # Modern approach: run merge_iteration multiple times to create ensemble - ENSEMBLE_SIZE: int = 2 - HISTORIES: list[MergeHistory] = [] - for i in range(ENSEMBLE_SIZE): - HISTORY: MergeHistory = merge_iteration( - merge_config=MERGE_CFG, - batch_id=f"batch_{i}", - activations=PROCESSED_ACTIVATIONS.activations, - component_labels=PROCESSED_ACTIVATIONS.labels, - log_callback=None, - ) - HISTORIES.append(HISTORY) +# Modern approach: run merge_iteration multiple times to create ensemble +ENSEMBLE_SIZE: int = 2 +HISTORIES: list[MergeHistory] = [] +for i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + batch_id=f"batch_{i}", + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, + ) + HISTORIES.append(HISTORY) - ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) +ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) # %% # Compute and plot distances # ============================================================ -with SpinnerContext(message="compute distances"): - DISTANCES = ENSEMBLE.get_distances() - +DISTANCES = ENSEMBLE.get_distances() -with SpinnerContext(message="plot distances"): - plot_dists_distribution( - distances=DISTANCES, - mode="points", - ) - plt.legend() +plot_dists_distribution( + distances=DISTANCES, + mode="points", +) +plt.legend() From 00db8dd1db707d67f69d8a31f844ebc1e57d0a43 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 10 Oct 2025 08:17:18 -0700 Subject: [PATCH 04/77] [clustering] `cluster_ss.py` speedup ci (#199) ## Description clustering tests were slow (see [`actions/runs/18406085786/job/52446229079`](https://github.com/goodfire-ai/spd/actions/runs/18406085786/job/52446229079), taking nearly 10 minutes [Fig 1]. In particular, `tests/clustering/test_clustering_experiments.py::test_cluster_ss_notebook` was taking ~500 seconds, nearly 10x the next slowest test. This test just runs `tests/clustering/scripts/cluster_ss.py` as a script. Turns out >90% of that time was spent on the dataset [Fig 2], so we switched it to be streaming. this required a minor refactor. Note that after we fixed the above, `test_cluster_ss_notebook` was fast, but streaming was not enabled for `test_clustering_with_simplestories_config` since that tests it thru cli -- hence the downloading was just happening for the latter. so, we adjust the config and cli to allow for enabling dataset streaming in see [17bfdc4](https://github.com/goodfire-ai/spd/pull/199/commits/17bfdc44c2e2b11cc28db6de8ab91c2c33c702b0) ### Fig 1 ``` ============================= slowest 10 durations ============================= 504.48s call tests/clustering/test_clustering_experiments.py::test_cluster_ss_notebook 66.45s call tests/clustering/test_clustering_experiments.py::test_clustering_with_resid_mlp1_config 48.46s call tests/test_distributed.py::TestDistributedDeterminicity::test_distributed_determinicity 47.01s call tests/clustering/test_clustering_experiments.py::test_clustering_with_simplestories_config 17.16s call tests/clustering/test_clustering_experiments.py::test_cluster_resid_mlp_notebook 6.18s call tests/metrics/test_alive_components_distributed.py::TestDistributedAliveComponentsTracker::test_distributed_alive_components 5.36s call tests/test_gpt2.py::test_gpt_2_decomposition_happy_path 4.67s call tests/test_wandb_run_loading.py::test_loading_from_wandb[tms_5-2-id-wandb:goodfire/spd/runs/dwalcejo-_from_run_info] 4.56s call tests/test_wandb_run_loading.py::test_loading_from_wandb[resid_mlp2-wandb:goodfire/spd/runs/6vpsvdl5-_from_pretrained] 4.37s call tests/test_wandb_run_loading.py::test_loading_from_wandb[tms_5-2-wandb:goodfire/spd/runs/4stmo6p5-_from_run_info] ``` > from https://github.com/goodfire-ai/spd/actions/runs/18406085786/job/52446229079 ### Fig 2 ``` Timer records (s): Load model 15.36 Load data 511.21 Load data batch 0.05 Get component activations 5.11 Process activations 0.08 Plot activations 8.05 Compute merge iterations 0.51 compute distances 0.18 plot distances 0.01 ``` > from https://github.com/goodfire-ai/spd/actions/runs/18407424091/job/52450587755 ## Overall Impact - total CI runtime reduced from 12 minutes to 4 minutes. time for specifically tests reduced from 11 min to 3 min. - before: https://github.com/goodfire-ai/spd/actions/runs/18406085786/job/52446229079 - after: https://github.com/goodfire-ai/spd/actions/runs/18410271787/job/52460485800 - a few minor backwards compatible interface additions - allow specifying `dataset_streaming` for `"lm"` tasks - format specifier for saving figures (I thought svg might be a bit faster. left it in because seems useful generally) ## Related Issue See https://github.com/goodfire-ai/spd/pull/43#discussion_r2278800579 ## Does this PR introduce a breaking change? No, but did modify a few interfaces. # Commits: * timing cluster_ss.py nearly identical to 26c2957 but accidentally committed that to clustering/main * better timing * fix timing * allow saving in formats besides pdf * streaming dataset for notebook to avoid download? * oops * remove timers * minimize diff * remove custom CI timing step * dataset streaming in config+cli for spd-cluster test_cluster_ss_notebook was fast, but streaming was not enabled for test_clustering_with_simplestories_config since that tests it thru cli -- hence the downloading was just happening for the latter. so, we adjust the config and cli to allow for enabling dataset streaming * wip * cuda issues??? * [temp] telemetry for action using https://github.com/catchpoint/workflow-telemetry-action * minor fixes from claude see https://github.com/goodfire-ai/spd/pull/199#issuecomment-3390405175 * [temp] no docker container for workflow telemetry to work * sudo in workflow for apt-get * use --dist worksteal in CI pytest * revert temp CI changes * minor fixes. accidentally removed container in CI lol * link to issue in script we do the weird thing because of CUDA issues. we should remove it at some point, and make sure that the CUDA worker catches it https://github.com/goodfire-ai/spd/issues/201#issue-3503138939 --- .github/workflows/checks.yaml | 2 +- spd/clustering/merge_run_config.py | 15 ++++++++++++- .../pipeline/clustering_pipeline.py | 14 ++++++++++++- spd/clustering/pipeline/s1_split_dataset.py | 21 +++++++++++++++---- spd/clustering/plotting/activations.py | 18 +++++++++------- spd/clustering/plotting/merge.py | 6 ++++-- spd/clustering/scripts/main.py | 6 ++++++ tests/clustering/scripts/cluster_ss.py | 20 +++++++++++++++--- .../clustering/test_clustering_experiments.py | 1 + 9 files changed, 83 insertions(+), 20 deletions(-) diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 258935d2b..81027df03 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -50,7 +50,7 @@ jobs: run: uv run ruff format . - name: Run tests - run: uv run pytest tests/ --runslow --durations 10 --numprocesses auto + run: uv run pytest tests/ --runslow --durations 20 --numprocesses auto --dist worksteal env: WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }} OMPI_ALLOW_RUN_AS_ROOT: 1 diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index feba16967..d86cbc7a6 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -71,7 +71,7 @@ class ClusteringRunConfig(BaseModel): """Configuration for a complete merge clustering run. Extends MergeConfig with parameters for model, dataset, and batch configuration. - CLI parameters (base_path, devices, workers_per_device) have defaults but will always be overridden + CLI parameters (base_path, devices, workers_per_device, dataset_streaming) have defaults but will always be overridden """ merge_config: MergeConfig = Field( @@ -100,6 +100,10 @@ class ClusteringRunConfig(BaseModel): default="perm_invariant_hamming", description="Method to use for computing distances between clusterings", ) + dataset_streaming: bool = Field( + default=False, + description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) # Implementation details # note that these are *always* overriden by CLI args in `spd/clustering/scripts/main.py`, but we have to have defaults here @@ -161,6 +165,15 @@ def validate_intervals(self) -> Self: return self + @model_validator(mode="after") + def validate_streaming_compatibility(self) -> Self: + """Ensure dataset_streaming is only enabled for compatible tasks.""" + if self.dataset_streaming and self.task_name != "lm": + raise ValueError( + f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" + ) + return self + @property def wandb_decomp_model(self) -> str: """Extract the WandB run ID of the source decomposition from the model_path diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py index 1e07c71d7..8c6b72f9d 100644 --- a/spd/clustering/pipeline/clustering_pipeline.py +++ b/spd/clustering/pipeline/clustering_pipeline.py @@ -46,9 +46,21 @@ def main(config: ClusteringRunConfig) -> None: # Split dataset into batches logger.info(f"Splitting dataset into {config.n_batches} batches...") + split_dataset_kwargs: dict[str, Any] = dict() + if config.dataset_streaming: + logger.info("Using streaming dataset loading") + split_dataset_kwargs["config_kwargs"] = dict(streaming=True) + # check this here as well as the model validator because we edit `config.dataset_streaming` after init in main() after the CLI args are parsed + # not sure if this is actually a problem though + assert config.task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{config.task_name = }'. Remove dataset_streaming=True from config or use a different task." + ) batches: Iterator[BatchTensor] dataset_config: dict[str, Any] - batches, dataset_config = split_dataset(config=config) + batches, dataset_config = split_dataset( + config=config, + **split_dataset_kwargs, + ) storage.save_batches(batches=batches, config=dataset_config) batch_paths: list[Path] = storage.get_batch_paths() n_batch_paths: int = len(batch_paths) diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py index d5427e600..711cda65a 100644 --- a/spd/clustering/pipeline/s1_split_dataset.py +++ b/spd/clustering/pipeline/s1_split_dataset.py @@ -21,7 +21,10 @@ from spd.models.component_model import ComponentModel, SPDRunInfo -def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], dict[str, Any]]: +def split_dataset( + config: ClusteringRunConfig, + **kwargs: Any, +) -> tuple[Iterator[BatchTensor], dict[str, Any]]: """Split a dataset into n_batches of batch_size, returning iterator and config""" ds: Generator[BatchTensor, None, None] ds_config_dict: dict[str, Any] @@ -30,11 +33,13 @@ def split_dataset(config: ClusteringRunConfig) -> tuple[Iterator[BatchTensor], d ds, ds_config_dict = _get_dataloader_lm( model_path=config.model_path, batch_size=config.batch_size, + **kwargs, ) case "resid_mlp": ds, ds_config_dict = _get_dataloader_resid_mlp( model_path=config.model_path, batch_size=config.batch_size, + **kwargs, ) case name: raise ValueError( @@ -56,6 +61,7 @@ def limited_iterator() -> Iterator[BatchTensor]: def _get_dataloader_lm( model_path: str, batch_size: int, + config_kwargs: dict[str, Any] | None = None, ) -> tuple[Generator[BatchTensor, None, None], dict[str, Any]]: """split up a SS dataset into n_batches of batch_size, returned the saved paths @@ -81,15 +87,22 @@ def _get_dataloader_lm( f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" ) + config_kwargs_: dict[str, Any] = { + **dict( + is_tokenized=False, + streaming=False, + seed=0, + ), + **(config_kwargs or {}), + } + dataset_config: DatasetConfig = DatasetConfig( name=cfg.task_config.dataset_name, hf_tokenizer_path=pretrained_model_name, split=cfg.task_config.train_data_split, n_ctx=cfg.task_config.max_seq_len, - is_tokenized=False, - streaming=False, - seed=0, column_name=cfg.task_config.column_name, + **config_kwargs_, ) with SpinnerContext(message="getting dataloader..."): diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py index 2411eca38..dc8c8658c 100644 --- a/spd/clustering/plotting/activations.py +++ b/spd/clustering/plotting/activations.py @@ -2,6 +2,7 @@ from collections.abc import Sequence from pathlib import Path +from typing import Literal import matplotlib as mpl import matplotlib.pyplot as plt @@ -20,7 +21,7 @@ def plot_activations( processed_activations: ProcessedActivations, save_dir: Path, n_samples_max: int, - pdf_prefix: str = "activations", + figure_prefix: str = "activations", figsize_raw: tuple[int, int] = (12, 4), figsize_concat: tuple[int, int] = (12, 2), figsize_coact: tuple[int, int] = (8, 6), @@ -28,6 +29,7 @@ def plot_activations( hist_bins: int = 100, do_sorted_samples: bool = False, wandb_run: wandb.sdk.wandb_run.Run | None = None, + save_fmt: Literal["pdf", "png", "svg"] = "pdf", ) -> None: """Plot activation visualizations including raw, concatenated, sorted, and coactivations. @@ -37,7 +39,7 @@ def plot_activations( coact: Coactivation matrix labels: Component labels save_dir: The directory to save the plots to - pdf_prefix: Prefix for PDF filenames + figure_prefix: Prefix for figure filenames figsize_raw: Figure size for raw activations figsize_concat: Figure size for concatenated activations figsize_coact: Figure size for coactivations @@ -77,7 +79,7 @@ def plot_activations( axs_act[i].set_ylabel(f"components\n{key}") axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})") - fig1_fname = save_dir / f"{pdf_prefix}_raw.pdf" + fig1_fname = save_dir / f"{figure_prefix}_raw.{save_fmt}" _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -100,7 +102,7 @@ def plot_activations( plt.colorbar(im2) - fig2_fname: Path = save_dir / f"{pdf_prefix}_concatenated.pdf" + fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.{save_fmt}" fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -169,7 +171,7 @@ def plot_activations( plt.colorbar(im3) - fig3_fname: Path = save_dir / f"{pdf_prefix}_concatenated_sorted.pdf" + fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.{save_fmt}" fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -193,7 +195,7 @@ def plot_activations( plt.colorbar(im4) - fig4_fname: Path = save_dir / f"{pdf_prefix}_coactivations.pdf" + fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.{save_fmt}" fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -217,7 +219,7 @@ def plot_activations( add_component_labeling(ax4_log, labels, axis="x") add_component_labeling(ax4_log, labels, axis="y") plt.colorbar(im4_log) - fig4_log_fname: Path = save_dir / f"{pdf_prefix}_coactivations_log.pdf" + fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.{save_fmt}" fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) # Log to WandB if available @@ -312,7 +314,7 @@ def plot_activations( plt.tight_layout() - fig5_fname: Path = save_dir / f"{pdf_prefix}_histograms.pdf" + fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.{save_fmt}" fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) # Log to WandB if available diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py index e470b3114..8a2cc20df 100644 --- a/spd/clustering/plotting/merge.py +++ b/spd/clustering/plotting/merge.py @@ -17,7 +17,7 @@ figsize=(16, 10), tick_spacing=5, save_pdf=False, - pdf_prefix="merge_iteration", + figure_prefix="merge_iteration", ) @@ -168,7 +168,9 @@ def plot_merge_iteration( if plot_config_["save_pdf"]: fig.savefig( - f"{plot_config_['pdf_prefix']}_iter_{iteration:03d}.pdf", bbox_inches="tight", dpi=300 + f"{plot_config_['figure_prefix']}_iter_{iteration:03d}.pdf", + bbox_inches="tight", + dpi=300, ) if show: diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py index 56cee4d84..2104482e5 100644 --- a/spd/clustering/scripts/main.py +++ b/spd/clustering/scripts/main.py @@ -43,6 +43,11 @@ def cli() -> None: default=1, help="Maximum number of concurrent clustering processes per device (default: 1)", ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) args: argparse.Namespace = parser.parse_args() logger.info("Starting clustering pipeline") @@ -66,6 +71,7 @@ def cli() -> None: config.base_path = args.base_path config.devices = devices config.workers_per_device = args.workers_per_device + config.dataset_streaming = args.dataset_streaming logger.info(f"Configuration loaded: {config.config_identifier}") logger.info(f"Base path: {config.base_path}") diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 6ede368f0..00ef733d6 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -1,7 +1,11 @@ # %% +import os + +# Suppress tokenizer parallelism warning when forking +os.environ["TOKENIZERS_PARALLELISM"] = "false" + from pathlib import Path -import matplotlib.pyplot as plt import torch from jaxtyping import Int from muutils.dbg import dbg_auto @@ -49,7 +53,11 @@ n_batches=1, batch_size=2, ) -BATCHES, _ = split_dataset(config=CONFIG) + +BATCHES, _ = split_dataset( + config=CONFIG, + config_kwargs=dict(streaming=True), # see https://github.com/goodfire-ai/spd/pull/199 +) # %% # Load data batch @@ -85,6 +93,7 @@ save_dir=TEMP_DIR, n_samples_max=256, wandb_run=None, + save_fmt="svg", ) # %% @@ -126,4 +135,9 @@ distances=DISTANCES, mode="points", ) -plt.legend() + +# %% +# Exit cleanly to avoid CUDA thread GIL issues during interpreter shutdown +# see https://github.com/goodfire-ai/spd/issues/201#issue-3503138939 +# ============================================================ +os._exit(0) diff --git a/tests/clustering/test_clustering_experiments.py b/tests/clustering/test_clustering_experiments.py index 5031adfce..19ff937f6 100644 --- a/tests/clustering/test_clustering_experiments.py +++ b/tests/clustering/test_clustering_experiments.py @@ -87,6 +87,7 @@ def test_clustering_with_simplestories_config(): "spd-cluster", "--config", str(config_path), + "--dataset-streaming", # see https://github.com/goodfire-ai/spd/pull/199 ], capture_output=True, text=True, From 13db0beddb6eafa14f5a4b33a0b112c3a12834b9 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 10 Oct 2025 17:20:20 +0100 Subject: [PATCH 05/77] add num_nonsingleton_groups stat from PR170 see https://github.com/goodfire-ai/spd/pull/170 --- spd/clustering/pipeline/s2_clustering.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spd/clustering/pipeline/s2_clustering.py b/spd/clustering/pipeline/s2_clustering.py index d04b16bc5..fc782c96e 100644 --- a/spd/clustering/pipeline/s2_clustering.py +++ b/spd/clustering/pipeline/s2_clustering.py @@ -302,6 +302,7 @@ def _log_callback( run.log( { "fraction_singleton_groups": float(fraction_singleton_groups), + "num_nonsingleton_groups": int((group_sizes > 1).sum().item()), "fraction_zero_coacts": float(fraction_zero_coacts), }, step=iter_idx, From 0c74b5dca93cacdd299067bdf8aaeb7bac078221 Mon Sep 17 00:00:00 2001 From: mivanit Date: Mon, 13 Oct 2025 08:33:26 -0700 Subject: [PATCH 06/77] switch BaseModel to BaseConfig, get rid of old save/read logic (#209) adapt `clustering/main` and the associated #198 to use the new `BaseConfig` introduced in #200 * switch BaseModel to BaseConfig, get rid of old save/read logic * fix typo * fix pydantic validation issue * use model_copy to avoid editing frozen dict when updating ClusteringRunConfig from CLI * remove deprecated config fields --- spd/clustering/configs/example.toml | 1 - spd/clustering/configs/example.yaml | 1 - spd/clustering/configs/test-resid_mlp1.json | 3 +- spd/clustering/merge_config.py | 4 +- spd/clustering/merge_run_config.py | 120 ++++++++------------ spd/clustering/pipeline/s2_clustering.py | 2 +- spd/clustering/pipeline/storage.py | 2 +- spd/clustering/scripts/main.py | 15 ++- spd/utils/wandb_utils.py | 4 +- tests/clustering/test_storage.py | 2 +- 10 files changed, 63 insertions(+), 91 deletions(-) diff --git a/spd/clustering/configs/example.toml b/spd/clustering/configs/example.toml index d5cfe46d6..98053576b 100644 --- a/spd/clustering/configs/example.toml +++ b/spd/clustering/configs/example.toml @@ -30,7 +30,6 @@ iters = 100 # iterations to run. setting this to exactly the number of componen pop_component_prob = 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold = 0.001 # Threshold for filtering dead components module_name_filter = "__NULL__" # Can be a string prefix like "model.layers.0." if you want to do only some modules -rank_cost_fn_name = "const_1" # Options: const_1, const_2, log, linear merge_pair_sampling_method = "range" # Method for sampling merge pairs: 'range' or 'mcmc' [merge_config.merge_pair_sampling_kwargs] diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/example.yaml index 5f3cd5fa5..259f1597c 100644 --- a/spd/clustering/configs/example.yaml +++ b/spd/clustering/configs/example.yaml @@ -11,7 +11,6 @@ merge_config: pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold: 0.001 # Threshold for filtering dead components module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules - rank_cost_fn_name: const_1 # Options: const_1, const_2, log, linear # Run configuration model_path: wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh # WandB path to the decomposed model diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/test-resid_mlp1.json index 75877dd25..fbacff53a 100644 --- a/spd/clustering/configs/test-resid_mlp1.json +++ b/spd/clustering/configs/test-resid_mlp1.json @@ -7,8 +7,7 @@ "merge_pair_sampling_kwargs": {"threshold": 0.05}, "pop_component_prob": 0, "filter_dead_threshold": 0.1, - "module_name_filter": null, - "rank_cost_fn_name": "const_1" + "module_name_filter": null }, "experiment_key": "resid_mlp1", "n_batches": 2, diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 03c601a9f..3bf8b6d5b 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -3,11 +3,11 @@ from typing import Any, Literal from pydantic import ( - BaseModel, Field, PositiveInt, ) +from spd.base_config import BaseConfig from spd.clustering.consts import ClusterCoactivationShaped, MergePair from spd.clustering.math.merge_pair_samplers import ( MERGE_PAIR_SAMPLERS, @@ -44,7 +44,7 @@ def _to_module_filter( raise TypeError(f"filter_modules must be str, set, or callable, got {type(filter_modules)}") # pyright: ignore[reportUnreachable] -class MergeConfig(BaseModel): +class MergeConfig(BaseConfig): activation_threshold: Probability | None = Field( default=0.01, description="Threshold for considering a component active in a group. If None, use raw scalar causal importances", diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index d86cbc7a6..b6b8d6ab6 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -1,16 +1,15 @@ """Configuration for merge clustering runs that combines merge config with run parameters.""" import hashlib -import json import tomllib import warnings from pathlib import Path from typing import Any, Literal, Self -import yaml from muutils.misc.numerical import shorten_numerical_to_str -from pydantic import BaseModel, Field, PositiveInt, model_validator +from pydantic import Field, PositiveInt, model_validator +from spd.base_config import BaseConfig from spd.clustering.consts import DistancesMethod from spd.clustering.merge_config import MergeConfig from spd.registry import EXPERIMENT_REGISTRY, ExperimentConfig @@ -67,7 +66,7 @@ def replace_sentinel_recursive(obj: Any) -> Any: return replace_sentinel_recursive(data) -class ClusteringRunConfig(BaseModel): +class ClusteringRunConfig(BaseConfig): """Configuration for a complete merge clustering run. Extends MergeConfig with parameters for model, dataset, and batch configuration. @@ -146,24 +145,27 @@ def validate_model_path(self) -> Self: ) return self - @model_validator(mode="after") - def validate_intervals(self) -> Self: + @model_validator(mode="before") + @classmethod + def validate_intervals(cls, data: dict[str, Any]) -> dict[str, Any]: """Ensure all required interval keys are present.""" + + data_intervals: dict[IntervalKey, Any] = data.get("intervals", {}) # warning if any keys are missing - missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(self.intervals.keys()) + missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(data_intervals.keys()) if missing_keys: warnings.warn( - f"Missing interval keys in {self.intervals = }: {missing_keys}. Using defaults for those.", + f"Missing interval keys in {data_intervals = }: {missing_keys}. Using defaults for those.", UserWarning, stacklevel=1, ) - self.intervals = { + data["intervals"] = { **_DEFAULT_INTERVALS, - **self.intervals, + **data_intervals, } - return self + return data @model_validator(mode="after") def validate_streaming_compatibility(self) -> Self: @@ -174,6 +176,38 @@ def validate_streaming_compatibility(self) -> Self: ) return self + @model_validator(mode="before") + @classmethod + def handle_experiment_key(cls, data: dict[str, Any]) -> dict[str, Any]: + """handle passing experiment key instead of model_path and task_name. + + if we provide an experiment_key, then: + 1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name + 2. check it's consistent with model_path and task_name from the file if those are provided + + """ + experiment_key: str | None = data.get("experiment_key") + model_path: str | None = data.get("model_path") + task_name: str | None = data.get("task_name") + if experiment_key is not None: + exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key] + + # Enforce consistency if explicit fields present + if model_path is not None: + assert model_path == exp_config.canonical_run, ( + f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})" + ) + if task_name is not None: + assert task_name == exp_config.task_name, ( + f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})" + ) + + # overwrite in data dict + data["model_path"] = exp_config.canonical_run + data["task_name"] = exp_config.task_name + + return data + @property def wandb_decomp_model(self) -> str: """Extract the WandB run ID of the source decomposition from the model_path @@ -213,70 +247,6 @@ def stable_hash(self) -> str: """Generate a stable hash including all config parameters.""" return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6] - @classmethod - def read(cls, path: Path) -> "ClusteringRunConfig": - """Load config from JSON, YAML, or TOML file. - - Handles legacy spd_exp: model_path format and enforces consistency. - For TOML files, the sentinel value "__NULL__" is converted to None. - """ - # read the file contents, load them according to extension - data: dict[str, Any] - content: str - if path.suffix == ".json": - content = path.read_text() - data = json.loads(content) - elif path.suffix in [".yaml", ".yml"]: - content = path.read_text() - data = yaml.safe_load(content) - elif path.suffix == ".toml": - data = toml_read_file_with_none(path) - else: - raise ValueError( - f"Unsupported file extension '{path.suffix}' on file '{path}' -- must be .json, .yaml, .yml, or .toml" - ) - - # if we provide an experiment_key, then: - # 1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name - # 2. check it's consistent with model_path and task_name from the file if those are provided - experiment_key: str | None = data.get("experiment_key") - model_path: str | None = data.get("model_path") - task_name: str | None = data.get("task_name") - if experiment_key is not None: - exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key] - - # Enforce consistency if explicit fields present - if model_path is not None: - assert model_path == exp_config.canonical_run, ( - f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})" - ) - if task_name is not None: - assert task_name == exp_config.task_name, ( - f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})" - ) - - # overwrite in data dict - data["model_path"] = exp_config.canonical_run - data["task_name"] = exp_config.task_name - - return cls.model_validate(data) - - def save(self, path: Path) -> None: - """Save config to file (format inferred from extension).""" - path.parent.mkdir(parents=True, exist_ok=True) - if path.suffix == ".json": - path.write_text(self.model_dump_json(indent=2)) - elif path.suffix in [".yaml", ".yml"]: - path.write_text( - yaml.dump( - self.model_dump(mode="json"), - default_flow_style=False, - sort_keys=False, - ) - ) - else: - raise ValueError(f"Unsupported file extension: {path.suffix}") - def model_dump_with_properties(self) -> dict[str, Any]: """Serialize config including computed properties for WandB logging.""" base_dump: dict[str, Any] = self.model_dump() diff --git a/spd/clustering/pipeline/s2_clustering.py b/spd/clustering/pipeline/s2_clustering.py index fc782c96e..116d7dc61 100644 --- a/spd/clustering/pipeline/s2_clustering.py +++ b/spd/clustering/pipeline/s2_clustering.py @@ -384,7 +384,7 @@ def cli() -> None: args: argparse.Namespace = parser.parse_args() # Load config - config: ClusteringRunConfig = ClusteringRunConfig.read(args.config) + config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config) # Run clustering result: ClusteringResult = run_clustering( diff --git a/spd/clustering/pipeline/storage.py b/spd/clustering/pipeline/storage.py index cb38befd8..febe9bcc8 100644 --- a/spd/clustering/pipeline/storage.py +++ b/spd/clustering/pipeline/storage.py @@ -270,7 +270,7 @@ def save_run_config(self, config: ClusteringRunConfig) -> Path: ) def load_run_config(self) -> ClusteringRunConfig: - return ClusteringRunConfig.read(self.run_config_file) + return ClusteringRunConfig.from_file(self.run_config_file) # Dashboard storage methods diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py index 2104482e5..65e224f5e 100644 --- a/spd/clustering/scripts/main.py +++ b/spd/clustering/scripts/main.py @@ -67,11 +67,16 @@ def cli() -> None: # Note that the defaults for args here always override the default values in `RunConfig` itself, # but we must have those defaults to avoid type issues logger.info(f"Loading config from {args.config}") - config: ClusteringRunConfig = ClusteringRunConfig.read(args.config) - config.base_path = args.base_path - config.devices = devices - config.workers_per_device = args.workers_per_device - config.dataset_streaming = args.dataset_streaming + config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config) + # Use model_copy to update frozen fields + config = config.model_copy( + update={ + "base_path": args.base_path, + "devices": devices, + "workers_per_device": args.workers_per_device, + "dataset_streaming": args.dataset_streaming, + } + ) logger.info(f"Configuration loaded: {config.config_identifier}") logger.info(f"Base path: {config.base_path}") diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index 2cac6dd80..855440804 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -6,9 +6,9 @@ import wandb_workspaces.reports.v2 as wr import wandb_workspaces.workspaces as ws from dotenv import load_dotenv -from pydantic import BaseModel from wandb.apis.public import File, Run +from spd.base_config import BaseConfig from spd.log import logger from spd.registry import EXPERIMENT_REGISTRY from spd.settings import REPO_ROOT @@ -120,7 +120,7 @@ def download_wandb_file(run: Run, wandb_run_dir: Path, file_name: str) -> Path: return path -def init_wandb[T_config: BaseModel]( +def init_wandb[T_config: BaseConfig]( config: T_config, project: str, name: str | None = None, tags: list[str] | None = None ) -> T_config: """Initialize Weights & Biases and return a config updated with sweep hyperparameters. diff --git a/tests/clustering/test_storage.py b/tests/clustering/test_storage.py index d5e3d535e..389940e54 100644 --- a/tests/clustering/test_storage.py +++ b/tests/clustering/test_storage.py @@ -91,7 +91,7 @@ def test_save_and_load_run_config(self, temp_storage: ClusteringStorage): assert saved_path == temp_storage.run_config_file # Load and verify - loaded_config = ClusteringRunConfig.read(saved_path) + loaded_config = ClusteringRunConfig.from_file(saved_path) assert loaded_config.n_batches == 5 assert loaded_config.batch_size == 32 assert loaded_config.task_name == "lm" From 3f07f420f7608889ff3199a0ef262108ad60b329 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 13 Oct 2025 17:49:01 +0100 Subject: [PATCH 07/77] switch to new run --- spd/clustering/configs/simplestories_dev.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/simplestories_dev.json index c82b11710..552309465 100644 --- a/spd/clustering/configs/simplestories_dev.json +++ b/spd/clustering/configs/simplestories_dev.json @@ -9,7 +9,7 @@ "filter_dead_threshold": 0.1, "module_name_filter": null }, - "model_path": "wandb:goodfire/spd-pre-Sep-2025/runs/rn9klzfs", + "model_path": "wandb:goodfire/spd/runs/lxs77xye", "task_name": "lm", "distances_method": "jaccard", "n_batches": 1, From c06cffe04ef5915a92d962c030e39bc4a0988547 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 16 Oct 2025 16:31:10 +0100 Subject: [PATCH 08/77] wip sigmoid issues --- spd/clustering/activations.py | 8 +- spd/clustering/pipeline/s1_split_dataset.py | 6 +- spd/clustering/pipeline/s2_clustering.py | 1 - .../pipeline/s4_compute_distances.py | 2 +- spd/clustering/plotting/activations.py | 4 +- spd/clustering/plotting/merge.py | 4 +- tests/clustering/scripts/cluster_resid_mlp.py | 1 - tests/clustering/scripts/cluster_ss.py | 1 - .../test_alive_components_distributed.py | 7 +- uv.lock | 218 +++++++++--------- 10 files changed, 117 insertions(+), 135 deletions(-) diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index 6b7c51abf..d0b1a51a6 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -14,14 +14,12 @@ ) from spd.clustering.util import ModuleFilterFunc from spd.models.component_model import ComponentModel, OutputWithCache -from spd.models.sigmoids import SigmoidTypes def component_activations( model: ComponentModel, device: torch.device | str, batch: Int[Tensor, "batch_size n_ctx"], - sigmoid_type: SigmoidTypes, ) -> dict[str, ActivationsTensor]: """Get the component activations over a **single** batch.""" causal_importances: dict[str, ActivationsTensor] @@ -31,12 +29,12 @@ def component_activations( cache_type="input", ) - causal_importances, _ = model.calc_causal_importances( + # TODO: !!!IMPORTANT!!! unclear whether pre_sigmoid is the right thing to use here + causal_importances = model.calc_causal_importances( pre_weight_acts=model_output.cache, - sigmoid_type=sigmoid_type, sampling="continuous", detach_inputs=False, - ) + ).pre_sigmoid return causal_importances diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py index 711cda65a..94ac1a8bf 100644 --- a/spd/clustering/pipeline/s1_split_dataset.py +++ b/spd/clustering/pipeline/s1_split_dataset.py @@ -26,7 +26,7 @@ def split_dataset( **kwargs: Any, ) -> tuple[Iterator[BatchTensor], dict[str, Any]]: """Split a dataset into n_batches of batch_size, returning iterator and config""" - ds: Generator[BatchTensor, None, None] + ds: Generator[BatchTensor] ds_config_dict: dict[str, Any] match config.task_name: case "lm": @@ -62,7 +62,7 @@ def _get_dataloader_lm( model_path: str, batch_size: int, config_kwargs: dict[str, Any] | None = None, -) -> tuple[Generator[BatchTensor, None, None], dict[str, Any]]: +) -> tuple[Generator[BatchTensor], dict[str, Any]]: """split up a SS dataset into n_batches of batch_size, returned the saved paths 1. load the config for a SimpleStories SPD Run given by model_path @@ -122,7 +122,7 @@ def _get_dataloader_lm( def _get_dataloader_resid_mlp( model_path: str, batch_size: int, -) -> tuple[Generator[torch.Tensor, None, None], dict[str, Any]]: +) -> tuple[Generator[torch.Tensor], dict[str, Any]]: """Split a ResidMLP dataset into n_batches of batch_size and save the batches.""" from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset from spd.utils.data_utils import DatasetGeneratedDataLoader diff --git a/spd/clustering/pipeline/s2_clustering.py b/spd/clustering/pipeline/s2_clustering.py index 116d7dc61..bfeeadfbe 100644 --- a/spd/clustering/pipeline/s2_clustering.py +++ b/spd/clustering/pipeline/s2_clustering.py @@ -118,7 +118,6 @@ def logger_call(msg: str) -> None: model=model, batch=batch, device=device, - sigmoid_type=spd_run.config.sigmoid_type, ) logger_call("computed activations") diff --git a/spd/clustering/pipeline/s4_compute_distances.py b/spd/clustering/pipeline/s4_compute_distances.py index 5c0e05124..9e1de4974 100644 --- a/spd/clustering/pipeline/s4_compute_distances.py +++ b/spd/clustering/pipeline/s4_compute_distances.py @@ -57,7 +57,7 @@ def create_clustering_report( plt.legend() # Get the figure from the axes - fig: plt.Figure | None = ax.get_figure() + fig = ax.get_figure(root=True) assert fig is not None # Log the plot diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py index dc8c8658c..eb7a86b01 100644 --- a/spd/clustering/plotting/activations.py +++ b/spd/clustering/plotting/activations.py @@ -67,7 +67,7 @@ def plot_activations( # Raw activations axs_act: Sequence[plt.Axes] _fig1: plt.Figure - _fig1, axs_act = plt.subplots(len(act_dict), 1, figsize=figsize_raw) # pyright: ignore[reportAssignmentType] + _fig1, axs_act = plt.subplots(len(act_dict), 1, figsize=figsize_raw) if len(act_dict) == 1: assert isinstance(axs_act, plt.Axes) axs_act = [axs_act] @@ -234,7 +234,7 @@ def plot_activations( ax5a: plt.Axes ax5b: plt.Axes ax5c: plt.Axes - fig5, (ax5a, ax5b, ax5c) = plt.subplots(1, 3, figsize=(15, 4)) # pyright: ignore[reportGeneralTypeIssues] + fig5, (ax5a, ax5b, ax5c) = plt.subplots(1, 3, figsize=(15, 4)) x_scale: str y_scale: str diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py index 8a2cc20df..049f06e29 100644 --- a/spd/clustering/plotting/merge.py +++ b/spd/clustering/plotting/merge.py @@ -45,7 +45,7 @@ def plot_merge_matrix( assert not show_row_sums else: if show_row_sums: - _fig, (ax_mat, ax_lbl) = plt.subplots( # pyright: ignore[reportGeneralTypeIssues] + _fig, (ax_mat, ax_lbl) = plt.subplots( 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [10, 1]} ) else: @@ -114,7 +114,7 @@ def plot_merge_iteration( **(plot_config or {}), } axs: list[plt.Axes] - fig, axs = plt.subplots( # pyright: ignore[reportAssignmentType] + fig, axs = plt.subplots( 1, 3, figsize=plot_config_["figsize"], sharey=True, gridspec_kw={"width_ratios": [2, 1, 1]} ) diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index df8dd0a2f..010e5ecfc 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -87,7 +87,6 @@ model=MODEL, device=DEVICE, batch=BATCH, - sigmoid_type="hard", ) dbg_auto(COMPONENT_ACTS) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 00ef733d6..43399d039 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -71,7 +71,6 @@ model=MODEL, batch=DATA_BATCH, device=DEVICE, - sigmoid_type="hard", ) _ = dbg_auto(COMPONENT_ACTS) diff --git a/tests/metrics/test_alive_components_distributed.py b/tests/metrics/test_alive_components_distributed.py index 5f054fd05..bd7a09b8b 100644 --- a/tests/metrics/test_alive_components_distributed.py +++ b/tests/metrics/test_alive_components_distributed.py @@ -133,9 +133,7 @@ def _test_dead_components(): print(f"Rank {rank} n_batches_since_fired: {metric.n_batches_since_fired['layer1']}") result = metric.compute() # only components 0 and 1 alive - assert result["layer1"] == 2, ( - f"Expected 2 alive components, got {result['layer1']}" - ) + assert result["layer1"] == 2, f"Expected 2 alive components, got {result['layer1']}" if rank == 0: print(f"✓ Dead components test passed (n_alive={result['layer1']})") @@ -179,8 +177,7 @@ def _test_multiple_modules(): if rank == 0: print( - f"✓ Multiple modules test passed " - f"(layer1={result['layer1']}, layer2={result['layer2']})" + f"✓ Multiple modules test passed (layer1={result['layer1']}, layer2={result['layer2']})" ) diff --git a/uv.lock b/uv.lock index 18c3f02a6..9ae8b57da 100644 --- a/uv.lock +++ b/uv.lock @@ -287,37 +287,37 @@ wheels = [ [[package]] name = "coverage" -version = "7.10.7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/26/d22c300112504f5f9a9fd2297ce33c35f3d353e4aeb987c8419453b2a7c2/coverage-7.10.7.tar.gz", hash = "sha256:f4ab143ab113be368a3e9b795f9cd7906c5ef407d6173fe9675a902e1fffc239", size = 827704, upload-time = "2025-09-21T20:03:56.815Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/94/b765c1abcb613d103b64fcf10395f54d69b0ef8be6a0dd9c524384892cc7/coverage-7.10.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:981a651f543f2854abd3b5fcb3263aac581b18209be49863ba575de6edf4c14d", size = 218320, upload-time = "2025-09-21T20:01:56.629Z" }, - { url = "https://files.pythonhosted.org/packages/72/4f/732fff31c119bb73b35236dd333030f32c4bfe909f445b423e6c7594f9a2/coverage-7.10.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:73ab1601f84dc804f7812dc297e93cd99381162da39c47040a827d4e8dafe63b", size = 218575, upload-time = "2025-09-21T20:01:58.203Z" }, - { url = "https://files.pythonhosted.org/packages/87/02/ae7e0af4b674be47566707777db1aa375474f02a1d64b9323e5813a6cdd5/coverage-7.10.7-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a8b6f03672aa6734e700bbcd65ff050fd19cddfec4b031cc8cf1c6967de5a68e", size = 249568, upload-time = "2025-09-21T20:01:59.748Z" }, - { url = "https://files.pythonhosted.org/packages/a2/77/8c6d22bf61921a59bce5471c2f1f7ac30cd4ac50aadde72b8c48d5727902/coverage-7.10.7-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10b6ba00ab1132a0ce4428ff68cf50a25efd6840a42cdf4239c9b99aad83be8b", size = 252174, upload-time = "2025-09-21T20:02:01.192Z" }, - { url = "https://files.pythonhosted.org/packages/b1/20/b6ea4f69bbb52dac0aebd62157ba6a9dddbfe664f5af8122dac296c3ee15/coverage-7.10.7-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c79124f70465a150e89340de5963f936ee97097d2ef76c869708c4248c63ca49", size = 253447, upload-time = "2025-09-21T20:02:02.701Z" }, - { url = "https://files.pythonhosted.org/packages/f9/28/4831523ba483a7f90f7b259d2018fef02cb4d5b90bc7c1505d6e5a84883c/coverage-7.10.7-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:69212fbccdbd5b0e39eac4067e20a4a5256609e209547d86f740d68ad4f04911", size = 249779, upload-time = "2025-09-21T20:02:04.185Z" }, - { url = "https://files.pythonhosted.org/packages/a7/9f/4331142bc98c10ca6436d2d620c3e165f31e6c58d43479985afce6f3191c/coverage-7.10.7-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7ea7c6c9d0d286d04ed3541747e6597cbe4971f22648b68248f7ddcd329207f0", size = 251604, upload-time = "2025-09-21T20:02:06.034Z" }, - { url = "https://files.pythonhosted.org/packages/ce/60/bda83b96602036b77ecf34e6393a3836365481b69f7ed7079ab85048202b/coverage-7.10.7-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b9be91986841a75042b3e3243d0b3cb0b2434252b977baaf0cd56e960fe1e46f", size = 249497, upload-time = "2025-09-21T20:02:07.619Z" }, - { url = "https://files.pythonhosted.org/packages/5f/af/152633ff35b2af63977edd835d8e6430f0caef27d171edf2fc76c270ef31/coverage-7.10.7-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:b281d5eca50189325cfe1f365fafade89b14b4a78d9b40b05ddd1fc7d2a10a9c", size = 249350, upload-time = "2025-09-21T20:02:10.34Z" }, - { url = "https://files.pythonhosted.org/packages/9d/71/d92105d122bd21cebba877228990e1646d862e34a98bb3374d3fece5a794/coverage-7.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:99e4aa63097ab1118e75a848a28e40d68b08a5e19ce587891ab7fd04475e780f", size = 251111, upload-time = "2025-09-21T20:02:12.122Z" }, - { url = "https://files.pythonhosted.org/packages/a2/9e/9fdb08f4bf476c912f0c3ca292e019aab6712c93c9344a1653986c3fd305/coverage-7.10.7-cp313-cp313-win32.whl", hash = "sha256:dc7c389dce432500273eaf48f410b37886be9208b2dd5710aaf7c57fd442c698", size = 220746, upload-time = "2025-09-21T20:02:13.919Z" }, - { url = "https://files.pythonhosted.org/packages/b1/b1/a75fd25df44eab52d1931e89980d1ada46824c7a3210be0d3c88a44aaa99/coverage-7.10.7-cp313-cp313-win_amd64.whl", hash = "sha256:cac0fdca17b036af3881a9d2729a850b76553f3f716ccb0360ad4dbc06b3b843", size = 221541, upload-time = "2025-09-21T20:02:15.57Z" }, - { url = "https://files.pythonhosted.org/packages/14/3a/d720d7c989562a6e9a14b2c9f5f2876bdb38e9367126d118495b89c99c37/coverage-7.10.7-cp313-cp313-win_arm64.whl", hash = "sha256:4b6f236edf6e2f9ae8fcd1332da4e791c1b6ba0dc16a2dc94590ceccb482e546", size = 220170, upload-time = "2025-09-21T20:02:17.395Z" }, - { url = "https://files.pythonhosted.org/packages/bb/22/e04514bf2a735d8b0add31d2b4ab636fc02370730787c576bb995390d2d5/coverage-7.10.7-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a0ec07fd264d0745ee396b666d47cef20875f4ff2375d7c4f58235886cc1ef0c", size = 219029, upload-time = "2025-09-21T20:02:18.936Z" }, - { url = "https://files.pythonhosted.org/packages/11/0b/91128e099035ece15da3445d9015e4b4153a6059403452d324cbb0a575fa/coverage-7.10.7-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd5e856ebb7bfb7672b0086846db5afb4567a7b9714b8a0ebafd211ec7ce6a15", size = 219259, upload-time = "2025-09-21T20:02:20.44Z" }, - { url = "https://files.pythonhosted.org/packages/8b/51/66420081e72801536a091a0c8f8c1f88a5c4bf7b9b1bdc6222c7afe6dc9b/coverage-7.10.7-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f57b2a3c8353d3e04acf75b3fed57ba41f5c0646bbf1d10c7c282291c97936b4", size = 260592, upload-time = "2025-09-21T20:02:22.313Z" }, - { url = "https://files.pythonhosted.org/packages/5d/22/9b8d458c2881b22df3db5bb3e7369e63d527d986decb6c11a591ba2364f7/coverage-7.10.7-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:1ef2319dd15a0b009667301a3f84452a4dc6fddfd06b0c5c53ea472d3989fbf0", size = 262768, upload-time = "2025-09-21T20:02:24.287Z" }, - { url = "https://files.pythonhosted.org/packages/f7/08/16bee2c433e60913c610ea200b276e8eeef084b0d200bdcff69920bd5828/coverage-7.10.7-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:83082a57783239717ceb0ad584de3c69cf581b2a95ed6bf81ea66034f00401c0", size = 264995, upload-time = "2025-09-21T20:02:26.133Z" }, - { url = "https://files.pythonhosted.org/packages/20/9d/e53eb9771d154859b084b90201e5221bca7674ba449a17c101a5031d4054/coverage-7.10.7-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:50aa94fb1fb9a397eaa19c0d5ec15a5edd03a47bf1a3a6111a16b36e190cff65", size = 259546, upload-time = "2025-09-21T20:02:27.716Z" }, - { url = "https://files.pythonhosted.org/packages/ad/b0/69bc7050f8d4e56a89fb550a1577d5d0d1db2278106f6f626464067b3817/coverage-7.10.7-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2120043f147bebb41c85b97ac45dd173595ff14f2a584f2963891cbcc3091541", size = 262544, upload-time = "2025-09-21T20:02:29.216Z" }, - { url = "https://files.pythonhosted.org/packages/ef/4b/2514b060dbd1bc0aaf23b852c14bb5818f244c664cb16517feff6bb3a5ab/coverage-7.10.7-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:2fafd773231dd0378fdba66d339f84904a8e57a262f583530f4f156ab83863e6", size = 260308, upload-time = "2025-09-21T20:02:31.226Z" }, - { url = "https://files.pythonhosted.org/packages/54/78/7ba2175007c246d75e496f64c06e94122bdb914790a1285d627a918bd271/coverage-7.10.7-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:0b944ee8459f515f28b851728ad224fa2d068f1513ef6b7ff1efafeb2185f999", size = 258920, upload-time = "2025-09-21T20:02:32.823Z" }, - { url = "https://files.pythonhosted.org/packages/c0/b3/fac9f7abbc841409b9a410309d73bfa6cfb2e51c3fada738cb607ce174f8/coverage-7.10.7-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:4b583b97ab2e3efe1b3e75248a9b333bd3f8b0b1b8e5b45578e05e5850dfb2c2", size = 261434, upload-time = "2025-09-21T20:02:34.86Z" }, - { url = "https://files.pythonhosted.org/packages/ee/51/a03bec00d37faaa891b3ff7387192cef20f01604e5283a5fabc95346befa/coverage-7.10.7-cp313-cp313t-win32.whl", hash = "sha256:2a78cd46550081a7909b3329e2266204d584866e8d97b898cd7fb5ac8d888b1a", size = 221403, upload-time = "2025-09-21T20:02:37.034Z" }, - { url = "https://files.pythonhosted.org/packages/53/22/3cf25d614e64bf6d8e59c7c669b20d6d940bb337bdee5900b9ca41c820bb/coverage-7.10.7-cp313-cp313t-win_amd64.whl", hash = "sha256:33a5e6396ab684cb43dc7befa386258acb2d7fae7f67330ebb85ba4ea27938eb", size = 222469, upload-time = "2025-09-21T20:02:39.011Z" }, - { url = "https://files.pythonhosted.org/packages/49/a1/00164f6d30d8a01c3c9c48418a7a5be394de5349b421b9ee019f380df2a0/coverage-7.10.7-cp313-cp313t-win_arm64.whl", hash = "sha256:86b0e7308289ddde73d863b7683f596d8d21c7d8664ce1dee061d0bcf3fbb4bb", size = 220731, upload-time = "2025-09-21T20:02:40.939Z" }, - { url = "https://files.pythonhosted.org/packages/ec/16/114df1c291c22cac3b0c127a73e0af5c12ed7bbb6558d310429a0ae24023/coverage-7.10.7-py3-none-any.whl", hash = "sha256:f7941f6f2fe6dd6807a1208737b8a0cbcf1cc6d7b07d24998ad2d63590868260", size = 209952, upload-time = "2025-09-21T20:03:53.918Z" }, +version = "7.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/38/ee22495420457259d2f3390309505ea98f98a5eed40901cf62196abad006/coverage-7.11.0.tar.gz", hash = "sha256:167bd504ac1ca2af7ff3b81d245dfea0292c5032ebef9d66cc08a7d28c1b8050", size = 811905, upload-time = "2025-10-15T15:15:08.542Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/7f/85e4dfe65e400645464b25c036a26ac226cf3a69d4a50c3934c532491cdd/coverage-7.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:cc3f49e65ea6e0d5d9bd60368684fe52a704d46f9e7fc413918f18d046ec40e1", size = 216129, upload-time = "2025-10-15T15:13:25.371Z" }, + { url = "https://files.pythonhosted.org/packages/96/5d/dc5fa98fea3c175caf9d360649cb1aa3715e391ab00dc78c4c66fabd7356/coverage-7.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f39ae2f63f37472c17b4990f794035c9890418b1b8cca75c01193f3c8d3e01be", size = 216380, upload-time = "2025-10-15T15:13:26.976Z" }, + { url = "https://files.pythonhosted.org/packages/b2/f5/3da9cc9596708273385189289c0e4d8197d37a386bdf17619013554b3447/coverage-7.11.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7db53b5cdd2917b6eaadd0b1251cf4e7d96f4a8d24e174bdbdf2f65b5ea7994d", size = 247375, upload-time = "2025-10-15T15:13:28.923Z" }, + { url = "https://files.pythonhosted.org/packages/65/6c/f7f59c342359a235559d2bc76b0c73cfc4bac7d61bb0df210965cb1ecffd/coverage-7.11.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10ad04ac3a122048688387828b4537bc9cf60c0bf4869c1e9989c46e45690b82", size = 249978, upload-time = "2025-10-15T15:13:30.525Z" }, + { url = "https://files.pythonhosted.org/packages/e7/8c/042dede2e23525e863bf1ccd2b92689692a148d8b5fd37c37899ba882645/coverage-7.11.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4036cc9c7983a2b1f2556d574d2eb2154ac6ed55114761685657e38782b23f52", size = 251253, upload-time = "2025-10-15T15:13:32.174Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a9/3c58df67bfa809a7bddd786356d9c5283e45d693edb5f3f55d0986dd905a/coverage-7.11.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7ab934dd13b1c5e94b692b1e01bd87e4488cb746e3a50f798cb9464fd128374b", size = 247591, upload-time = "2025-10-15T15:13:34.147Z" }, + { url = "https://files.pythonhosted.org/packages/26/5b/c7f32efd862ee0477a18c41e4761305de6ddd2d49cdeda0c1116227570fd/coverage-7.11.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59a6e5a265f7cfc05f76e3bb53eca2e0dfe90f05e07e849930fecd6abb8f40b4", size = 249411, upload-time = "2025-10-15T15:13:38.425Z" }, + { url = "https://files.pythonhosted.org/packages/76/b5/78cb4f1e86c1611431c990423ec0768122905b03837e1b4c6a6f388a858b/coverage-7.11.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:df01d6c4c81e15a7c88337b795bb7595a8596e92310266b5072c7e301168efbd", size = 247303, upload-time = "2025-10-15T15:13:40.464Z" }, + { url = "https://files.pythonhosted.org/packages/87/c9/23c753a8641a330f45f221286e707c427e46d0ffd1719b080cedc984ec40/coverage-7.11.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:8c934bd088eed6174210942761e38ee81d28c46de0132ebb1801dbe36a390dcc", size = 247157, upload-time = "2025-10-15T15:13:42.087Z" }, + { url = "https://files.pythonhosted.org/packages/c5/42/6e0cc71dc8a464486e944a4fa0d85bdec031cc2969e98ed41532a98336b9/coverage-7.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a03eaf7ec24078ad64a07f02e30060aaf22b91dedf31a6b24d0d98d2bba7f48", size = 248921, upload-time = "2025-10-15T15:13:43.715Z" }, + { url = "https://files.pythonhosted.org/packages/e8/1c/743c2ef665e6858cccb0f84377dfe3a4c25add51e8c7ef19249be92465b6/coverage-7.11.0-cp313-cp313-win32.whl", hash = "sha256:695340f698a5f56f795b2836abe6fb576e7c53d48cd155ad2f80fd24bc63a040", size = 218526, upload-time = "2025-10-15T15:13:45.336Z" }, + { url = "https://files.pythonhosted.org/packages/ff/d5/226daadfd1bf8ddbccefbd3aa3547d7b960fb48e1bdac124e2dd13a2b71a/coverage-7.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:2727d47fce3ee2bac648528e41455d1b0c46395a087a229deac75e9f88ba5a05", size = 219317, upload-time = "2025-10-15T15:13:47.401Z" }, + { url = "https://files.pythonhosted.org/packages/97/54/47db81dcbe571a48a298f206183ba8a7ba79200a37cd0d9f4788fcd2af4a/coverage-7.11.0-cp313-cp313-win_arm64.whl", hash = "sha256:0efa742f431529699712b92ecdf22de8ff198df41e43aeaaadf69973eb93f17a", size = 217948, upload-time = "2025-10-15T15:13:49.096Z" }, + { url = "https://files.pythonhosted.org/packages/e5/8b/cb68425420154e7e2a82fd779a8cc01549b6fa83c2ad3679cd6c088ebd07/coverage-7.11.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:587c38849b853b157706407e9ebdca8fd12f45869edb56defbef2daa5fb0812b", size = 216837, upload-time = "2025-10-15T15:13:51.09Z" }, + { url = "https://files.pythonhosted.org/packages/33/55/9d61b5765a025685e14659c8d07037247de6383c0385757544ffe4606475/coverage-7.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b971bdefdd75096163dd4261c74be813c4508477e39ff7b92191dea19f24cd37", size = 217061, upload-time = "2025-10-15T15:13:52.747Z" }, + { url = "https://files.pythonhosted.org/packages/52/85/292459c9186d70dcec6538f06ea251bc968046922497377bf4a1dc9a71de/coverage-7.11.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:269bfe913b7d5be12ab13a95f3a76da23cf147be7fa043933320ba5625f0a8de", size = 258398, upload-time = "2025-10-15T15:13:54.45Z" }, + { url = "https://files.pythonhosted.org/packages/1f/e2/46edd73fb8bf51446c41148d81944c54ed224854812b6ca549be25113ee0/coverage-7.11.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:dadbcce51a10c07b7c72b0ce4a25e4b6dcb0c0372846afb8e5b6307a121eb99f", size = 260574, upload-time = "2025-10-15T15:13:56.145Z" }, + { url = "https://files.pythonhosted.org/packages/07/5e/1df469a19007ff82e2ca8fe509822820a31e251f80ee7344c34f6cd2ec43/coverage-7.11.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ed43fa22c6436f7957df036331f8fe4efa7af132054e1844918866cd228af6c", size = 262797, upload-time = "2025-10-15T15:13:58.635Z" }, + { url = "https://files.pythonhosted.org/packages/f9/50/de216b31a1434b94d9b34a964c09943c6be45069ec704bfc379d8d89a649/coverage-7.11.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9516add7256b6713ec08359b7b05aeff8850c98d357784c7205b2e60aa2513fa", size = 257361, upload-time = "2025-10-15T15:14:00.409Z" }, + { url = "https://files.pythonhosted.org/packages/82/1e/3f9f8344a48111e152e0fd495b6fff13cc743e771a6050abf1627a7ba918/coverage-7.11.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:eb92e47c92fcbcdc692f428da67db33337fa213756f7adb6a011f7b5a7a20740", size = 260349, upload-time = "2025-10-15T15:14:02.188Z" }, + { url = "https://files.pythonhosted.org/packages/65/9b/3f52741f9e7d82124272f3070bbe316006a7de1bad1093f88d59bfc6c548/coverage-7.11.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d06f4fc7acf3cabd6d74941d53329e06bab00a8fe10e4df2714f0b134bfc64ef", size = 258114, upload-time = "2025-10-15T15:14:03.907Z" }, + { url = "https://files.pythonhosted.org/packages/0b/8b/918f0e15f0365d50d3986bbd3338ca01178717ac5678301f3f547b6619e6/coverage-7.11.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:6fbcee1a8f056af07ecd344482f711f563a9eb1c2cad192e87df00338ec3cdb0", size = 256723, upload-time = "2025-10-15T15:14:06.324Z" }, + { url = "https://files.pythonhosted.org/packages/44/9e/7776829f82d3cf630878a7965a7d70cc6ca94f22c7d20ec4944f7148cb46/coverage-7.11.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dbbf012be5f32533a490709ad597ad8a8ff80c582a95adc8d62af664e532f9ca", size = 259238, upload-time = "2025-10-15T15:14:08.002Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b8/49cf253e1e7a3bedb85199b201862dd7ca4859f75b6cf25ffa7298aa0760/coverage-7.11.0-cp313-cp313t-win32.whl", hash = "sha256:cee6291bb4fed184f1c2b663606a115c743df98a537c969c3c64b49989da96c2", size = 219180, upload-time = "2025-10-15T15:14:09.786Z" }, + { url = "https://files.pythonhosted.org/packages/ac/e1/1a541703826be7ae2125a0fb7f821af5729d56bb71e946e7b933cc7a89a4/coverage-7.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a386c1061bf98e7ea4758e4313c0ab5ecf57af341ef0f43a0bf26c2477b5c268", size = 220241, upload-time = "2025-10-15T15:14:11.471Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d1/5ee0e0a08621140fd418ec4020f595b4d52d7eb429ae6a0c6542b4ba6f14/coverage-7.11.0-cp313-cp313t-win_arm64.whl", hash = "sha256:f9ea02ef40bb83823b2b04964459d281688fe173e20643870bb5d2edf68bc836", size = 218510, upload-time = "2025-10-15T15:14:13.46Z" }, + { url = "https://files.pythonhosted.org/packages/5f/04/642c1d8a448ae5ea1369eac8495740a79eb4e581a9fb0cbdce56bbf56da1/coverage-7.11.0-py3-none-any.whl", hash = "sha256:4b7589765348d78fb4e5fb6ea35d07564e387da2fc5efff62e0222971f155f68", size = 207761, upload-time = "2025-10-15T15:15:06.439Z" }, ] [[package]] @@ -638,7 +638,7 @@ wheels = [ [[package]] name = "ipykernel" -version = "7.0.0" +version = "7.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "appnope", marker = "sys_platform == 'darwin'" }, @@ -655,9 +655,9 @@ dependencies = [ { name = "tornado" }, { name = "traitlets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8a/81/50e62d30cee8e3035bcd46eb1e8768366df135240c0061846c7421ffcd06/ipykernel-7.0.0.tar.gz", hash = "sha256:06aef83f27adbce00b23345aa70f749f907dc4ac6f4a41fe7bf5f780dc506225", size = 173513, upload-time = "2025-10-13T11:23:53.977Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/4c/9f0024c8457286c6bfd5405a15d650ec5ea36f420ef9bbc58b301f66cfc5/ipykernel-7.0.1.tar.gz", hash = "sha256:2d3fd7cdef22071c2abbad78f142b743228c5d59cd470d034871ae0ac359533c", size = 171460, upload-time = "2025-10-14T16:17:07.325Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/52/20/99c49012535eec8ec797d049dec4f89c1d49b6793fa2f53e6e918f4d901f/ipykernel-7.0.0-py3-none-any.whl", hash = "sha256:28793cecaa6a669e3be80eb6d24803202388b6a955929b0a4e13404d8c92062b", size = 118921, upload-time = "2025-10-13T11:23:51.747Z" }, + { url = "https://files.pythonhosted.org/packages/b8/f7/761037905ffdec673533bfa43af8d4c31c859c778dfc3bbb71899875ec18/ipykernel-7.0.1-py3-none-any.whl", hash = "sha256:87182a8305e28954b6721087dec45b171712610111d494c17bb607befa1c4000", size = 118157, upload-time = "2025-10-14T16:17:05.606Z" }, ] [[package]] @@ -774,16 +774,15 @@ wheels = [ [[package]] name = "jupyter-core" -version = "5.8.1" +version = "5.9.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, - { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, { name = "traitlets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/1b/72906d554acfeb588332eaaa6f61577705e9ec752ddb486f302dafa292d9/jupyter_core-5.8.1.tar.gz", hash = "sha256:0a5f9706f70e64786b75acba995988915ebd4601c8a52e534a40b51c95f59941", size = 88923, upload-time = "2025-05-27T07:38:16.655Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/0c/7b01e93e054555cbadf614f2ff10ea77aecbc8867831914d8a2c5868481a/jupyter_core-5.9.0.tar.gz", hash = "sha256:5f8fba10cfc946fe1b4037e986458fc89430397207b21d741dc399d3d42951d4", size = 89804, upload-time = "2025-10-16T12:12:23.851Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/57/6bffd4b20b88da3800c5d691e0337761576ee688eb01299eae865689d2df/jupyter_core-5.8.1-py3-none-any.whl", hash = "sha256:c28d268fc90fb53f1338ded2eb410704c5449a358406e8a948b75706e24863d0", size = 28880, upload-time = "2025-05-27T07:38:15.137Z" }, + { url = "https://files.pythonhosted.org/packages/96/f2/5efda2a70d98288f4d94baba8489cd782d53772233c77351864bc754a146/jupyter_core-5.9.0-py3-none-any.whl", hash = "sha256:bf13431d292ce34a25568586729a3b9deb07d112289b77350dc4c2340c2f34c1", size = 29024, upload-time = "2025-10-16T12:12:22.19Z" }, ] [[package]] @@ -1027,32 +1026,32 @@ wheels = [ [[package]] name = "numpy" -version = "2.3.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d0/19/95b3d357407220ed24c139018d2518fab0a61a948e68286a25f1a4d049ff/numpy-2.3.3.tar.gz", hash = "sha256:ddc7c39727ba62b80dfdbedf400d1c10ddfa8eefbd7ec8dcb118be8b56d31029", size = 20576648, upload-time = "2025-09-09T16:54:12.543Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/b9/984c2b1ee61a8b803bf63582b4ac4242cf76e2dbd663efeafcb620cc0ccb/numpy-2.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f5415fb78995644253370985342cd03572ef8620b934da27d77377a2285955bf", size = 20949588, upload-time = "2025-09-09T15:56:59.087Z" }, - { url = "https://files.pythonhosted.org/packages/a6/e4/07970e3bed0b1384d22af1e9912527ecbeb47d3b26e9b6a3bced068b3bea/numpy-2.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d00de139a3324e26ed5b95870ce63be7ec7352171bc69a4cf1f157a48e3eb6b7", size = 14177802, upload-time = "2025-09-09T15:57:01.73Z" }, - { url = "https://files.pythonhosted.org/packages/35/c7/477a83887f9de61f1203bad89cf208b7c19cc9fef0cebef65d5a1a0619f2/numpy-2.3.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:9dc13c6a5829610cc07422bc74d3ac083bd8323f14e2827d992f9e52e22cd6a6", size = 5106537, upload-time = "2025-09-09T15:57:03.765Z" }, - { url = "https://files.pythonhosted.org/packages/52/47/93b953bd5866a6f6986344d045a207d3f1cfbad99db29f534ea9cee5108c/numpy-2.3.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:d79715d95f1894771eb4e60fb23f065663b2298f7d22945d66877aadf33d00c7", size = 6640743, upload-time = "2025-09-09T15:57:07.921Z" }, - { url = "https://files.pythonhosted.org/packages/23/83/377f84aaeb800b64c0ef4de58b08769e782edcefa4fea712910b6f0afd3c/numpy-2.3.3-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:952cfd0748514ea7c3afc729a0fc639e61655ce4c55ab9acfab14bda4f402b4c", size = 14278881, upload-time = "2025-09-09T15:57:11.349Z" }, - { url = "https://files.pythonhosted.org/packages/9a/a5/bf3db6e66c4b160d6ea10b534c381a1955dfab34cb1017ea93aa33c70ed3/numpy-2.3.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5b83648633d46f77039c29078751f80da65aa64d5622a3cd62aaef9d835b6c93", size = 16636301, upload-time = "2025-09-09T15:57:14.245Z" }, - { url = "https://files.pythonhosted.org/packages/a2/59/1287924242eb4fa3f9b3a2c30400f2e17eb2707020d1c5e3086fe7330717/numpy-2.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b001bae8cea1c7dfdb2ae2b017ed0a6f2102d7a70059df1e338e307a4c78a8ae", size = 16053645, upload-time = "2025-09-09T15:57:16.534Z" }, - { url = "https://files.pythonhosted.org/packages/e6/93/b3d47ed882027c35e94ac2320c37e452a549f582a5e801f2d34b56973c97/numpy-2.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8e9aced64054739037d42fb84c54dd38b81ee238816c948c8f3ed134665dcd86", size = 18578179, upload-time = "2025-09-09T15:57:18.883Z" }, - { url = "https://files.pythonhosted.org/packages/20/d9/487a2bccbf7cc9d4bfc5f0f197761a5ef27ba870f1e3bbb9afc4bbe3fcc2/numpy-2.3.3-cp313-cp313-win32.whl", hash = "sha256:9591e1221db3f37751e6442850429b3aabf7026d3b05542d102944ca7f00c8a8", size = 6312250, upload-time = "2025-09-09T15:57:21.296Z" }, - { url = "https://files.pythonhosted.org/packages/1b/b5/263ebbbbcede85028f30047eab3d58028d7ebe389d6493fc95ae66c636ab/numpy-2.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:f0dadeb302887f07431910f67a14d57209ed91130be0adea2f9793f1a4f817cf", size = 12783269, upload-time = "2025-09-09T15:57:23.034Z" }, - { url = "https://files.pythonhosted.org/packages/fa/75/67b8ca554bbeaaeb3fac2e8bce46967a5a06544c9108ec0cf5cece559b6c/numpy-2.3.3-cp313-cp313-win_arm64.whl", hash = "sha256:3c7cf302ac6e0b76a64c4aecf1a09e51abd9b01fc7feee80f6c43e3ab1b1dbc5", size = 10195314, upload-time = "2025-09-09T15:57:25.045Z" }, - { url = "https://files.pythonhosted.org/packages/11/d0/0d1ddec56b162042ddfafeeb293bac672de9b0cfd688383590090963720a/numpy-2.3.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:eda59e44957d272846bb407aad19f89dc6f58fecf3504bd144f4c5cf81a7eacc", size = 21048025, upload-time = "2025-09-09T15:57:27.257Z" }, - { url = "https://files.pythonhosted.org/packages/36/9e/1996ca6b6d00415b6acbdd3c42f7f03ea256e2c3f158f80bd7436a8a19f3/numpy-2.3.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:823d04112bc85ef5c4fda73ba24e6096c8f869931405a80aa8b0e604510a26bc", size = 14301053, upload-time = "2025-09-09T15:57:30.077Z" }, - { url = "https://files.pythonhosted.org/packages/05/24/43da09aa764c68694b76e84b3d3f0c44cb7c18cdc1ba80e48b0ac1d2cd39/numpy-2.3.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:40051003e03db4041aa325da2a0971ba41cf65714e65d296397cc0e32de6018b", size = 5229444, upload-time = "2025-09-09T15:57:32.733Z" }, - { url = "https://files.pythonhosted.org/packages/bc/14/50ffb0f22f7218ef8af28dd089f79f68289a7a05a208db9a2c5dcbe123c1/numpy-2.3.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:6ee9086235dd6ab7ae75aba5662f582a81ced49f0f1c6de4260a78d8f2d91a19", size = 6738039, upload-time = "2025-09-09T15:57:34.328Z" }, - { url = "https://files.pythonhosted.org/packages/55/52/af46ac0795e09657d45a7f4db961917314377edecf66db0e39fa7ab5c3d3/numpy-2.3.3-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:94fcaa68757c3e2e668ddadeaa86ab05499a70725811e582b6a9858dd472fb30", size = 14352314, upload-time = "2025-09-09T15:57:36.255Z" }, - { url = "https://files.pythonhosted.org/packages/a7/b1/dc226b4c90eb9f07a3fff95c2f0db3268e2e54e5cce97c4ac91518aee71b/numpy-2.3.3-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da1a74b90e7483d6ce5244053399a614b1d6b7bc30a60d2f570e5071f8959d3e", size = 16701722, upload-time = "2025-09-09T15:57:38.622Z" }, - { url = "https://files.pythonhosted.org/packages/9d/9d/9d8d358f2eb5eced14dba99f110d83b5cd9a4460895230f3b396ad19a323/numpy-2.3.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2990adf06d1ecee3b3dcbb4977dfab6e9f09807598d647f04d385d29e7a3c3d3", size = 16132755, upload-time = "2025-09-09T15:57:41.16Z" }, - { url = "https://files.pythonhosted.org/packages/b6/27/b3922660c45513f9377b3fb42240bec63f203c71416093476ec9aa0719dc/numpy-2.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ed635ff692483b8e3f0fcaa8e7eb8a75ee71aa6d975388224f70821421800cea", size = 18651560, upload-time = "2025-09-09T15:57:43.459Z" }, - { url = "https://files.pythonhosted.org/packages/5b/8e/3ab61a730bdbbc201bb245a71102aa609f0008b9ed15255500a99cd7f780/numpy-2.3.3-cp313-cp313t-win32.whl", hash = "sha256:a333b4ed33d8dc2b373cc955ca57babc00cd6f9009991d9edc5ddbc1bac36bcd", size = 6442776, upload-time = "2025-09-09T15:57:45.793Z" }, - { url = "https://files.pythonhosted.org/packages/1c/3a/e22b766b11f6030dc2decdeff5c2fb1610768055603f9f3be88b6d192fb2/numpy-2.3.3-cp313-cp313t-win_amd64.whl", hash = "sha256:4384a169c4d8f97195980815d6fcad04933a7e1ab3b530921c3fef7a1c63426d", size = 12927281, upload-time = "2025-09-09T15:57:47.492Z" }, - { url = "https://files.pythonhosted.org/packages/7b/42/c2e2bc48c5e9b2a83423f99733950fbefd86f165b468a3d85d52b30bf782/numpy-2.3.3-cp313-cp313t-win_arm64.whl", hash = "sha256:75370986cc0bc66f4ce5110ad35aae6d182cc4ce6433c40ad151f53690130bf1", size = 10265275, upload-time = "2025-09-09T15:57:49.647Z" }, +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/f4/098d2270d52b41f1bd7db9fc288aaa0400cb48c2a3e2af6fa365d9720947/numpy-2.3.4.tar.gz", hash = "sha256:a7d018bfedb375a8d979ac758b120ba846a7fe764911a64465fd87b8729f4a6a", size = 20582187, upload-time = "2025-10-15T16:18:11.77Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/7e/b72610cc91edf138bc588df5150957a4937221ca6058b825b4725c27be62/numpy-2.3.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c090d4860032b857d94144d1a9976b8e36709e40386db289aaf6672de2a81966", size = 20950335, upload-time = "2025-10-15T16:16:10.304Z" }, + { url = "https://files.pythonhosted.org/packages/3e/46/bdd3370dcea2f95ef14af79dbf81e6927102ddf1cc54adc0024d61252fd9/numpy-2.3.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a13fc473b6db0be619e45f11f9e81260f7302f8d180c49a22b6e6120022596b3", size = 14179878, upload-time = "2025-10-15T16:16:12.595Z" }, + { url = "https://files.pythonhosted.org/packages/ac/01/5a67cb785bda60f45415d09c2bc245433f1c68dd82eef9c9002c508b5a65/numpy-2.3.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:3634093d0b428e6c32c3a69b78e554f0cd20ee420dcad5a9f3b2a63762ce4197", size = 5108673, upload-time = "2025-10-15T16:16:14.877Z" }, + { url = "https://files.pythonhosted.org/packages/c2/cd/8428e23a9fcebd33988f4cb61208fda832800ca03781f471f3727a820704/numpy-2.3.4-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:043885b4f7e6e232d7df4f51ffdef8c36320ee9d5f227b380ea636722c7ed12e", size = 6641438, upload-time = "2025-10-15T16:16:16.805Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d1/913fe563820f3c6b079f992458f7331278dcd7ba8427e8e745af37ddb44f/numpy-2.3.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4ee6a571d1e4f0ea6d5f22d6e5fbd6ed1dc2b18542848e1e7301bd190500c9d7", size = 14281290, upload-time = "2025-10-15T16:16:18.764Z" }, + { url = "https://files.pythonhosted.org/packages/9e/7e/7d306ff7cb143e6d975cfa7eb98a93e73495c4deabb7d1b5ecf09ea0fd69/numpy-2.3.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fc8a63918b04b8571789688b2780ab2b4a33ab44bfe8ccea36d3eba51228c953", size = 16636543, upload-time = "2025-10-15T16:16:21.072Z" }, + { url = "https://files.pythonhosted.org/packages/47/6a/8cfc486237e56ccfb0db234945552a557ca266f022d281a2f577b98e955c/numpy-2.3.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:40cc556d5abbc54aabe2b1ae287042d7bdb80c08edede19f0c0afb36ae586f37", size = 16056117, upload-time = "2025-10-15T16:16:23.369Z" }, + { url = "https://files.pythonhosted.org/packages/b1/0e/42cb5e69ea901e06ce24bfcc4b5664a56f950a70efdcf221f30d9615f3f3/numpy-2.3.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ecb63014bb7f4ce653f8be7f1df8cbc6093a5a2811211770f6606cc92b5a78fd", size = 18577788, upload-time = "2025-10-15T16:16:27.496Z" }, + { url = "https://files.pythonhosted.org/packages/86/92/41c3d5157d3177559ef0a35da50f0cda7fa071f4ba2306dd36818591a5bc/numpy-2.3.4-cp313-cp313-win32.whl", hash = "sha256:e8370eb6925bb8c1c4264fec52b0384b44f675f191df91cbe0140ec9f0955646", size = 6282620, upload-time = "2025-10-15T16:16:29.811Z" }, + { url = "https://files.pythonhosted.org/packages/09/97/fd421e8bc50766665ad35536c2bb4ef916533ba1fdd053a62d96cc7c8b95/numpy-2.3.4-cp313-cp313-win_amd64.whl", hash = "sha256:56209416e81a7893036eea03abcb91c130643eb14233b2515c90dcac963fe99d", size = 12784672, upload-time = "2025-10-15T16:16:31.589Z" }, + { url = "https://files.pythonhosted.org/packages/ad/df/5474fb2f74970ca8eb978093969b125a84cc3d30e47f82191f981f13a8a0/numpy-2.3.4-cp313-cp313-win_arm64.whl", hash = "sha256:a700a4031bc0fd6936e78a752eefb79092cecad2599ea9c8039c548bc097f9bc", size = 10196702, upload-time = "2025-10-15T16:16:33.902Z" }, + { url = "https://files.pythonhosted.org/packages/11/83/66ac031464ec1767ea3ed48ce40f615eb441072945e98693bec0bcd056cc/numpy-2.3.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:86966db35c4040fdca64f0816a1c1dd8dbd027d90fca5a57e00e1ca4cd41b879", size = 21049003, upload-time = "2025-10-15T16:16:36.101Z" }, + { url = "https://files.pythonhosted.org/packages/5f/99/5b14e0e686e61371659a1d5bebd04596b1d72227ce36eed121bb0aeab798/numpy-2.3.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:838f045478638b26c375ee96ea89464d38428c69170360b23a1a50fa4baa3562", size = 14302980, upload-time = "2025-10-15T16:16:39.124Z" }, + { url = "https://files.pythonhosted.org/packages/2c/44/e9486649cd087d9fc6920e3fc3ac2aba10838d10804b1e179fb7cbc4e634/numpy-2.3.4-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d7315ed1dab0286adca467377c8381cd748f3dc92235f22a7dfc42745644a96a", size = 5231472, upload-time = "2025-10-15T16:16:41.168Z" }, + { url = "https://files.pythonhosted.org/packages/3e/51/902b24fa8887e5fe2063fd61b1895a476d0bbf46811ab0c7fdf4bd127345/numpy-2.3.4-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:84f01a4d18b2cc4ade1814a08e5f3c907b079c847051d720fad15ce37aa930b6", size = 6739342, upload-time = "2025-10-15T16:16:43.777Z" }, + { url = "https://files.pythonhosted.org/packages/34/f1/4de9586d05b1962acdcdb1dc4af6646361a643f8c864cef7c852bf509740/numpy-2.3.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:817e719a868f0dacde4abdfc5c1910b301877970195db9ab6a5e2c4bd5b121f7", size = 14354338, upload-time = "2025-10-15T16:16:46.081Z" }, + { url = "https://files.pythonhosted.org/packages/1f/06/1c16103b425de7969d5a76bdf5ada0804b476fed05d5f9e17b777f1cbefd/numpy-2.3.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85e071da78d92a214212cacea81c6da557cab307f2c34b5f85b628e94803f9c0", size = 16702392, upload-time = "2025-10-15T16:16:48.455Z" }, + { url = "https://files.pythonhosted.org/packages/34/b2/65f4dc1b89b5322093572b6e55161bb42e3e0487067af73627f795cc9d47/numpy-2.3.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2ec646892819370cf3558f518797f16597b4e4669894a2ba712caccc9da53f1f", size = 16134998, upload-time = "2025-10-15T16:16:51.114Z" }, + { url = "https://files.pythonhosted.org/packages/d4/11/94ec578896cdb973aaf56425d6c7f2aff4186a5c00fac15ff2ec46998b46/numpy-2.3.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:035796aaaddfe2f9664b9a9372f089cfc88bd795a67bd1bfe15e6e770934cf64", size = 18651574, upload-time = "2025-10-15T16:16:53.429Z" }, + { url = "https://files.pythonhosted.org/packages/62/b7/7efa763ab33dbccf56dade36938a77345ce8e8192d6b39e470ca25ff3cd0/numpy-2.3.4-cp313-cp313t-win32.whl", hash = "sha256:fea80f4f4cf83b54c3a051f2f727870ee51e22f0248d3114b8e755d160b38cfb", size = 6413135, upload-time = "2025-10-15T16:16:55.992Z" }, + { url = "https://files.pythonhosted.org/packages/43/70/aba4c38e8400abcc2f345e13d972fb36c26409b3e644366db7649015f291/numpy-2.3.4-cp313-cp313t-win_amd64.whl", hash = "sha256:15eea9f306b98e0be91eb344a94c0e630689ef302e10c2ce5f7e11905c704f9c", size = 12928582, upload-time = "2025-10-15T16:16:57.943Z" }, + { url = "https://files.pythonhosted.org/packages/67/63/871fad5f0073fc00fbbdd7232962ea1ac40eeaae2bba66c76214f7954236/numpy-2.3.4-cp313-cp313t-win_arm64.whl", hash = "sha256:b6c231c9c2fadbae4011ca5e7e83e12dc4a5072f1a1d85a0a7b3ed754d145a40", size = 10266691, upload-time = "2025-10-15T16:17:00.048Z" }, ] [[package]] @@ -1358,16 +1357,17 @@ wheels = [ [[package]] name = "protobuf" -version = "6.32.1" +version = "6.33.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fa/a4/cc17347aa2897568beece2e674674359f911d6fe21b0b8d6268cd42727ac/protobuf-6.32.1.tar.gz", hash = "sha256:ee2469e4a021474ab9baafea6cd070e5bf27c7d29433504ddea1a4ee5850f68d", size = 440635, upload-time = "2025-09-11T21:38:42.935Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/ff/64a6c8f420818bb873713988ca5492cba3a7946be57e027ac63495157d97/protobuf-6.33.0.tar.gz", hash = "sha256:140303d5c8d2037730c548f8c7b93b20bb1dc301be280c378b82b8894589c954", size = 443463, upload-time = "2025-10-15T20:39:52.159Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c0/98/645183ea03ab3995d29086b8bf4f7562ebd3d10c9a4b14ee3f20d47cfe50/protobuf-6.32.1-cp310-abi3-win32.whl", hash = "sha256:a8a32a84bc9f2aad712041b8b366190f71dde248926da517bde9e832e4412085", size = 424411, upload-time = "2025-09-11T21:38:27.427Z" }, - { url = "https://files.pythonhosted.org/packages/8c/f3/6f58f841f6ebafe076cebeae33fc336e900619d34b1c93e4b5c97a81fdfa/protobuf-6.32.1-cp310-abi3-win_amd64.whl", hash = "sha256:b00a7d8c25fa471f16bc8153d0e53d6c9e827f0953f3c09aaa4331c718cae5e1", size = 435738, upload-time = "2025-09-11T21:38:30.959Z" }, - { url = "https://files.pythonhosted.org/packages/10/56/a8a3f4e7190837139e68c7002ec749190a163af3e330f65d90309145a210/protobuf-6.32.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d8c7e6eb619ffdf105ee4ab76af5a68b60a9d0f66da3ea12d1640e6d8dab7281", size = 426454, upload-time = "2025-09-11T21:38:34.076Z" }, - { url = "https://files.pythonhosted.org/packages/3f/be/8dd0a927c559b37d7a6c8ab79034fd167dcc1f851595f2e641ad62be8643/protobuf-6.32.1-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:2f5b80a49e1eb7b86d85fcd23fe92df154b9730a725c3b38c4e43b9d77018bf4", size = 322874, upload-time = "2025-09-11T21:38:35.509Z" }, - { url = "https://files.pythonhosted.org/packages/5c/f6/88d77011b605ef979aace37b7703e4eefad066f7e84d935e5a696515c2dd/protobuf-6.32.1-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:b1864818300c297265c83a4982fd3169f97122c299f56a56e2445c3698d34710", size = 322013, upload-time = "2025-09-11T21:38:37.017Z" }, - { url = "https://files.pythonhosted.org/packages/97/b7/15cc7d93443d6c6a84626ae3258a91f4c6ac8c0edd5df35ea7658f71b79c/protobuf-6.32.1-py3-none-any.whl", hash = "sha256:2601b779fc7d32a866c6b4404f9d42a3f67c5b9f3f15b4db3cccabe06b95c346", size = 169289, upload-time = "2025-09-11T21:38:41.234Z" }, + { url = "https://files.pythonhosted.org/packages/7e/ee/52b3fa8feb6db4a833dfea4943e175ce645144532e8a90f72571ad85df4e/protobuf-6.33.0-cp310-abi3-win32.whl", hash = "sha256:d6101ded078042a8f17959eccd9236fb7a9ca20d3b0098bbcb91533a5680d035", size = 425593, upload-time = "2025-10-15T20:39:40.29Z" }, + { url = "https://files.pythonhosted.org/packages/7b/c6/7a465f1825872c55e0341ff4a80198743f73b69ce5d43ab18043699d1d81/protobuf-6.33.0-cp310-abi3-win_amd64.whl", hash = "sha256:9a031d10f703f03768f2743a1c403af050b6ae1f3480e9c140f39c45f81b13ee", size = 436882, upload-time = "2025-10-15T20:39:42.841Z" }, + { url = "https://files.pythonhosted.org/packages/e1/a9/b6eee662a6951b9c3640e8e452ab3e09f117d99fc10baa32d1581a0d4099/protobuf-6.33.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:905b07a65f1a4b72412314082c7dbfae91a9e8b68a0cc1577515f8df58ecf455", size = 427521, upload-time = "2025-10-15T20:39:43.803Z" }, + { url = "https://files.pythonhosted.org/packages/10/35/16d31e0f92c6d2f0e77c2a3ba93185130ea13053dd16200a57434c882f2b/protobuf-6.33.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e0697ece353e6239b90ee43a9231318302ad8353c70e6e45499fa52396debf90", size = 324445, upload-time = "2025-10-15T20:39:44.932Z" }, + { url = "https://files.pythonhosted.org/packages/e6/eb/2a981a13e35cda8b75b5585aaffae2eb904f8f351bdd3870769692acbd8a/protobuf-6.33.0-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:e0a1715e4f27355afd9570f3ea369735afc853a6c3951a6afe1f80d8569ad298", size = 339159, upload-time = "2025-10-15T20:39:46.186Z" }, + { url = "https://files.pythonhosted.org/packages/21/51/0b1cbad62074439b867b4e04cc09b93f6699d78fd191bed2bbb44562e077/protobuf-6.33.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:35be49fd3f4fefa4e6e2aacc35e8b837d6703c37a2168a55ac21e9b1bc7559ef", size = 323172, upload-time = "2025-10-15T20:39:47.465Z" }, + { url = "https://files.pythonhosted.org/packages/07/d1/0a28c21707807c6aacd5dc9c3704b2aa1effbf37adebd8caeaf68b17a636/protobuf-6.33.0-py3-none-any.whl", hash = "sha256:25c9e1963c6734448ea2d308cfa610e692b801304ba0908d7bfa564ac5132995", size = 170477, upload-time = "2025-10-15T20:39:51.311Z" }, ] [[package]] @@ -1437,7 +1437,7 @@ wheels = [ [[package]] name = "pydantic" -version = "2.12.1" +version = "2.12.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-types" }, @@ -1445,39 +1445,39 @@ dependencies = [ { name = "typing-extensions" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3c/a7/d0d7b3c128948ece6676a6a21b9036e3ca53765d35052dbcc8c303886a44/pydantic-2.12.1.tar.gz", hash = "sha256:0af849d00e1879199babd468ec9db13b956f6608e9250500c1a9d69b6a62824e", size = 815997, upload-time = "2025-10-13T21:00:41.219Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8d/35/d319ed522433215526689bad428a94058b6dd12190ce7ddd78618ac14b28/pydantic-2.12.2.tar.gz", hash = "sha256:7b8fa15b831a4bbde9d5b84028641ac3080a4ca2cbd4a621a661687e741624fd", size = 816358, upload-time = "2025-10-14T15:02:21.842Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f5/69/ce4e60e5e67aa0c339a5dc3391a02b4036545efb6308c54dc4aa9425386f/pydantic-2.12.1-py3-none-any.whl", hash = "sha256:665931f5b4ab40c411439e66f99060d631d1acc58c3d481957b9123343d674d1", size = 460511, upload-time = "2025-10-13T21:00:38.935Z" }, + { url = "https://files.pythonhosted.org/packages/6c/98/468cb649f208a6f1279448e6e5247b37ae79cf5e4041186f1e2ef3d16345/pydantic-2.12.2-py3-none-any.whl", hash = "sha256:25ff718ee909acd82f1ff9b1a4acfd781bb23ab3739adaa7144f19a6a4e231ae", size = 460628, upload-time = "2025-10-14T15:02:19.623Z" }, ] [[package]] name = "pydantic-core" -version = "2.41.3" +version = "2.41.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/00/e9/3916abb671bffb00845408c604ff03480dc8dc273310d8268547a37be0fb/pydantic_core-2.41.3.tar.gz", hash = "sha256:cdebb34b36ad05e8d77b4e797ad38a2a775c2a07a8fa386d4f6943b7778dcd39", size = 457489, upload-time = "2025-10-13T19:34:51.666Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9f/a6/7533cba20b8b66e209d8d2acbb9ccc0bc1b883b0654776d676e02696ef5d/pydantic_core-2.41.3-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:a8596700fdd3ee12b0d9c1f2395f4c32557e7ebfbfacdc08055b0bcbe7d2827e", size = 2105686, upload-time = "2025-10-13T19:31:57.675Z" }, - { url = "https://files.pythonhosted.org/packages/84/d7/2d15cb9dfb9f94422fb4a8820cbfeb397e3823087c2361ef46df5c172000/pydantic_core-2.41.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:624503f918e472c0eed6935020c01b6a6b4bcdb7955a848da5c8805d40f15c0f", size = 1910554, upload-time = "2025-10-13T19:32:00.037Z" }, - { url = "https://files.pythonhosted.org/packages/4c/fc/cbd1caa19e88fd64df716a37b49e5864c1ac27dbb9eb870b8977a584fa42/pydantic_core-2.41.3-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:36388958d0c614df9f5de1a5f88f4b79359016b9ecdfc352037788a628616aa2", size = 1957559, upload-time = "2025-10-13T19:32:02.603Z" }, - { url = "https://files.pythonhosted.org/packages/3b/fe/da942ae51f602173556c627304dc24b9fa8bd04423bce189bf397ba0419e/pydantic_core-2.41.3-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3c50eba144add9104cf43ef9a3d81c37ebf48bfd0924b584b78ec2e03ec91daf", size = 2051084, upload-time = "2025-10-13T19:32:05.056Z" }, - { url = "https://files.pythonhosted.org/packages/c8/62/0abd59a7107d1ef502b9cfab68145c6bb87115c2d9e883afbf18b98fe6db/pydantic_core-2.41.3-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c6ea2102958eb5ad560d570c49996e215a6939d9bffd0e9fd3b9e808a55008cc", size = 2218098, upload-time = "2025-10-13T19:32:06.837Z" }, - { url = "https://files.pythonhosted.org/packages/72/b1/93a36aa119b70126f3f0d06b6f9a81ca864115962669d8a85deb39c82ecc/pydantic_core-2.41.3-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd0d26f1e4335d5f84abfc880da0afa080c8222410482f9ee12043bb05f55ec8", size = 2341954, upload-time = "2025-10-13T19:32:08.583Z" }, - { url = "https://files.pythonhosted.org/packages/0f/be/7c2563b53b71ff3e41950b0ffa9eeba3d702091c6d59036fff8a39050528/pydantic_core-2.41.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41c38700094045b12c0cff35c8585954de66cf6dd63909fed1c2e6b8f38e1e1e", size = 2069474, upload-time = "2025-10-13T19:32:10.808Z" }, - { url = "https://files.pythonhosted.org/packages/ba/ac/2394004db9f6e03712c1e52f40f0979750fa87721f6baf5f76ad92b8be46/pydantic_core-2.41.3-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4061cc82d7177417fdb90e23e67b27425ecde2652cfd2053b5b4661a489ddc19", size = 2190633, upload-time = "2025-10-13T19:32:12.731Z" }, - { url = "https://files.pythonhosted.org/packages/7d/31/7b70c2d1fe41f450f8022f5523edaaea19c17a2d321fab03efd03aea1fe8/pydantic_core-2.41.3-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:b1d9699a4dae10a7719951cca1e30b591ef1dd9cdda9fec39282a283576c0241", size = 2137097, upload-time = "2025-10-13T19:32:14.634Z" }, - { url = "https://files.pythonhosted.org/packages/4e/ae/f872198cffc8564f52c4ef83bcd3e324e5ac914e168c6b812f5ce3f80aab/pydantic_core-2.41.3-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:d5099f1b97e79f0e45cb6a236a5bd1a20078ed50b1b28f3d17f6c83ff3585baa", size = 2316771, upload-time = "2025-10-13T19:32:16.586Z" }, - { url = "https://files.pythonhosted.org/packages/23/50/f0fce3a9a7554ced178d943e1eada58b15fca896e9eb75d50244fc12007c/pydantic_core-2.41.3-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:b5ff0467a8c1b6abb0ab9c9ea80e2e3a9788592e44c726c2db33fdaf1b5e7d0b", size = 2319449, upload-time = "2025-10-13T19:32:18.503Z" }, - { url = "https://files.pythonhosted.org/packages/15/1f/86a6948408e8388604c02ffde651a2e39b711bd1ab6eeaff376094553a10/pydantic_core-2.41.3-cp313-cp313-win32.whl", hash = "sha256:edfe9b4cee4a91da7247c25732f24504071f3e101c050694d18194b7d2d320bf", size = 1995352, upload-time = "2025-10-13T19:32:20.5Z" }, - { url = "https://files.pythonhosted.org/packages/1f/4b/6dac37c3f62684dc459a31623d8ae97ee433fd68bb827e5c64dd831a5087/pydantic_core-2.41.3-cp313-cp313-win_amd64.whl", hash = "sha256:44af3276c0c2c14efde6590523e4d7e04bcd0e46e0134f0dbef1be0b64b2d3e3", size = 2031894, upload-time = "2025-10-13T19:32:23.11Z" }, - { url = "https://files.pythonhosted.org/packages/fd/75/3d9ba041a3fcb147279fbb37d2468efe62606809fec97b8de78174335ef4/pydantic_core-2.41.3-cp313-cp313-win_arm64.whl", hash = "sha256:59aeed341f92440d51fdcc82c8e930cfb234f1843ed1d4ae1074f5fb9789a64b", size = 1974036, upload-time = "2025-10-13T19:32:25.219Z" }, - { url = "https://files.pythonhosted.org/packages/50/68/45842628ccdb384df029f884ef915306d195c4f08b66ca4d99867edc6338/pydantic_core-2.41.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ef37228238b3a280170ac43a010835c4a7005742bc8831c2c1a9560de4595dbe", size = 1876856, upload-time = "2025-10-13T19:32:27.504Z" }, - { url = "https://files.pythonhosted.org/packages/99/73/336a82910c6a482a0ba9a255c08dcc456ebca9735df96d7a82dffe17626a/pydantic_core-2.41.3-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5cb19f36253152c509abe76c1d1b185436e0c75f392a82934fe37f4a1264449", size = 1884665, upload-time = "2025-10-13T19:32:29.567Z" }, - { url = "https://files.pythonhosted.org/packages/34/87/ec610a7849561e0ef7c25b74ef934d154454c3aac8fb595b899557f3c6ab/pydantic_core-2.41.3-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:91be4756e05367ce19a70e1db3b77f01f9e40ca70d26fb4cdfa993e53a08964a", size = 2043067, upload-time = "2025-10-13T19:32:31.506Z" }, - { url = "https://files.pythonhosted.org/packages/db/b4/5f2b0cf78752f9111177423bd5f2bc0815129e587c13401636b8900a417e/pydantic_core-2.41.3-cp313-cp313t-win_amd64.whl", hash = "sha256:ce7d8f4353f82259b55055bd162bbaf599f6c40cd0c098e989eeb95f9fdc022f", size = 1996799, upload-time = "2025-10-13T19:32:33.612Z" }, - { url = "https://files.pythonhosted.org/packages/49/7f/07e7f19a6a44a52abd48846e348e11fa1b3de5ed7c0231d53f055ffb365f/pydantic_core-2.41.3-cp313-cp313t-win_arm64.whl", hash = "sha256:f06a9e81da60e5a0ef584f6f4790f925c203880ae391bf363d97126fd1790b21", size = 1969574, upload-time = "2025-10-13T19:32:35.533Z" }, +sdist = { url = "https://files.pythonhosted.org/packages/df/18/d0944e8eaaa3efd0a91b0f1fc537d3be55ad35091b6a87638211ba691964/pydantic_core-2.41.4.tar.gz", hash = "sha256:70e47929a9d4a1905a67e4b687d5946026390568a8e952b92824118063cee4d5", size = 457557, upload-time = "2025-10-14T10:23:47.909Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/13/d0/c20adabd181a029a970738dfe23710b52a31f1258f591874fcdec7359845/pydantic_core-2.41.4-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:85e050ad9e5f6fe1004eec65c914332e52f429bc0ae12d6fa2092407a462c746", size = 2105688, upload-time = "2025-10-14T10:20:54.448Z" }, + { url = "https://files.pythonhosted.org/packages/00/b6/0ce5c03cec5ae94cca220dfecddc453c077d71363b98a4bbdb3c0b22c783/pydantic_core-2.41.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7393f1d64792763a48924ba31d1e44c2cfbc05e3b1c2c9abb4ceeadd912cced", size = 1910807, upload-time = "2025-10-14T10:20:56.115Z" }, + { url = "https://files.pythonhosted.org/packages/68/3e/800d3d02c8beb0b5c069c870cbb83799d085debf43499c897bb4b4aaff0d/pydantic_core-2.41.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94dab0940b0d1fb28bcab847adf887c66a27a40291eedf0b473be58761c9799a", size = 1956669, upload-time = "2025-10-14T10:20:57.874Z" }, + { url = "https://files.pythonhosted.org/packages/60/a4/24271cc71a17f64589be49ab8bd0751f6a0a03046c690df60989f2f95c2c/pydantic_core-2.41.4-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:de7c42f897e689ee6f9e93c4bec72b99ae3b32a2ade1c7e4798e690ff5246e02", size = 2051629, upload-time = "2025-10-14T10:21:00.006Z" }, + { url = "https://files.pythonhosted.org/packages/68/de/45af3ca2f175d91b96bfb62e1f2d2f1f9f3b14a734afe0bfeff079f78181/pydantic_core-2.41.4-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:664b3199193262277b8b3cd1e754fb07f2c6023289c815a1e1e8fb415cb247b1", size = 2224049, upload-time = "2025-10-14T10:21:01.801Z" }, + { url = "https://files.pythonhosted.org/packages/af/8f/ae4e1ff84672bf869d0a77af24fd78387850e9497753c432875066b5d622/pydantic_core-2.41.4-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d95b253b88f7d308b1c0b417c4624f44553ba4762816f94e6986819b9c273fb2", size = 2342409, upload-time = "2025-10-14T10:21:03.556Z" }, + { url = "https://files.pythonhosted.org/packages/18/62/273dd70b0026a085c7b74b000394e1ef95719ea579c76ea2f0cc8893736d/pydantic_core-2.41.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a1351f5bbdbbabc689727cb91649a00cb9ee7203e0a6e54e9f5ba9e22e384b84", size = 2069635, upload-time = "2025-10-14T10:21:05.385Z" }, + { url = "https://files.pythonhosted.org/packages/30/03/cf485fff699b4cdaea469bc481719d3e49f023241b4abb656f8d422189fc/pydantic_core-2.41.4-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1affa4798520b148d7182da0615d648e752de4ab1a9566b7471bc803d88a062d", size = 2194284, upload-time = "2025-10-14T10:21:07.122Z" }, + { url = "https://files.pythonhosted.org/packages/f9/7e/c8e713db32405dfd97211f2fc0a15d6bf8adb7640f3d18544c1f39526619/pydantic_core-2.41.4-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:7b74e18052fea4aa8dea2fb7dbc23d15439695da6cbe6cfc1b694af1115df09d", size = 2137566, upload-time = "2025-10-14T10:21:08.981Z" }, + { url = "https://files.pythonhosted.org/packages/04/f7/db71fd4cdccc8b75990f79ccafbbd66757e19f6d5ee724a6252414483fb4/pydantic_core-2.41.4-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:285b643d75c0e30abda9dc1077395624f314a37e3c09ca402d4015ef5979f1a2", size = 2316809, upload-time = "2025-10-14T10:21:10.805Z" }, + { url = "https://files.pythonhosted.org/packages/76/63/a54973ddb945f1bca56742b48b144d85c9fc22f819ddeb9f861c249d5464/pydantic_core-2.41.4-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:f52679ff4218d713b3b33f88c89ccbf3a5c2c12ba665fb80ccc4192b4608dbab", size = 2311119, upload-time = "2025-10-14T10:21:12.583Z" }, + { url = "https://files.pythonhosted.org/packages/f8/03/5d12891e93c19218af74843a27e32b94922195ded2386f7b55382f904d2f/pydantic_core-2.41.4-cp313-cp313-win32.whl", hash = "sha256:ecde6dedd6fff127c273c76821bb754d793be1024bc33314a120f83a3c69460c", size = 1981398, upload-time = "2025-10-14T10:21:14.584Z" }, + { url = "https://files.pythonhosted.org/packages/be/d8/fd0de71f39db91135b7a26996160de71c073d8635edfce8b3c3681be0d6d/pydantic_core-2.41.4-cp313-cp313-win_amd64.whl", hash = "sha256:d081a1f3800f05409ed868ebb2d74ac39dd0c1ff6c035b5162356d76030736d4", size = 2030735, upload-time = "2025-10-14T10:21:16.432Z" }, + { url = "https://files.pythonhosted.org/packages/72/86/c99921c1cf6650023c08bfab6fe2d7057a5142628ef7ccfa9921f2dda1d5/pydantic_core-2.41.4-cp313-cp313-win_arm64.whl", hash = "sha256:f8e49c9c364a7edcbe2a310f12733aad95b022495ef2a8d653f645e5d20c1564", size = 1973209, upload-time = "2025-10-14T10:21:18.213Z" }, + { url = "https://files.pythonhosted.org/packages/36/0d/b5706cacb70a8414396efdda3d72ae0542e050b591119e458e2490baf035/pydantic_core-2.41.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:ed97fd56a561f5eb5706cebe94f1ad7c13b84d98312a05546f2ad036bafe87f4", size = 1877324, upload-time = "2025-10-14T10:21:20.363Z" }, + { url = "https://files.pythonhosted.org/packages/de/2d/cba1fa02cfdea72dfb3a9babb067c83b9dff0bbcb198368e000a6b756ea7/pydantic_core-2.41.4-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a870c307bf1ee91fc58a9a61338ff780d01bfae45922624816878dce784095d2", size = 1884515, upload-time = "2025-10-14T10:21:22.339Z" }, + { url = "https://files.pythonhosted.org/packages/07/ea/3df927c4384ed9b503c9cc2d076cf983b4f2adb0c754578dfb1245c51e46/pydantic_core-2.41.4-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d25e97bc1f5f8f7985bdc2335ef9e73843bb561eb1fa6831fdfc295c1c2061cf", size = 2042819, upload-time = "2025-10-14T10:21:26.683Z" }, + { url = "https://files.pythonhosted.org/packages/6a/ee/df8e871f07074250270a3b1b82aad4cd0026b588acd5d7d3eb2fcb1471a3/pydantic_core-2.41.4-cp313-cp313t-win_amd64.whl", hash = "sha256:d405d14bea042f166512add3091c1af40437c2e7f86988f3915fabd27b1e9cd2", size = 1995866, upload-time = "2025-10-14T10:21:28.951Z" }, + { url = "https://files.pythonhosted.org/packages/fc/de/b20f4ab954d6d399499c33ec4fafc46d9551e11dc1858fb7f5dca0748ceb/pydantic_core-2.41.4-cp313-cp313t-win_arm64.whl", hash = "sha256:19f3684868309db5263a11bace3c45d93f6f24afa2ffe75a647583df22a2ff89", size = 1970034, upload-time = "2025-10-14T10:21:30.869Z" }, ] [[package]] @@ -1584,16 +1584,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, ] -[[package]] -name = "pywin32" -version = "311" -source = { registry = "https://pypi.org/simple" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d", size = 8705700, upload-time = "2025-07-14T20:13:26.471Z" }, - { url = "https://files.pythonhosted.org/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d", size = 9494700, upload-time = "2025-07-14T20:13:28.243Z" }, - { url = "https://files.pythonhosted.org/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a", size = 8709318, upload-time = "2025-07-14T20:13:30.348Z" }, -] - [[package]] name = "pyyaml" version = "6.0.3" @@ -1827,15 +1817,15 @@ wheels = [ [[package]] name = "sentry-sdk" -version = "2.41.0" +version = "2.42.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/68/47/aea50a61d85bc07a34e6e7145aad7bd96c5671a86a32618059bad0cbc73b/sentry_sdk-2.41.0.tar.gz", hash = "sha256:e7af3f4d7f8bac4c56fbaf95adb0d111f061cce58d5df91cfcd4e69782759b10", size = 343942, upload-time = "2025-10-09T14:12:21.132Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/b2/7481156cf42b7f66cffb371e504b7ace12b4f016b8872ffcf0873ae9534b/sentry_sdk-2.42.0.tar.gz", hash = "sha256:91c69c9372fb5fb4df0ac39456ccf7286f0428b3ee1cdd389f9dd36c04e0f5c9", size = 351242, upload-time = "2025-10-15T07:41:15.577Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/71/58/175d0e4d93f62075a01f8aebe904b412c34a94a4517e5045d0a1d512aad0/sentry_sdk-2.41.0-py2.py3-none-any.whl", hash = "sha256:343cde6540574113d13d178d1b2093e011ac21dd55abd3a1ec7e540f0d18a5bd", size = 370606, upload-time = "2025-10-09T14:12:19.003Z" }, + { url = "https://files.pythonhosted.org/packages/58/4a/9810a246ec5d1df2ae066efefeecfa91d3c548fa2bd5390184e016112887/sentry_sdk-2.42.0-py2.py3-none-any.whl", hash = "sha256:1a7986e638306ff158f52dd47d9480a4055e6c289388caa90628acb2563fe7bd", size = 379496, upload-time = "2025-10-15T07:41:13.802Z" }, ] [[package]] @@ -1950,7 +1940,7 @@ requires-dist = [ { name = "streamlit-antd-components" }, { name = "sympy" }, { name = "torch", specifier = ">=2.6" }, - { name = "torchvision" }, + { name = "torchvision", specifier = ">=0.23,<0.24" }, { name = "tqdm" }, { name = "transformers" }, { name = "wandb", specifier = ">=0.20.1" }, @@ -2212,7 +2202,7 @@ wheels = [ [[package]] name = "transformers" -version = "4.57.0" +version = "4.57.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -2226,9 +2216,9 @@ dependencies = [ { name = "tokenizers" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f3/5c/a22c39dac2687f3fe2a6b97e2c1ae516e91cd4d3976a7a2b7c24ff2fae48/transformers-4.57.0.tar.gz", hash = "sha256:d045753f3d93f9216e693cdb168698dfd2e9d3aad1bb72579a5d60ebf1545a8b", size = 10142956, upload-time = "2025-10-03T17:03:47.177Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/68/a39307bcc4116a30b2106f2e689130a48de8bd8a1e635b5e1030e46fcd9e/transformers-4.57.1.tar.gz", hash = "sha256:f06c837959196c75039809636cd964b959f6604b75b8eeec6fdfc0440b89cc55", size = 10142511, upload-time = "2025-10-14T15:39:26.18Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/2b/4d2708ac1ff5cd708b6548f4c5812d0ae40d1c28591c4c1c762b6dbdef2d/transformers-4.57.0-py3-none-any.whl", hash = "sha256:9d7c6d098c026e40d897e017ed1f481ab803cbac041021dbc6ae6100e4949b55", size = 11990588, upload-time = "2025-10-03T17:03:43.629Z" }, + { url = "https://files.pythonhosted.org/packages/71/d3/c16c3b3cf7655a67db1144da94b021c200ac1303f82428f2beef6c2e72bb/transformers-4.57.1-py3-none-any.whl", hash = "sha256:b10d05da8fa67dc41644dbbf9bc45a44cb86ae33da6f9295f5fbf5b7890bd267", size = 11990925, upload-time = "2025-10-14T15:39:23.085Z" }, ] [[package]] From fddb323b61dfc9f4f0cca13310f44985036b1b14 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 16 Oct 2025 16:38:04 +0100 Subject: [PATCH 09/77] that worked...? --- spd/clustering/activations.py | 4 ++-- tests/clustering/scripts/cluster_ss.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py index d0b1a51a6..cd6a2b742 100644 --- a/spd/clustering/activations.py +++ b/spd/clustering/activations.py @@ -29,12 +29,12 @@ def component_activations( cache_type="input", ) - # TODO: !!!IMPORTANT!!! unclear whether pre_sigmoid is the right thing to use here + # TODO: !!!IMPORTANT!!! unclear what the right thing from CIOutputs is causal_importances = model.calc_causal_importances( pre_weight_acts=model_output.cache, sampling="continuous", detach_inputs=False, - ).pre_sigmoid + ).upper_leaky return causal_importances diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 43399d039..7e0e12a89 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -52,6 +52,7 @@ task_name="lm", n_batches=1, batch_size=2, + dataset_streaming=True, # no effect since we do this manually ) BATCHES, _ = split_dataset( @@ -139,4 +140,4 @@ # Exit cleanly to avoid CUDA thread GIL issues during interpreter shutdown # see https://github.com/goodfire-ai/spd/issues/201#issue-3503138939 # ============================================================ -os._exit(0) +# os._exit(0) From 8a56e12b740f4ac4dbd3d1237ecfdf6bb822113d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 16 Oct 2025 16:43:31 +0100 Subject: [PATCH 10/77] dont assert positive coacts? --- spd/clustering/plotting/activations.py | 7 +++++-- tests/clustering/scripts/cluster_ss.py | 6 ------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py index eb7a86b01..c2c3c6bd1 100644 --- a/spd/clustering/plotting/activations.py +++ b/spd/clustering/plotting/activations.py @@ -11,6 +11,7 @@ import wandb import wandb.sdk.wandb_run from jaxtyping import Float, Int +from muutils.dbg import dbg_tensor from torch import Tensor from spd.clustering.activations import ProcessedActivations, compute_coactivatons @@ -50,7 +51,9 @@ def plot_activations( act_dict: dict[str, ActivationsTensor] = processed_activations.activations_raw act_concat: ActivationsTensor = processed_activations.activations + dbg_tensor(act_concat) coact: ClusterCoactivationShaped = compute_coactivatons(act_concat) + dbg_tensor(coact) labels: ComponentLabels = ComponentLabels(processed_activations.labels) n_samples: int = act_concat.shape[0] @@ -209,8 +212,8 @@ def plot_activations( fig4_log: plt.Figure ax4_log: plt.Axes fig4_log, ax4_log = plt.subplots(figsize=figsize_coact) - assert np.all(coact_data >= 0) - coact_log_data: np.ndarray = np.log10(coact_data + 1e-6) + # assert np.all(coact_data >= 0) # TODO: why does this fail? + coact_log_data: np.ndarray = np.log10(coact_data + 1e-6 + coact_data.min()) im4_log = ax4_log.matshow( coact_log_data, aspect="auto", vmin=coact_log_data.min(), vmax=coact_log_data.max() ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 7e0e12a89..12dfdae4d 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -135,9 +135,3 @@ distances=DISTANCES, mode="points", ) - -# %% -# Exit cleanly to avoid CUDA thread GIL issues during interpreter shutdown -# see https://github.com/goodfire-ai/spd/issues/201#issue-3503138939 -# ============================================================ -# os._exit(0) From d18470b9ac27e34bf3102df1d6013fa66475ef93 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 16 Oct 2025 16:46:28 +0100 Subject: [PATCH 11/77] get rid of long-running merge pair sampler on GPU test well... somewhere in the range of 0.5 to 1s, but its the longest running test out of all the standard tests --- tests/clustering/test_merge_pair_samplers.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/clustering/test_merge_pair_samplers.py b/tests/clustering/test_merge_pair_samplers.py index e400b0dd3..66c59cb66 100644 --- a/tests/clustering/test_merge_pair_samplers.py +++ b/tests/clustering/test_merge_pair_samplers.py @@ -221,23 +221,6 @@ def test_registry_samplers_callable(self): class TestSamplerIntegration: """Integration tests for samplers with edge cases.""" - def test_samplers_with_gpu_tensors(self): - """Test samplers work with GPU tensors if available.""" - if not torch.cuda.is_available(): - pytest.skip("CUDA not available") - - k = 4 - costs = torch.randn(k, k, device="cuda") - costs = (costs + costs.T) / 2 - costs.fill_diagonal_(float("inf")) - - # Both samplers should work with GPU tensors - pair_range = range_sampler(costs, threshold=0.5) - pair_mcmc = mcmc_sampler(costs, temperature=1.0) - - assert isinstance(pair_range, tuple) - assert isinstance(pair_mcmc, tuple) - def test_samplers_deterministic_with_seed(self): """Test that samplers are deterministic with fixed seed.""" k = 5 From 83d4288e210944fd0716f69e6626837b8608a2d5 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 11:01:02 +0100 Subject: [PATCH 12/77] wip CI decision trees w/ random data --- spd/clustering/ci_dt.py | 320 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 320 insertions(+) create mode 100644 spd/clustering/ci_dt.py diff --git a/spd/clustering/ci_dt.py b/spd/clustering/ci_dt.py new file mode 100644 index 000000000..9dd735ad0 --- /dev/null +++ b/spd/clustering/ci_dt.py @@ -0,0 +1,320 @@ +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Literal + +import matplotlib.pyplot as plt +import numpy as np +from jaxtyping import Bool, Float +from sklearn.base import ClassifierMixin +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + balanced_accuracy_score, +) +from sklearn.multioutput import MultiOutputClassifier +from sklearn.tree import DecisionTreeClassifier, plot_tree + +# ----------------------- library code ----------------------- + + +@dataclass +class LayerModel: + """Holds a trained per-layer model.""" + + layer_index: int + model: ClassifierMixin + feature_dim: int + target_dim: int + + +def concat_cols( + Xs: Sequence[Bool[np.ndarray, "n_samples n_features"]], +) -> Bool[np.ndarray, "n_samples n_concat"]: + """Column-concat a sequence or return empty (n,0).""" + n_samples: int = Xs[0].shape[0] if len(Xs) else 0 + return np.concatenate(Xs, axis=1) if len(Xs) else np.zeros((n_samples, 0), bool) + + +def build_xy( + layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], +) -> list[ + tuple[ + Bool[np.ndarray, "n_samples n_features"], + Bool[np.ndarray, "n_samples n_targets"], + ] +]: + """Return (X_k,Y_k) for k=1..L-1 with X_k=concat(layers[:k]).""" + XYs: list[tuple[np.ndarray, np.ndarray]] = [] + for k in range(1, len(layers)): + X_k: np.ndarray = concat_cols(layers[:k]) + Y_k: np.ndarray = layers[k] + XYs.append((X_k, Y_k)) + return XYs + + +def train_trees( + layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + *, + strategy: Literal["one_vs_all", "single_tree"] = "one_vs_all", + max_depth: int | None = None, + min_samples_leaf: int = 1, + random_state: int | None = 0, +) -> list[LayerModel]: + """Train one model per target layer using previous layers as features.""" + XYs = build_xy(layers) + models: list[LayerModel] = [] + for k, (X_k, Y_k) in enumerate(XYs, start=1): + base = DecisionTreeClassifier( + max_depth=max_depth, + min_samples_leaf=min_samples_leaf, + random_state=random_state, + ) + model: ClassifierMixin = MultiOutputClassifier(base) if strategy == "one_vs_all" else base + _ = model.fit(X_k.astype(np.uint8), Y_k.astype(np.uint8)) + models.append(LayerModel(k, model, int(X_k.shape[1]), int(Y_k.shape[1]))) + return models + + +def predict_k( + models: Sequence[LayerModel], + prefix_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + k: int, + *, + threshold: float = 0.5, +) -> Bool[np.ndarray, "n_samples n_components_k"]: + """Predict layer k activations from layers[:k].""" + lm: LayerModel = next(m for m in models if m.layer_index == k) + X: np.ndarray = concat_cols(prefix_layers) + proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore + if isinstance(proba, list): + P: np.ndarray = np.stack([p[:, 1] for p in proba], axis=1) + else: + P = proba[..., 1] # type: ignore + Y_hat: np.ndarray = (float(threshold) <= P).astype(bool) + return Y_hat + + +def predict_all( + models: Sequence[LayerModel], + seed_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + *, + thresholds: Sequence[float] | None = None, +) -> list[Bool[np.ndarray, "n_samples n_components"]]: + """Sequentially predict layers 1.. using layer 0 as seed.""" + out: list[np.ndarray] = [seed_layers[0].copy()] + ths: list[float] = list(thresholds) if thresholds is not None else [] + for i, lm in enumerate(sorted(models, key=lambda m: m.layer_index)): + thr: float = ths[i] if i < len(ths) else 0.5 + out.append(predict_k(models, out, lm.layer_index, threshold=thr)) + return out + + +# ----------------------- random data ----------------------- + +rng: np.random.Generator = np.random.default_rng(2) +n: int = 250 +sizes: list[int] = [15, 9, 22, 6] + +# base probs per component +base_probs: list[np.ndarray] = [rng.uniform(0.05, 0.5, size=s) for s in sizes] + +layers_true: list[np.ndarray] = [ + (rng.uniform(size=(n, s)) < p).astype(bool) for s, p in zip(sizes, base_probs, strict=True) +] + +# ----------------------- fit and predict ----------------------- + +models: list[LayerModel] = train_trees(layers_true, max_depth=8, random_state=7) +layers_pred: list[np.ndarray] = predict_all(models, [layers_true[0]]) + +# ----------------------- metrics ----------------------- + + +def layer_metrics( + Y_true: Bool[np.ndarray, "n t"], + Y_prob: Float[np.ndarray, "n t"], + Y_pred: Bool[np.ndarray, "n t"], +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Return per-target AP, acc, bacc, prevalence.""" + T: int = Y_true.shape[1] + ap: np.ndarray = np.zeros(T) + acc: np.ndarray = np.zeros(T) + bacc: np.ndarray = np.zeros(T) + prev: np.ndarray = np.zeros(T) + for j in range(T): + y: np.ndarray = Y_true[:, j].astype(int) + p: np.ndarray = Y_prob[:, j] + yhat: np.ndarray = Y_pred[:, j].astype(int) + prev[j] = float(y.mean()) + try: + ap[j] = average_precision_score(y, p) + except Exception: + ap[j] = np.nan + try: + acc[j] = accuracy_score(y, yhat) + except Exception: + acc[j] = np.nan + try: + bacc[j] = balanced_accuracy_score(y, yhat) + except Exception: + bacc[j] = np.nan + return ap, acc, bacc, prev + + +# get probabilities for each layer +def proba_for_layer(lm: LayerModel, X: np.ndarray) -> np.ndarray: + """Return P(y=1) per target column.""" + pr = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore + if isinstance(pr, list): + return np.stack([p[:, 1] for p in pr], axis=1) + return pr[..., 1] # type: ignore + + +XYs_demo = build_xy(layers_true) +per_layer_stats: list[dict[str, Any]] = [] +all_triplets: list[tuple[int, int, float]] = [] # (layer, target_idx, AP) + +for lm, (Xk, Yk) in zip(models, XYs_demo, strict=True): + Pk: np.ndarray = proba_for_layer(lm, Xk) + Yhat_k: np.ndarray = Pk >= 0.5 + ap, acc, bacc, prev = layer_metrics(Yk, Pk, Yhat_k) + per_layer_stats.append( + { + "ap": ap, + "acc": acc, + "bacc": bacc, + "prev": prev, + "mean_ap": float(np.nanmean(ap)), + "mean_acc": float(np.nanmean(acc)), + "mean_bacc": float(np.nanmean(bacc)), + } + ) + for j, apj in enumerate(ap): + all_triplets.append((lm.layer_index, j, float(apj))) + +# identify best and worst trees across all outputs by AP +sorted_triplets = sorted(all_triplets, key=lambda t: (np.isnan(t[2]), t[2])) +worst_list = [t for t in sorted_triplets if not np.isnan(t[2])][:2] +best_list = [t for t in sorted_triplets if not np.isnan(t[2])][-2:] + + +# pull corresponding estimators (MultiOutputClassifier -> estimators_ list) +def get_estimator_for( + models: list[LayerModel], layer_idx: int, target_idx: int +) -> DecisionTreeClassifier: + """Fetch the per-output estimator for a given layer and column.""" + lm = next(m for m in models if m.layer_index == layer_idx) + if isinstance(lm.model, MultiOutputClassifier): + return lm.model.estimators_[target_idx] # type: ignore + return lm.model # type: ignore + + +# ----------------------- plotting ----------------------- + + +# 1) Single fig showing activations across all layers (true vs predicted stacked) +def plot_activations(layers_true: list[np.ndarray], layers_pred: list[np.ndarray]) -> None: + """Show true and predicted activations as heatmaps.""" + A_true: np.ndarray = np.concatenate(layers_true, axis=1) + A_pred: np.ndarray = np.concatenate([layers_pred[0]] + layers_pred[1:], axis=1) + fig1 = plt.figure(figsize=(10, 6)) + ax1 = fig1.add_subplot(2, 1, 1) + ax1.set_title("Activations (True)") + ax1.imshow(A_true, aspect="auto", interpolation="nearest") + ax1.set_xlabel("components (all layers concatenated)") + ax1.set_ylabel("samples") + ax2 = fig1.add_subplot(2, 1, 2) + ax2.set_title("Activations (Predicted)") + ax2.imshow(A_pred, aspect="auto", interpolation="nearest") + ax2.set_xlabel("components (all layers concatenated)") + ax2.set_ylabel("samples") + fig1.tight_layout() + + +# 2) Covariance matrix of all components +def plot_covariance(layers_true: list[np.ndarray]) -> None: + """Plot covariance between all components across layers.""" + A: np.ndarray = np.concatenate(layers_true, axis=1).astype(float) + C: np.ndarray = np.cov(A, rowvar=False) + fig2 = plt.figure(figsize=(6, 6)) + ax = fig2.add_subplot(1, 1, 1) + ax.set_title("Covariance of components (all layers)") + _im = ax.imshow(C, aspect="auto", interpolation="nearest") + ax.set_xlabel("component index") + ax.set_ylabel("component index") + fig2.tight_layout() + + +# 3) Accuracy ideas: bar of mean metrics per layer; scatter of prevalence vs AP +def plot_layer_metrics(per_layer_stats: list[dict[str, Any]]) -> None: + """Plot summary metrics per layer and per-target AP vs prevalence.""" + L: int = len(per_layer_stats) + mean_ap: np.ndarray = np.array([d["mean_ap"] for d in per_layer_stats]) + mean_acc: np.ndarray = np.array([d["mean_acc"] for d in per_layer_stats]) + mean_bacc: np.ndarray = np.array([d["mean_bacc"] for d in per_layer_stats]) + + # bar: mean AP, ACC, BACC per layer (three separate figures to respect one-plot rule) + fig3 = plt.figure(figsize=(8, 3)) + ax3 = fig3.add_subplot(1, 1, 1) + ax3.set_title("Mean Average Precision per layer") + ax3.bar(np.arange(1, L + 1), mean_ap) + ax3.set_xlabel("layer index (target)") + ax3.set_ylabel("mean AP") + fig3.tight_layout() + + fig4 = plt.figure(figsize=(8, 3)) + ax4 = fig4.add_subplot(1, 1, 1) + ax4.set_title("Mean Accuracy per layer") + ax4.bar(np.arange(1, L + 1), mean_acc) + ax4.set_xlabel("layer index (target)") + ax4.set_ylabel("mean accuracy") + fig4.tight_layout() + + fig5 = plt.figure(figsize=(8, 3)) + ax5 = fig5.add_subplot(1, 1, 1) + ax5.set_title("Mean Balanced Accuracy per layer") + ax5.bar(np.arange(1, L + 1), mean_bacc) + ax5.set_xlabel("layer index (target)") + ax5.set_ylabel("mean balanced accuracy") + fig5.tight_layout() + + # scatter: prevalence vs AP for all targets across layers + fig6 = plt.figure(figsize=(6, 5)) + ax6 = fig6.add_subplot(1, 1, 1) + ax6.set_title("Per-target AP vs prevalence") + x_list: list[float] = [] + y_list: list[float] = [] + for d in per_layer_stats: + x_list.extend(list(d["prev"])) + y_list.extend(list(d["ap"])) + ax6.scatter(x_list, y_list, alpha=0.6) + ax6.set_xlabel("prevalence") + ax6.set_ylabel("average precision") + fig6.tight_layout() + + +# 4) Display a couple decision trees (worst and best by AP) +def plot_selected_trees( + picks: list[tuple[int, int, float]], + title_prefix: str, + models: list[LayerModel], + feature_dims_prefix: list[int], +) -> None: + """Plot a list of selected trees by (layer, target_idx, score).""" + for layer_idx, target_idx, score in picks: + est = get_estimator_for(models, layer_idx, target_idx) + fig = plt.figure(figsize=(10, 6)) + ax = fig.add_subplot(1, 1, 1) + ax.set_title(f"{title_prefix}: layer {layer_idx}, target {target_idx}, AP={score:.3f}") + plot_tree(est, ax=ax, filled=False) # default styling + fig.tight_layout() + + +# Run the plots +plot_activations(layers_true, layers_pred) +plot_covariance(layers_true) +plot_layer_metrics(per_layer_stats) +plot_selected_trees(worst_list, "Worst", models, []) +plot_selected_trees(best_list, "Best", models, []) + +print("Plots generated.") From e2e8c5c8f076a8794eb67e7afea632d6d4de56fe Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 11:38:13 +0100 Subject: [PATCH 13/77] wip --- pyproject.toml | 1 + spd/clustering/ci_dt.py | 127 ++++++++++++++++++++++++++++++++++++---- uv.lock | 58 +++++++++++++++--- 3 files changed, 169 insertions(+), 17 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a7f7cb6a7..f55ba3150 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev", "scipy>=1.14.1", "muutils", + "scikit-learn", ] [dependency-groups] diff --git a/spd/clustering/ci_dt.py b/spd/clustering/ci_dt.py index 9dd735ad0..272e2bf76 100644 --- a/spd/clustering/ci_dt.py +++ b/spd/clustering/ci_dt.py @@ -1,9 +1,12 @@ +# %% + from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Literal import matplotlib.pyplot as plt import numpy as np +import torch from jaxtyping import Bool, Float from sklearn.base import ClassifierMixin from sklearn.metrics import ( @@ -13,6 +16,33 @@ ) from sklearn.multioutput import MultiOutputClassifier from sklearn.tree import DecisionTreeClassifier, plot_tree +from torch import Tensor + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.registry import EXPERIMENT_REGISTRY + +# ----------------------- config ----------------------- + + +@dataclass +class CIDTConfig: + """Configuration for causal importance decision tree training.""" + + experiment_key: str = "ss_emb" # Key from EXPERIMENT_REGISTRY + n_samples: int = 250 + activation_threshold: float = 0.01 # Threshold for boolean conversion + filter_dead_threshold: float = 0.001 # Threshold for filtering dead components + max_depth: int = 8 # Maximum depth for decision trees + random_state: int = 7 # Random state for reproducibility + # ----------------------- library code ----------------------- @@ -109,22 +139,99 @@ def predict_all( return out -# ----------------------- random data ----------------------- +# ----------------------- configuration ----------------------- + +config = CIDTConfig() +device: str = "cuda" if torch.cuda.is_available() else "cpu" + +# ----------------------- load model ----------------------- + +# Load SPD run info and model +exp_config = EXPERIMENT_REGISTRY[config.experiment_key] +assert exp_config.canonical_run is not None, f"No canonical run found for {config.experiment_key}" +assert exp_config.task_name == "lm", f"Only 'lm' task supported, got {exp_config.task_name}" -rng: np.random.Generator = np.random.default_rng(2) -n: int = 250 -sizes: list[int] = [15, 9, 22, 6] +spd_run: SPDRunInfo = SPDRunInfo.from_path(exp_config.canonical_run) +model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) +model.to(device) +cfg: Config = spd_run.config -# base probs per component -base_probs: list[np.ndarray] = [rng.uniform(0.05, 0.5, size=s) for s in sizes] +print(f"Loaded model from {exp_config.canonical_run}") +print(f"Task: {exp_config.task_name}") -layers_true: list[np.ndarray] = [ - (rng.uniform(size=(n, s)) < p).astype(bool) for s, p in zip(sizes, base_probs, strict=True) -] +# ----------------------- load dataset ----------------------- + +# Create LM dataset and dataloader +assert isinstance(cfg.task_config, LMTaskConfig) +pretrained_model_name = cfg.pretrained_model_name +assert pretrained_model_name is not None + +dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + column_name=cfg.task_config.column_name, + is_tokenized=False, + streaming=False, + seed=0, +) +dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=config.n_samples, + buffer_size=cfg.task_config.buffer_size, + global_seed=cfg.seed, + ddp_rank=0, + ddp_world_size=1, +) +batch_data = next(iter(dataloader)) +batch: Tensor = batch_data["input_ids"] +print(f"Created LM dataset with {cfg.task_config.dataset_name}, batch shape: {batch.shape}") + +# ----------------------- get activations ----------------------- + +# Get component activations (on device) +print("Computing component activations...") +component_acts: dict[str, Tensor] = component_activations( + model=model, + device=device, + batch=batch, +) + +# Process activations (filter dead components, concatenate) +print("Processing activations...") +processed_acts: ProcessedActivations = process_activations( + component_acts, + filter_dead_threshold=config.filter_dead_threshold, + seq_mode="seq_mean", # LM task needs seq_mean +) + +print(f"Total components (before filtering): {processed_acts.n_components_original}") +print(f"Alive components: {processed_acts.n_components_alive}") +print(f"Dead components: {processed_acts.n_components_dead}") +print(f"Module keys: {processed_acts.module_keys}") + +# ----------------------- convert to layers ----------------------- + +# Move to CPU and convert to numpy for sklearn +# Group by module to create "layers" for decision trees +print("\nConverting to boolean layers...") +layers_true: list[np.ndarray] = [] +for module_key in processed_acts.module_keys: + # Get the activations for this module from activations_raw, move to CPU + module_acts_cpu = processed_acts.activations_raw[module_key].cpu().numpy() + module_acts_bool = (module_acts_cpu >= config.activation_threshold).astype(bool) + layers_true.append(module_acts_bool) + print(f"Layer {len(layers_true) - 1} ({module_key}): {module_acts_bool.shape[1]} components") + +print(f"\nCreated {len(layers_true)} layers for decision tree training") # ----------------------- fit and predict ----------------------- -models: list[LayerModel] = train_trees(layers_true, max_depth=8, random_state=7) +print("\nTraining decision trees...") +models: list[LayerModel] = train_trees( + layers_true, max_depth=config.max_depth, random_state=config.random_state +) layers_pred: list[np.ndarray] = predict_all(models, [layers_true[0]]) # ----------------------- metrics ----------------------- diff --git a/uv.lock b/uv.lock index 9ae8b57da..f0872b510 100644 --- a/uv.lock +++ b/uv.lock @@ -729,6 +729,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "joblib" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, +] + [[package]] name = "jsonschema" version = "4.25.1" @@ -1091,7 +1100,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -1102,7 +1111,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -1129,9 +1138,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12" }, - { name = "nvidia-cusparse-cu12" }, - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -1142,7 +1151,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12" }, + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -1784,6 +1793,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" }, ] +[[package]] +name = "scikit-learn" +version = "1.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136, upload-time = "2025-09-09T08:21:29.075Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/93/a3038cb0293037fd335f77f31fe053b89c72f17b1c8908c576c29d953e84/scikit_learn-1.7.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b7dacaa05e5d76759fb071558a8b5130f4845166d88654a0f9bdf3eb57851b7", size = 9212382, upload-time = "2025-09-09T08:20:54.731Z" }, + { url = "https://files.pythonhosted.org/packages/40/dd/9a88879b0c1104259136146e4742026b52df8540c39fec21a6383f8292c7/scikit_learn-1.7.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:abebbd61ad9e1deed54cca45caea8ad5f79e1b93173dece40bb8e0c658dbe6fe", size = 8592042, upload-time = "2025-09-09T08:20:57.313Z" }, + { url = "https://files.pythonhosted.org/packages/46/af/c5e286471b7d10871b811b72ae794ac5fe2989c0a2df07f0ec723030f5f5/scikit_learn-1.7.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:502c18e39849c0ea1a5d681af1dbcf15f6cce601aebb657aabbfe84133c1907f", size = 9434180, upload-time = "2025-09-09T08:20:59.671Z" }, + { url = "https://files.pythonhosted.org/packages/f1/fd/df59faa53312d585023b2da27e866524ffb8faf87a68516c23896c718320/scikit_learn-1.7.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a4c328a71785382fe3fe676a9ecf2c86189249beff90bf85e22bdb7efaf9ae0", size = 9283660, upload-time = "2025-09-09T08:21:01.71Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c7/03000262759d7b6f38c836ff9d512f438a70d8a8ddae68ee80de72dcfb63/scikit_learn-1.7.2-cp313-cp313-win_amd64.whl", hash = "sha256:63a9afd6f7b229aad94618c01c252ce9e6fa97918c5ca19c9a17a087d819440c", size = 8702057, upload-time = "2025-09-09T08:21:04.234Z" }, + { url = "https://files.pythonhosted.org/packages/55/87/ef5eb1f267084532c8e4aef98a28b6ffe7425acbfd64b5e2f2e066bc29b3/scikit_learn-1.7.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9acb6c5e867447b4e1390930e3944a005e2cb115922e693c08a323421a6966e8", size = 9558731, upload-time = "2025-09-09T08:21:06.381Z" }, + { url = "https://files.pythonhosted.org/packages/93/f8/6c1e3fc14b10118068d7938878a9f3f4e6d7b74a8ddb1e5bed65159ccda8/scikit_learn-1.7.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:2a41e2a0ef45063e654152ec9d8bcfc39f7afce35b08902bfe290c2498a67a6a", size = 9038852, upload-time = "2025-09-09T08:21:08.628Z" }, + { url = "https://files.pythonhosted.org/packages/83/87/066cafc896ee540c34becf95d30375fe5cbe93c3b75a0ee9aa852cd60021/scikit_learn-1.7.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98335fb98509b73385b3ab2bd0639b1f610541d3988ee675c670371d6a87aa7c", size = 9527094, upload-time = "2025-09-09T08:21:11.486Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2b/4903e1ccafa1f6453b1ab78413938c8800633988c838aa0be386cbb33072/scikit_learn-1.7.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:191e5550980d45449126e23ed1d5e9e24b2c68329ee1f691a3987476e115e09c", size = 9367436, upload-time = "2025-09-09T08:21:13.602Z" }, + { url = "https://files.pythonhosted.org/packages/b5/aa/8444be3cfb10451617ff9d177b3c190288f4563e6c50ff02728be67ad094/scikit_learn-1.7.2-cp313-cp313t-win_amd64.whl", hash = "sha256:57dc4deb1d3762c75d685507fbd0bc17160144b2f2ba4ccea5dc285ab0d0e973", size = 9275749, upload-time = "2025-09-09T08:21:15.96Z" }, +] + [[package]] name = "scipy" version = "1.16.2" @@ -1899,6 +1932,7 @@ dependencies = [ { name = "numpy" }, { name = "pydantic" }, { name = "python-dotenv" }, + { name = "scikit-learn" }, { name = "scipy" }, { name = "simple-stories-train" }, { name = "streamlit" }, @@ -1934,6 +1968,7 @@ requires-dist = [ { name = "numpy" }, { name = "pydantic", specifier = ">=2" }, { name = "python-dotenv" }, + { name = "scikit-learn" }, { name = "scipy", specifier = ">=1.14.1" }, { name = "simple-stories-train", git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev" }, { name = "streamlit" }, @@ -2041,6 +2076,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/bd/de8d508070629b6d84a30d01d57e4a65c69aa7f5abe7560b8fad3b50ea59/termcolor-3.1.0-py3-none-any.whl", hash = "sha256:591dd26b5c2ce03b9e43f391264626557873ce1d379019786f99b0c2bee140aa", size = 7684, upload-time = "2025-04-30T11:37:52.382Z" }, ] +[[package]] +name = "threadpoolctl" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, +] + [[package]] name = "tiktoken" version = "0.12.0" @@ -2226,7 +2270,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools" }, + { name = "setuptools", marker = "sys_platform == 'linux'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" }, From 17d3a13aec3572ed19a06995b6b356c7fb482b2d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 11:49:10 +0100 Subject: [PATCH 14/77] wip --- spd/clustering/ci_dt.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/spd/clustering/ci_dt.py b/spd/clustering/ci_dt.py index 272e2bf76..3ae33aa07 100644 --- a/spd/clustering/ci_dt.py +++ b/spd/clustering/ci_dt.py @@ -27,7 +27,6 @@ from spd.data import DatasetConfig, create_data_loader from spd.experiments.lm.configs import LMTaskConfig from spd.models.component_model import ComponentModel, SPDRunInfo -from spd.registry import EXPERIMENT_REGISTRY # ----------------------- config ----------------------- @@ -146,18 +145,14 @@ def predict_all( # ----------------------- load model ----------------------- -# Load SPD run info and model -exp_config = EXPERIMENT_REGISTRY[config.experiment_key] -assert exp_config.canonical_run is not None, f"No canonical run found for {config.experiment_key}" -assert exp_config.task_name == "lm", f"Only 'lm' task supported, got {exp_config.task_name}" +wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" -spd_run: SPDRunInfo = SPDRunInfo.from_path(exp_config.canonical_run) +spd_run: SPDRunInfo = SPDRunInfo.from_path(wandb_run_path) model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) model.to(device) cfg: Config = spd_run.config -print(f"Loaded model from {exp_config.canonical_run}") -print(f"Task: {exp_config.task_name}") +print(f"Loaded model from {wandb_run_path}") # ----------------------- load dataset ----------------------- From 8dfea5ee430f211420ca6f84b9487653d69c7cfe Mon Sep 17 00:00:00 2001 From: Dan Braun Date: Mon, 20 Oct 2025 11:53:35 +0100 Subject: [PATCH 15/77] [clustering] Refactor to two-stage process (#203) * Have clustering be submitted via slurm jobs * Functional run_clustering.py script * Functional submitter (though workspace broken) * Refactors * Move pipeline_config.yaml to configs/ * Add non-used config entries to spd/clustering/configs/example.yaml * --n_runs --> --n-runs * Update run paths * Add calc_distances script * Fix run_pipeline wandb workspace * Have MergeConfig inherit from BaseConfig * Remove unused tensor_stats.py * Improve script docs and output file naming * Use existing random id generator which uses secrets * run_clustering -> main * Get tests passing (and remove some old tests) * Remove REFACTOR.md * Avoid loading run info twice * Add comment to test_run_clustering_happy_path * remove matplotlib related type ignore comments causing warns likely related to [f9a2a97](https://github.com/goodfire-ai/spd/commit/f9a2a97c95fe189388e53682f3818e8466950b3a) / https://github.com/goodfire-ai/spd/pull/210 * re-add --dist-worksteal in CI see https://github.com/goodfire-ai/spd/pull/203/files#r2425944229 * pdf_prefix -> figure_prefix (minimizing diff) * Add calc_distances to pipeline * Explicitly state distance plot file in pipeline * move ClusteringRunConfig to separate file run_clustering.py is very long, plus this helps us minimize the diff * [temp] rename {merge,clustering}_run_config for easier diff * re-add dataset_streaming option for fast CI tests * remove spd/clustering/utils it was empty * add `make clean` recipe * [wip] add back storage.py * wip * wip * make format * `Command` object continuing refactor, `make test` passing * try to fix some tests * working on fixing tests, some config issues, wip * ExecutionStamp reworking, it will manage output dirs * CI install fix -- python version used was 3.12, now we are on 3.13 * sigmoid typing fixes !!! unsure abt wht to use from CIOutputs * catch exception when checking git repo is clean * remove merge pair sampler GPU test * allow errors when calling repo_is_clean * wip debugging * fix?? * try/except workspace new view creation * fix create_clustering_workspace_view * just dont pass a wandb project to tests * config woes * wip * noneable slurm config * fix cli * wip * move `Command` to own class * storage refactor * wip * ensemble id keys must be unique if registered in db, pass None * snapshot creation, cli stuff * logging cleanup * Remove Command and use shlex * Remove Storage properties --------- Co-authored-by: Michael Ivanitskiy --- .vscode/launch.json | 32 ++ Makefile | 8 + pyproject.toml | 2 +- spd/base_config.py | 23 +- spd/clustering/configs/example.toml | 36 -- spd/clustering/configs/example.yaml | 31 +- .../configs/pipeline-test-resid_mlp1.yaml | 9 + .../configs/pipeline-test-simplestories.yaml | 9 + spd/clustering/configs/pipeline_config.yaml | 9 + spd/clustering/configs/simplestories_dev.json | 2 +- spd/clustering/configs/test-resid_mlp1.json | 6 +- .../configs/test-simplestories.json | 9 +- spd/clustering/dataset.py | 137 ++++++ spd/clustering/ensemble_registry.py | 72 +++ spd/clustering/math/tensor_stats.py | 160 ------- spd/clustering/merge.py | 11 +- spd/clustering/merge_history.py | 3 +- spd/clustering/merge_run_config.py | 289 ++++-------- spd/clustering/pipeline/__init__.py | 0 .../pipeline/clustering_pipeline.py | 118 ----- spd/clustering/pipeline/dist_utils.py | 313 ------------- spd/clustering/pipeline/s1_split_dataset.py | 164 ------- spd/clustering/pipeline/s2_clustering.py | 409 ---------------- .../pipeline/s3_normalize_histories.py | 32 -- .../pipeline/s4_compute_distances.py | 92 ---- spd/clustering/pipeline/storage.py | 300 ------------ spd/clustering/plotting/activations.py | 43 +- spd/clustering/scripts/calc_distances.py | 120 +++++ spd/clustering/scripts/main.py | 92 ---- spd/clustering/scripts/run_clustering.py | 435 ++++++++++++++++++ spd/clustering/scripts/run_pipeline.py | 375 +++++++++++++++ spd/clustering/storage.py | 19 + spd/experiments/ih/ih_decomposition.py | 2 +- spd/identity_insertion.py | 6 +- spd/scripts/run.py | 136 +++--- spd/spd_types.py | 7 +- spd/utils/command_utils.py | 37 ++ spd/utils/git_utils.py | 37 +- spd/utils/run_utils.py | 96 +++- spd/utils/slurm_utils.py | 170 +++++-- tests/clustering/scripts/cluster_resid_mlp.py | 4 +- tests/clustering/scripts/cluster_ss.py | 24 +- tests/clustering/test_calc_distances.py | 32 ++ .../clustering/test_clustering_experiments.py | 6 +- tests/clustering/test_merge_integration.py | 22 +- .../test_run_clustering_happy_path.py | 40 ++ tests/clustering/test_storage.py | 351 -------------- tests/clustering/test_wandb_integration.py | 153 ------ tests/scripts_run/test_main.py | 57 ++- uv.lock | 50 +- 50 files changed, 1858 insertions(+), 2732 deletions(-) delete mode 100644 spd/clustering/configs/example.toml create mode 100644 spd/clustering/configs/pipeline-test-resid_mlp1.yaml create mode 100644 spd/clustering/configs/pipeline-test-simplestories.yaml create mode 100644 spd/clustering/configs/pipeline_config.yaml create mode 100644 spd/clustering/dataset.py create mode 100644 spd/clustering/ensemble_registry.py delete mode 100644 spd/clustering/math/tensor_stats.py delete mode 100644 spd/clustering/pipeline/__init__.py delete mode 100644 spd/clustering/pipeline/clustering_pipeline.py delete mode 100644 spd/clustering/pipeline/dist_utils.py delete mode 100644 spd/clustering/pipeline/s1_split_dataset.py delete mode 100644 spd/clustering/pipeline/s2_clustering.py delete mode 100644 spd/clustering/pipeline/s3_normalize_histories.py delete mode 100644 spd/clustering/pipeline/s4_compute_distances.py delete mode 100644 spd/clustering/pipeline/storage.py create mode 100644 spd/clustering/scripts/calc_distances.py delete mode 100644 spd/clustering/scripts/main.py create mode 100644 spd/clustering/scripts/run_clustering.py create mode 100644 spd/clustering/scripts/run_pipeline.py create mode 100644 spd/clustering/storage.py create mode 100644 spd/utils/command_utils.py create mode 100644 tests/clustering/test_calc_distances.py create mode 100644 tests/clustering/test_run_clustering_happy_path.py delete mode 100644 tests/clustering/test_storage.py delete mode 100644 tests/clustering/test_wandb_integration.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 616cf64ce..226feb719 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -230,6 +230,38 @@ "--model_path", "wandb:goodfire/spd/runs/ioprgffh" ] + }, + { + "name": "run_clustering example", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/clustering/scripts/run_clustering.py", + "args": [ + "--config", + "${workspaceFolder}/spd/clustering/configs/example.yaml", + ], + "python": "${command:python.interpreterPath}", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } + }, + { + "name": "clustering pipeline", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/clustering/scripts/run_pipeline.py", + "args": [ + "--config", + "${workspaceFolder}/spd/clustering/configs/pipeline_config.yaml", + ], + "python": "${command:python.interpreterPath}", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } } ] } \ No newline at end of file diff --git a/Makefile b/Makefile index 45bc29feb..85a0a4d8a 100644 --- a/Makefile +++ b/Makefile @@ -75,3 +75,11 @@ coverage: mkdir -p $(COVERAGE_DIR) uv run python -m coverage report -m > $(COVERAGE_DIR)/coverage.txt uv run python -m coverage html --directory=$(COVERAGE_DIR)/html/ + + +.PHONY: clean +clean: + @echo "Cleaning Python cache and build artifacts..." + find . -type d -name "__pycache__" -exec rm -rf {} + + find . -type d -name "*.egg-info" -exec rm -rf {} + + rm -rf build/ dist/ .ruff_cache/ .pytest_cache/ .coverage \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a7f7cb6a7..ac30280e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ dev = [ [project.scripts] spd-run = "spd.scripts.run:cli" -spd-cluster = "spd.clustering.scripts.main:cli" +spd-cluster = "spd.clustering.scripts.run_pipeline:cli" [build-system] requires = ["setuptools", "wheel"] diff --git a/spd/base_config.py b/spd/base_config.py index c9b488e19..860898907 100644 --- a/spd/base_config.py +++ b/spd/base_config.py @@ -6,6 +6,14 @@ from pydantic import BaseModel, ConfigDict +class FileTypeError(ValueError): + """Error raised when a file has an unsupported type/extension.""" + + +class ConfigValidationError(ValueError): + """Error raised when a config file fails pydantic validation.""" + + class BaseConfig(BaseModel): """Pydantic BaseModel suited for configs. @@ -15,6 +23,8 @@ class BaseConfig(BaseModel): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True) + # TODO: add a "config_type" field, which is set to the class name, so that when loading a config we can check whether the config type matches the expected class + @classmethod def from_file(cls, path: Path | str) -> Self: """Load config from path to a JSON or YAML file.""" @@ -27,9 +37,16 @@ def from_file(cls, path: Path | str) -> Self: case Path() if path.suffix in [".yaml", ".yml"]: data = yaml.safe_load(path.read_text()) case _: - raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}") + raise FileTypeError(f"Only (.json, .yaml, .yml) files are supported, got {path}") + + try: + cfg = cls.model_validate(data) + except Exception as e: + raise ConfigValidationError( + f"Error validating config {cls=} from path `{path.as_posix()}`\n{data = }" + ) from e - return cls.model_validate(data) + return cfg def to_file(self, path: Path | str) -> None: """Save config to file (format inferred from extension).""" @@ -43,4 +60,4 @@ def to_file(self, path: Path | str) -> None: case ".yaml" | ".yml": path.write_text(yaml.dump(self.model_dump(mode="json"))) case _: - raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}") + raise FileTypeError(f"Only (.json, .yaml, .yml) files are supported, got {path}") diff --git a/spd/clustering/configs/example.toml b/spd/clustering/configs/example.toml deleted file mode 100644 index 98053576b..000000000 --- a/spd/clustering/configs/example.toml +++ /dev/null @@ -1,36 +0,0 @@ -# Example MergeRunConfig in TOML format - -# Run configuration -model_path = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" # WandB path to the decomposed model -task_name = "lm" # Task name (must be explicit: tms, resid_mlp, lm, ih) -# experiment_key = "tms_5-2" # Alternative: use experiment key from EXPERIMENT_REGISTRY -n_batches = 10 # Ensemble size -batch_size = 64 # Batch size for processing -- number of samples for each run in the ensemble - -# WandB configuration -wandb_enabled = false # Enable WandB logging -wandb_project = "spd-cluster" # WandB project name - -[intervals] -stat = 1 # for k_groups, merge_pair_cost, mdl_loss -tensor = 100 # for wandb_log_tensor and fraction_* calculations -plot = 100 # for calling the plotting callback -artifact = 100 # for calling the artifact callback - -# Optional: Override defaults (typically set via CLI args) -# base_path = ".data/clustering/" # defaults to .data/clustering/ -# workers_per_device = 1 # defaults to 1 -# devices = ["cpu"] # defaults to ["cpu"], CLI will override with ["cuda"] if available - -# Merge algorithm parameters (wrapped in merge_config) -[merge_config] -activation_threshold = 0.01 # set to null to use scalar activations for cost calculation -alpha = 1.0 # rank penalty term -iters = 100 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? -pop_component_prob = 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway -filter_dead_threshold = 0.001 # Threshold for filtering dead components -module_name_filter = "__NULL__" # Can be a string prefix like "model.layers.0." if you want to do only some modules -merge_pair_sampling_method = "range" # Method for sampling merge pairs: 'range' or 'mcmc' - -[merge_config.merge_pair_sampling_kwargs] -threshold = 0.05 # For range sampler: fraction of the range of costs to sample from diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/example.yaml index 259f1597c..efa36d693 100644 --- a/spd/clustering/configs/example.yaml +++ b/spd/clustering/configs/example.yaml @@ -1,10 +1,14 @@ -# Example MergeRunConfig in YAML format +model_path: wandb:goodfire/spd/runs/zxbu57pt # WandB path to the decomposed model +batch_size: 8 # Batch size for processing -- number of samples for each run in the ensemble +dataset_seed: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) +# idx_in_ensemble: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) +# output_dir: .data/clustering/clustering_runs # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) +# ensemble_id: 1234567890 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) -# Merge algorithm parameters (wrapped in merge_config) merge_config: activation_threshold: 0.01 # set to null to use scalar activations for cost calculation alpha: 1.0 # rank penalty term - iters: 100 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? + iters: 10 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' merge_pair_sampling_kwargs: threshold: 0.05 # For range sampler: fraction of the range of costs to sample from @@ -12,23 +16,10 @@ merge_config: filter_dead_threshold: 0.001 # Threshold for filtering dead components module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules -# Run configuration -model_path: wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh # WandB path to the decomposed model -task_name: lm # Task name (must be explicit: tms, resid_mlp, lm, ih) -# experiment_key: tms_5-2 # Alternative: use experiment key from EXPERIMENT_REGISTRY -n_batches: 10 # Ensemble size -batch_size: 64 # Batch size for processing -- number of samples for each run in the ensemble - -# WandB configuration -wandb_enabled: false # Enable WandB logging -wandb_project: spd-cluster # WandB project name -intervals: +wandb_project: spd-cluster +wandb_entity: goodfire +logging_intervals: stat: 1 # for k_groups, merge_pair_cost, mdl_loss tensor: 100 # for wandb_log_tensor and fraction_* calculations plot: 100 # for calling the plotting callback - artifact: 100 # for calling the artifact callback - -# Optional: Override defaults (typically set via CLI args) -# base_path: .data/clustering/ # defaults to .data/clustering/ -# workers_per_device: 1 # defaults to 1 -# devices: ["cpu"] # defaults to ["cpu"], CLI will override with ["cuda"] if available \ No newline at end of file + artifact: 100 # for calling the artifact callback \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml new file mode 100644 index 000000000..e6680b8d0 --- /dev/null +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -0,0 +1,9 @@ +run_clustering_config_path: "spd/clustering/configs/test-resid_mlp1.json" +n_runs: 2 +distances_method: "perm_invariant_hamming" +base_output_dir: "tests/.temp/clustering" +slurm_job_name_prefix: null +slurm_partition: null +wandb_project: null # wandb fails in CI +wandb_entity: "goodfire" +create_git_snapshot: false \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml new file mode 100644 index 000000000..a2fc9ec9c --- /dev/null +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -0,0 +1,9 @@ +run_clustering_config_path: "spd/clustering/configs/test-simplestories.json" +n_runs: 2 +distances_method: "perm_invariant_hamming" +base_output_dir: "tests/.temp/clustering" +slurm_job_name_prefix: null +slurm_partition: null +wandb_project: null # wandb fails in CI +wandb_entity: "goodfire" +create_git_snapshot: false \ No newline at end of file diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml new file mode 100644 index 000000000..3d2085c6b --- /dev/null +++ b/spd/clustering/configs/pipeline_config.yaml @@ -0,0 +1,9 @@ +run_clustering_config_path: "spd/clustering/configs/example.yaml" +n_runs: 2 +distances_method: "perm_invariant_hamming" +base_output_dir: "/mnt/polished-lake/spd/clustering" +slurm_job_name_prefix: "spd" +slurm_partition: "h100-reserved" +wandb_project: "spd-cluster" +wandb_entity: "goodfire" +create_git_snapshot: true \ No newline at end of file diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/simplestories_dev.json index 552309465..89cbfde06 100644 --- a/spd/clustering/configs/simplestories_dev.json +++ b/spd/clustering/configs/simplestories_dev.json @@ -9,7 +9,7 @@ "filter_dead_threshold": 0.1, "module_name_filter": null }, - "model_path": "wandb:goodfire/spd/runs/lxs77xye", + "model_path": "wandb:goodfire/spd/runs/rn9klzfs", "task_name": "lm", "distances_method": "jaccard", "n_batches": 1, diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/test-resid_mlp1.json index fbacff53a..6dd7fb12b 100644 --- a/spd/clustering/configs/test-resid_mlp1.json +++ b/spd/clustering/configs/test-resid_mlp1.json @@ -10,11 +10,9 @@ "module_name_filter": null }, "experiment_key": "resid_mlp1", - "n_batches": 2, "batch_size": 100, - "wandb_enabled": true, - "wandb_project": "spd-cluster", - "intervals": { + "wandb_project": null, + "logging_intervals": { "stat": 1, "tensor": 5, "plot": 10, diff --git a/spd/clustering/configs/test-simplestories.json b/spd/clustering/configs/test-simplestories.json index 377eb6af1..891177ab1 100644 --- a/spd/clustering/configs/test-simplestories.json +++ b/spd/clustering/configs/test-simplestories.json @@ -9,13 +9,10 @@ "filter_dead_threshold": 0.9, "module_name_filter": "model.layers.0" }, - "model_path": "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh", - "task_name": "lm", - "n_batches": 1, + "model_path": "wandb:goodfire/spd/runs/lxs77xye", "batch_size": 1, - "wandb_enabled": true, - "wandb_project": "spd-cluster", - "intervals": { + "wandb_project": null, + "logging_intervals": { "stat": 1, "tensor": 2, "plot": 3, diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py new file mode 100644 index 000000000..c514aa69f --- /dev/null +++ b/spd/clustering/dataset.py @@ -0,0 +1,137 @@ +"""Dataset loading utilities for clustering runs. + +Each clustering run loads its own dataset batch, seeded by the run index. +""" + +from typing import Any + +from spd.clustering.consts import BatchTensor +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.experiments.resid_mlp.configs import ResidMLPTaskConfig +from spd.experiments.resid_mlp.models import ResidMLP +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName + + +def load_dataset( + model_path: str, + task_name: TaskName, + batch_size: int, + seed: int, + **kwargs: Any, +) -> BatchTensor: + """Load a single batch for clustering. + + Each run gets its own dataset batch, seeded by idx_in_ensemble. + + Args: + model_path: Path to decomposed model + task_name: Task type + batch_size: Batch size + seed: Random seed for dataset + + Returns: + Single batch of data + """ + match task_name: + case "lm": + return _load_lm_batch( + model_path=model_path, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + case "resid_mlp": + return _load_resid_mlp_batch( + model_path=model_path, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + case _: + raise ValueError(f"Unsupported task: {task_name}") + + +def _load_lm_batch( + model_path: str, batch_size: int, seed: int, config_kwargs: dict[str, Any] | None = None +) -> BatchTensor: + """Load a batch for language model task.""" + spd_run = SPDRunInfo.from_path(model_path) + cfg = spd_run.config + + assert isinstance(cfg.task_config, LMTaskConfig), ( + f"Expected task_config to be of type LMTaskConfig, but got {type(cfg.task_config) = }" + ) + + try: + pretrained_model_name: str = cfg.pretrained_model_name # pyright: ignore[reportAssignmentType] + assert pretrained_model_name is not None + except Exception as e: + raise AttributeError("Could not find 'pretrained_model_name' in the SPD Run config") from e + + config_kwargs_: dict[str, Any] = { + **dict( + is_tokenized=False, + streaming=False, + ), + **(config_kwargs or {}), + } + + dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + seed=seed, # Use run-specific seed + column_name=cfg.task_config.column_name, + **config_kwargs_, + ) + + dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=cfg.task_config.buffer_size, + global_seed=seed, # Use run-specific seed + ddp_rank=0, + ddp_world_size=1, + ) + + # Get first batch + batch = next(iter(dataloader)) + return batch["input_ids"] + + +def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchTensor: + """Load a batch for ResidMLP task.""" + from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset + from spd.utils.data_utils import DatasetGeneratedDataLoader + + spd_run = SPDRunInfo.from_path(model_path) + cfg = spd_run.config + component_model = ComponentModel.from_pretrained(spd_run.checkpoint_path) + + assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( + f"Expected task_config to be of type ResidMLPTaskConfig, but got {type(cfg.task_config) = }" + ) + assert isinstance(component_model.target_model, ResidMLP), ( + f"Expected target_model to be of type ResidMLP, but got {type(component_model.target_model) = }" + ) + + # Create dataset with run-specific seed + dataset = ResidMLPDataset( + n_features=component_model.target_model.config.n_features, + feature_probability=cfg.task_config.feature_probability, + device="cpu", + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=seed, # Use run-specific seed + label_coeffs=None, + data_generation_type=cfg.task_config.data_generation_type, + ) + + # Generate batch + dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False) + batch, _ = next(iter(dataloader)) + return batch diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py new file mode 100644 index 000000000..7756877d8 --- /dev/null +++ b/spd/clustering/ensemble_registry.py @@ -0,0 +1,72 @@ +"""Ensemble registry for tracking which clustering runs belong to which pipeline ensemble. + +Uses SQLite to maintain a mapping of (pipeline_run_id, idx, clustering_run_id). +""" + +import sqlite3 +from contextlib import contextmanager + +from spd.settings import SPD_CACHE_DIR + +# SQLite database path +_ENSEMBLE_REGISTRY_DB = SPD_CACHE_DIR / "clustering_ensemble_registry.db" + + +@contextmanager +def _get_connection(): + """Context manager for SQLite connection, ensures table exists.""" + _ENSEMBLE_REGISTRY_DB.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(_ENSEMBLE_REGISTRY_DB) + + try: + # Create table if not exists + conn.execute(""" + CREATE TABLE IF NOT EXISTS ensemble_runs ( + pipeline_run_id TEXT NOT NULL, + idx INTEGER NOT NULL, + clustering_run_id TEXT NOT NULL, + PRIMARY KEY (pipeline_run_id, idx) + ) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_pipeline_run_id + ON ensemble_runs (pipeline_run_id) + """) + conn.commit() + + yield conn + finally: + conn.close() + + +def register_clustering_run(pipeline_run_id: str, idx: int, clustering_run_id: str) -> None: + """Register a clustering run as part of a pipeline ensemble. + + Args: + pipeline_run_id: The ensemble/pipeline run ID + idx: Index of this run in the ensemble + clustering_run_id: The individual clustering run ID + """ + with _get_connection() as conn: + conn.execute( + "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", + (pipeline_run_id, idx, clustering_run_id), + ) + conn.commit() + + +def get_clustering_runs(pipeline_run_id: str) -> list[tuple[int, str]]: + """Get all clustering runs for a pipeline ensemble. + + Args: + pipeline_run_id: The ensemble/pipeline run ID + + Returns: + List of (idx, clustering_run_id) tuples, sorted by idx + """ + with _get_connection() as conn: + cursor = conn.execute( + "SELECT idx, clustering_run_id FROM ensemble_runs WHERE pipeline_run_id = ? ORDER BY idx", + (pipeline_run_id,), + ) + return cursor.fetchall() diff --git a/spd/clustering/math/tensor_stats.py b/spd/clustering/math/tensor_stats.py deleted file mode 100644 index 4080b9795..000000000 --- a/spd/clustering/math/tensor_stats.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import Literal - -import torch -from jaxtyping import Float -from torch import Tensor - -StatsKey = Literal[ - "mean", - "std", - "median", - "min", - "max", - "q01", - "q05", - "q10", - "q25", - "q50", - "q75", - "q90", - "q95", - "q99", - "chosen_pair", -] - - -def _flatten_if_needed(x: Tensor) -> Tensor: - """Make x 1D without copy when possible.""" - x_flat: Tensor = x.reshape(-1) - return x_flat - - -def _approx_quantile( - x: Tensor, - qs: Float[Tensor, " n_quantiles"], - *, - max_elems: int = 5_000_000, - generator: torch.Generator | None = None, -) -> Float[Tensor, " n_quantiles"]: - """Approximate quantiles by subsampling if needed, else exact. - - If x.numel() > max_elems, draws a random subset of size max_elems (with replacement) - on the same device as x, then computes torch.quantile once for all qs. - """ - x1d: Tensor = _flatten_if_needed(x) - n: int = x1d.numel() - if n == 0: - raise ValueError("Empty tensor.") - if n > max_elems: - # sample with replacement to avoid materializing a giant permutation - g: torch.Generator | None = generator - idx: Tensor = torch.randint(0, n, (max_elems,), device=x1d.device, generator=g) - x_used: Tensor = x1d[idx] - else: - x_used = x1d - # Compute all quantiles in one shot to reuse the sort - q: Tensor = torch.quantile(x_used, qs, interpolation="linear") - return q - - -def _exact_quantile_all_at_once( - x: Tensor, qs: Float[Tensor, " n_quantiles"] -) -> Float[Tensor, " n_quantiles"]: - """Exact quantiles without repeated sorts.""" - x1d: Tensor = _flatten_if_needed(x) - q: Float[Tensor, " n_quantiles"] = torch.quantile(x1d, qs, interpolation="linear") - return q - - -def stats_dict( - data: Tensor, - *, - approx_if_large: bool = True, - max_elems_for_quantile: int = 5_000_000, - rng: torch.Generator | None = None, -) -> dict[StatsKey, float]: - """summary - - Compute common stats plus a set of quantiles. Uses a single quantile() call - for all requested quantiles; optionally switches to an approximate method - by subsampling when the input is very large to avoid RuntimeError. - - # Parameters: - - `data : Tensor` - Input tensor of any shape and dtype convertible to floating for stats. - - `approx_if_large : bool` - If True, use subsampling for quantiles when data is huge. (defaults to True) - - `max_elems_for_quantile : int` - Max elements before triggering approximate mode. (defaults to 5_000_000) - - `rng : torch.Generator | None` - Optional torch generator for reproducible subsampling. - - # Returns: - - `dict[StatsKey, float]` - Mapping from stat name to Python float. - - # Modifies: - - None - - # Usage: - - ```python - >>> x = torch.randn(50_000_000, device="cuda") - >>> out = stats_dict(x, approx_if_large=True, max_elems_for_quantile=5_000_000) - >>> out["q95"] - 1.64 - ``` - - # Raises: - - `ValueError` : if `data` is empty - """ - x: Tensor = data - if x.numel() == 0: - raise ValueError("Empty tensor.") - # Work in float for numerics, but keep device - xf: Tensor = x.float() - - # Fast exact ops that do not need the full sort - # std_mean does mean and std in one pass; aminmax does min and max together - std: Tensor - mean: Tensor - std, mean = torch.std_mean(xf) - mn: Tensor - mx: Tensor - mn, mx = torch.aminmax(xf) - - # median is a quantile; we can either reuse below or do .median() directly. - # We will get it from the quantiles call to avoid extra work. - q_values: Float[Tensor, " 9"] = torch.tensor( - [0.01, 0.05, 0.10, 0.25, 0.50, 0.75, 0.90, 0.95, 0.99], - device=xf.device, - dtype=xf.dtype, - ) - qs_all: Float[Tensor, " 9"] - if approx_if_large: - qs_all = _approx_quantile( - xf, - q_values, - max_elems=max_elems_for_quantile, - generator=rng, - ) - else: - qs_all = _exact_quantile_all_at_once(xf, q_values) - - out: dict[StatsKey, float] = { - "mean": float(mean.item()), - "std": float(std.item()), - "median": float(qs_all[4].item()), # median is at index 4 - "min": float(mn.item()), - "max": float(mx.item()), - "q01": float(qs_all[0].item()), - "q05": float(qs_all[1].item()), - "q10": float(qs_all[2].item()), - "q25": float(qs_all[3].item()), - "q50": float(qs_all[4].item()), # median again - "q75": float(qs_all[5].item()), - "q90": float(qs_all[6].item()), - "q95": float(qs_all[7].item()), - "q99": float(qs_all[8].item()), - } - return out diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index 3692e1687..fd982b83f 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -29,8 +29,6 @@ from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory -_BATCH_PREFIX_FMT: str = "\033[38;5;208m[{batch_id}]\033[0m" - class LogCallback(Protocol): def __call__( @@ -54,7 +52,6 @@ def merge_iteration( activations: ActivationsTensor, component_labels: ComponentLabels, log_callback: LogCallback | None = None, - batch_id: str = "unk", ) -> MergeHistory: """ Merge iteration with optional logging/plotting callbacks. @@ -63,10 +60,6 @@ def merge_iteration( the same core algorithm logic. """ - # setup - # ================================================== - pbar_prefix: str = _BATCH_PREFIX_FMT.format(batch_id=batch_id) - # compute coactivations # -------------------------------------------------- activation_mask_orig: BoolActivationsTensor | ActivationsTensor | None = ( @@ -200,9 +193,7 @@ def merge_iteration( merge_pair_cost: float = float(costs[merge_pair].item()) # Update progress bar - pbar.set_description( - f"{pbar_prefix} k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}" - ) + pbar.set_description(f"k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}") if log_callback is not None: log_callback( diff --git a/spd/clustering/merge_history.py b/spd/clustering/merge_history.py index 39247d0b7..5ba3226ce 100644 --- a/spd/clustering/merge_history.py +++ b/spd/clustering/merge_history.py @@ -214,7 +214,7 @@ def initial_k_groups(self) -> int: return int(self.merges.k_groups[0].item()) @override - def save(self, path: Path, wandb_url: str | None = None) -> None: + def save(self, path: Path) -> None: zf: zipfile.ZipFile with zipfile.ZipFile(path, "w") as zf: # save arrays @@ -234,7 +234,6 @@ def save(self, path: Path, wandb_url: str | None = None) -> None: json.dumps( dict( merge_config=self.merge_config.model_dump(mode="json"), - wandb_url=wandb_url, c_components=self.c_components, n_iters_current=self.n_iters_current, labels=self.labels, diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index b6b8d6ab6..60a5244d6 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -1,263 +1,128 @@ -"""Configuration for merge clustering runs that combines merge config with run parameters.""" +"""ClusteringRunConfig""" -import hashlib -import tomllib -import warnings from pathlib import Path -from typing import Any, Literal, Self +from typing import Any, Self -from muutils.misc.numerical import shorten_numerical_to_str from pydantic import Field, PositiveInt, model_validator from spd.base_config import BaseConfig -from spd.clustering.consts import DistancesMethod from spd.clustering.merge_config import MergeConfig -from spd.registry import EXPERIMENT_REGISTRY, ExperimentConfig -from spd.spd_types import TaskName +from spd.registry import EXPERIMENT_REGISTRY +from spd.settings import SPD_CACHE_DIR -# Define interval types and defaults -IntervalKey = Literal["stat", "tensor", "plot", "artifact"] -IntervalsDict = dict[IntervalKey, PositiveInt] -"""Type alias for intervals dictionary +class LoggingIntervals(BaseConfig): + """Intervals in which to log each type of output.""" -- `stat`: logging statistics (e.g., k_groups, merge_pair_cost, mdl_loss) -- `tensor`: logging tensors (e.g., wandb_log_tensor, fraction calculations) -- `plot`: generating plots -- `artifact`: creating artifacts (checkpoints) - -""" - -_DEFAULT_INTERVALS: IntervalsDict = { - "stat": 1, - "tensor": 100, - "plot": 100, - "artifact": 100, -} - - -def toml_read_file_with_none(path: Path, null_sentinel: str = "__NULL__") -> dict[str, Any]: - """Read a TOML file and recursively convert sentinel values to None. - - TOML doesn't support null/None values natively, so we use a sentinel string - that gets converted to None after parsing. - - Args: - path: Path to the TOML file - null_sentinel: String value to be converted to None (default: "__NULL__") - - Returns: - Dictionary with sentinel values replaced by None - """ - - def replace_sentinel_recursive(obj: Any) -> Any: - """Recursively replace sentinel values with None.""" - if isinstance(obj, dict): - return {key: replace_sentinel_recursive(value) for key, value in obj.items()} - elif isinstance(obj, list): - return [replace_sentinel_recursive(item) for item in obj] - elif isinstance(obj, str) and obj == null_sentinel: - return None - else: - return obj - - with path.open("rb") as f: - data = tomllib.load(f) - return replace_sentinel_recursive(data) + stat: PositiveInt = Field( + default=1, description="Logging statistics (e.g., k_groups, merge_pair_cost, mdl_loss)" + ) + tensor: PositiveInt = Field( + default=100, description="Logging tensors (e.g., wandb_log_tensor, fraction calculations)" + ) + plot: PositiveInt = Field( + default=100, description="Generating plots (e.g., plot_merge_iteration)" + ) + artifact: PositiveInt = Field( + default=100, description="Creating artifacts (e.g., merge_history)" + ) class ClusteringRunConfig(BaseConfig): - """Configuration for a complete merge clustering run. + """Configuration for a single clustering run. - Extends MergeConfig with parameters for model, dataset, and batch configuration. - CLI parameters (base_path, devices, workers_per_device, dataset_streaming) have defaults but will always be overridden + This config specifies the clustering algorithm parameters and data processing settings. + Deployment concerns (where to save, WandB settings, ensemble configuration) are handled + by ClusteringSubmitConfig. """ - merge_config: MergeConfig = Field( - description="Merge configuration", - ) - + # TODO: Handle both wandb strings and local file paths model_path: str = Field( - description="WandB path to the model (format: wandb:entity/project/run_id)", + description="WandB path to the decomposed model (format: wandb:entity/project/run_id)" ) - task_name: TaskName = Field( - description="Task name for the model (must be explicit)", + + batch_size: PositiveInt = Field(..., description="Batch size for processing") + dataset_seed: int = Field(0, description="Seed for dataset generation/loading") + base_output_dir: Path = Field( + default=SPD_CACHE_DIR / "clustering", + description="Base directory to save clustering runs", ) - experiment_key: str | None = Field( + ensemble_id: str | None = Field( default=None, - description="Original experiment key if created from spd_exp registry", - ) - n_batches: PositiveInt = Field( - default=10, - description="Number of batches to split the dataset into (ensemble size)", + description="Ensemble identifier for WandB grouping", ) - batch_size: PositiveInt = Field( - default=64, - description="Size of each batch for processing", + idx_in_ensemble: int = Field(0, description="Index of this run in the ensemble") + + merge_config: MergeConfig = Field(description="Merge algorithm configuration") + logging_intervals: LoggingIntervals = Field( + default_factory=LoggingIntervals, + description="Logging intervals", ) - distances_method: DistancesMethod = Field( - default="perm_invariant_hamming", - description="Method to use for computing distances between clusterings", + + wandb_project: str | None = Field( + default=None, + description="WandB project name (None to disable WandB logging)", ) + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") dataset_streaming: bool = Field( default=False, description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", ) - # Implementation details - # note that these are *always* overriden by CLI args in `spd/clustering/scripts/main.py`, but we have to have defaults here - # to avoid type issues with pydantic. however, these defaults should match the defaults in the CLI args. - base_path: Path = Field( - default_factory=lambda: Path(".data/clustering/"), - description="Base path for saving clustering outputs", - ) - workers_per_device: int = Field( - default=1, - description="Maximum number of concurrent clustering processes per device", - ) - devices: list[str] = Field( - default_factory=lambda: ["cpu"], - description="Devices to use for clustering", - ) - - # WandB configuration - wandb_enabled: bool = Field( - default=False, - description="Enable WandB logging for clustering runs", - ) - wandb_project: str = Field( - default="spd-cluster", - description="WandB project name for clustering runs", - ) - intervals: dict[IntervalKey, PositiveInt] = Field( - default_factory=lambda: _DEFAULT_INTERVALS.copy(), - description="Intervals for different logging operations", - ) - - @model_validator(mode="after") - def validate_model_path(self) -> Self: - """Validate that model_path is a proper WandB path.""" - if not self.model_path.startswith("wandb:"): - raise ValueError(f"model_path must start with 'wandb:', got: {self.model_path}") - - assert self.task_name in TaskName.__args__, ( - f"Invalid task_name: {self.task_name = }, must be in {TaskName.__args__ = }" - ) - return self + # TODO: no way to check this without knowing task + # @model_validator(mode="after") + # def validate_streaming_compatibility(self) -> Self: + # """Ensure dataset_streaming is only enabled for compatible tasks.""" + # if self.dataset_streaming and self.task_name != "lm": + # raise ValueError( + # f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" + # ) + # return self @model_validator(mode="before") - @classmethod - def validate_intervals(cls, data: dict[str, Any]) -> dict[str, Any]: - """Ensure all required interval keys are present.""" - - data_intervals: dict[IntervalKey, Any] = data.get("intervals", {}) - # warning if any keys are missing - missing_keys: set[IntervalKey] = set(_DEFAULT_INTERVALS.keys()) - set(data_intervals.keys()) - if missing_keys: - warnings.warn( - f"Missing interval keys in {data_intervals = }: {missing_keys}. Using defaults for those.", - UserWarning, - stacklevel=1, + def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: + experiment_key: str | None = values.get("experiment_key") + if experiment_key: + model_path_given: str | None = values.get("model_path") + model_path_from_experiment: str | None = EXPERIMENT_REGISTRY[ + experiment_key + ].canonical_run + assert model_path_from_experiment is not None, ( + f"Experiment '{experiment_key}' has no canonical_run defined in the EXPERIMENT_REGISTRY" ) + if model_path_given and model_path_given != model_path_from_experiment: + raise ValueError( + f"Both experiment_key '{experiment_key}' and model_path '{model_path_given}' given in config data, but they disagree: {model_path_from_experiment=}" + ) - data["intervals"] = { - **_DEFAULT_INTERVALS, - **data_intervals, - } + values["model_path"] = model_path_from_experiment + del values["experiment_key"] - return data + return values @model_validator(mode="after") - def validate_streaming_compatibility(self) -> Self: - """Ensure dataset_streaming is only enabled for compatible tasks.""" - if self.dataset_streaming and self.task_name != "lm": - raise ValueError( - f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" - ) + def validate_model_path(self) -> Self: + """Validate that model_path is a proper WandB path.""" + if not self.model_path.startswith("wandb:"): + raise ValueError(f"model_path must start with 'wandb:', got: {self.model_path}") return self - @model_validator(mode="before") - @classmethod - def handle_experiment_key(cls, data: dict[str, Any]) -> dict[str, Any]: - """handle passing experiment key instead of model_path and task_name. - - if we provide an experiment_key, then: - 1. use the `EXPERIMENT_REGISTRY` to fill in model_path and task_name - 2. check it's consistent with model_path and task_name from the file if those are provided - - """ - experiment_key: str | None = data.get("experiment_key") - model_path: str | None = data.get("model_path") - task_name: str | None = data.get("task_name") - if experiment_key is not None: - exp_config: ExperimentConfig = EXPERIMENT_REGISTRY[experiment_key] - - # Enforce consistency if explicit fields present - if model_path is not None: - assert model_path == exp_config.canonical_run, ( - f"Inconsistent model_path for {experiment_key}, version from file ({model_path}) does not match registry ({exp_config.canonical_run})" - ) - if task_name is not None: - assert task_name == exp_config.task_name, ( - f"Inconsistent task_name for {experiment_key}, version from file ({task_name}) does not match registry ({exp_config.task_name})" - ) - - # overwrite in data dict - data["model_path"] = exp_config.canonical_run - data["task_name"] = exp_config.task_name - - return data - @property def wandb_decomp_model(self) -> str: - """Extract the WandB run ID of the source decomposition from the model_path - - Format: wandb:entity/project/run_id or wandb:entity/project/runs/run_id - """ - parts: list[str] = self.model_path.replace("wandb:", "").split("/") + """Extract the WandB run ID of the source decomposition.""" + parts = self.model_path.replace("wandb:", "").split("/") if len(parts) >= 3: - # Handle both formats: with and without 'runs' in path - return parts[-1] if parts[-1] != "runs" else parts[-2] if len(parts) > 3 else parts[-1] - else: - raise ValueError(f"Invalid wandb path format: {self.model_path}") - - @property - def wandb_group(self) -> str: - """Generate WandB group name based on parent model""" - return f"model-{self.wandb_decomp_model}" - - @property - def _iters_str(self) -> str: - """Shortened string representation of iterations for run ID""" - if self.merge_config.iters is None: - return "_auto" - return shorten_numerical_to_str(self.merge_config.iters) - - @property - def config_identifier(self) -> str: - """Unique identifier for this specific config on this specific model. - - Format: model_abc123-a0.1-i1k-b64-n10-h_12ab - Allows filtering in WandB for all runs with this exact config and model. - """ - return f"task_{self.task_name}-w_{self.wandb_decomp_model}-a{self.merge_config.alpha:g}-i{self._iters_str}-b{self.batch_size}-n{self.n_batches}-h_{self.stable_hash}" - - @property - def stable_hash(self) -> str: - """Generate a stable hash including all config parameters.""" - return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6] + return parts[-1] if parts[-1] != "runs" else parts[-2] + raise ValueError(f"Invalid wandb path format: {self.model_path}") def model_dump_with_properties(self) -> dict[str, Any]: """Serialize config including computed properties for WandB logging.""" - base_dump: dict[str, Any] = self.model_dump() + base_dump: dict[str, Any] = self.model_dump(mode="json") # Add computed properties base_dump.update( { "wandb_decomp_model": self.wandb_decomp_model, - "wandb_group": self.wandb_group, - "config_identifier": self.config_identifier, - "stable_hash": self.stable_hash, } ) diff --git a/spd/clustering/pipeline/__init__.py b/spd/clustering/pipeline/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/spd/clustering/pipeline/clustering_pipeline.py b/spd/clustering/pipeline/clustering_pipeline.py deleted file mode 100644 index 8c6b72f9d..000000000 --- a/spd/clustering/pipeline/clustering_pipeline.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Orchestration layer - clustering pipeline coordination""" - -import os -from collections.abc import Iterator -from pathlib import Path -from typing import Any - -from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.log import logger - -os.environ["WANDB_QUIET"] = "True" - - -def main(config: ClusteringRunConfig) -> None: - """Run the complete clustering pipeline. - - Args: - config: ClusteringRunConfig containing all pipeline parameters - """ - logger.section("setup") - - from spd.clustering.consts import BatchTensor, DistancesArray, DistancesMethod, MergesArray - from spd.clustering.math.merge_distances import compute_distances - from spd.clustering.pipeline.dist_utils import distribute_clustering - from spd.clustering.pipeline.s1_split_dataset import split_dataset - from spd.clustering.pipeline.s3_normalize_histories import normalize_and_save - from spd.clustering.pipeline.s4_compute_distances import create_clustering_report - from spd.clustering.pipeline.storage import ClusteringStorage - - logger.info("Imports complete") - - # Initialize storage - storage: ClusteringStorage = ClusteringStorage( - base_path=config.base_path, run_identifier=config.config_identifier - ) - logger.info(f"Initialized storage at: {storage.run_path}") - - # Save run configuration - storage.save_run_config(config) - logger.info(f"Run record saved to: {storage.run_config_file}") - - # Save config to a path that can be passed to subprocess - config_path: Path = storage.run_path / "config.json" - config_path.write_text(config.model_dump_json(indent=2)) - logger.info(f"Config saved to: {config_path}") - - # Split dataset into batches - logger.info(f"Splitting dataset into {config.n_batches} batches...") - split_dataset_kwargs: dict[str, Any] = dict() - if config.dataset_streaming: - logger.info("Using streaming dataset loading") - split_dataset_kwargs["config_kwargs"] = dict(streaming=True) - # check this here as well as the model validator because we edit `config.dataset_streaming` after init in main() after the CLI args are parsed - # not sure if this is actually a problem though - assert config.task_name == "lm", ( - f"Streaming dataset loading only supported for 'lm' task, got '{config.task_name = }'. Remove dataset_streaming=True from config or use a different task." - ) - batches: Iterator[BatchTensor] - dataset_config: dict[str, Any] - batches, dataset_config = split_dataset( - config=config, - **split_dataset_kwargs, - ) - storage.save_batches(batches=batches, config=dataset_config) - batch_paths: list[Path] = storage.get_batch_paths() - n_batch_paths: int = len(batch_paths) - logger.info(f"Dataset split complete. Saved {n_batch_paths} batches to: {storage._batches_dir}") - - # Process batches in parallel via subprocess shell-out - logger.section("computing clusterings") - logger.info( - f"Processing {n_batch_paths} batches with {config.workers_per_device} workers per device on {config.devices}..." - ) - distribute_prefix: str = "\033[92m[spd-cluster]\033[0m" - - from spd.clustering.pipeline.dist_utils import ClusteringBatchResult - - results: list[ClusteringBatchResult] = distribute_clustering( - config_path=config_path, - data_files=batch_paths, - devices=config.devices, - base_path=config.base_path, - run_identifier=config.config_identifier, - workers_per_device=config.workers_per_device, - log_fn=lambda msg: logger.info(f"{distribute_prefix} {msg}"), - log_fn_error=lambda msg: logger.error(f"{distribute_prefix} {msg}"), - ) - logger.info(f"Batch processing complete. Processed {len(results)} batches") - - logger.section("computing distances") - - # Normalize and save ensemble - logger.info("Normalizing merge histories across ensemble...") - normalized_merge_array: MergesArray = normalize_and_save(storage=storage) - logger.info( - f"Normalized merge array saved: shape={normalized_merge_array.shape}, dtype={normalized_merge_array.dtype}" - ) - - # Compute distances - distances_method: DistancesMethod = config.distances_method - logger.info(f"Computing distances using method: {distances_method}") - distances: DistancesArray = compute_distances( - normalized_merge_array=normalized_merge_array, - method=distances_method, - ) - storage.save_distances(distances=distances, method=distances_method) - logger.info(f"Distances computed and saved: shape={distances.shape}") - - # Create clustering report - wandb_urls: list[str] = [r["wandb_url"] for r in results if r["wandb_url"] is not None] - logger.info(f"Creating clustering report with {len(wandb_urls)} WandB URLs") - create_clustering_report( - distances=distances, - method=distances_method, - wandb_urls=wandb_urls, - config_identifier=config.config_identifier, - ) - logger.info("Clustering report created successfully") diff --git a/spd/clustering/pipeline/dist_utils.py b/spd/clustering/pipeline/dist_utils.py deleted file mode 100644 index 5f7d8f7fd..000000000 --- a/spd/clustering/pipeline/dist_utils.py +++ /dev/null @@ -1,313 +0,0 @@ -"""Distribution utilities for parallel clustering via subprocess shell-out.""" - -import json -import os -import selectors -import subprocess -from collections.abc import Callable -from dataclasses import dataclass -from pathlib import Path -from typing import IO, TypedDict - -from spd.log import logger -from spd.settings import REPO_ROOT - - -class ClusteringBatchResult(TypedDict): - """Result from clustering a single batch.""" - - hist_save_path: str - wandb_url: str | None - batch_name: str - config_identifier: str - - -# Module-global cache for JSON writer in child processes -_JSON_WRITER: IO[str] | None = None - - -@dataclass -class ActiveProcess: - """Tracks an active subprocess and its associated metadata.""" - - proc: subprocess.Popen[bytes] - json_fd: IO[bytes] - dataset_path: Path - device: str - - -def launch_child_with_json_fd(cmd: list[str]) -> tuple[subprocess.Popen[bytes], IO[bytes]]: - """Launch child process with JSON fd via environment variable. - - This allows the child to write structured JSON output to a dedicated file descriptor - while still allowing stdout/stderr to stream normally to the console. - - Args: - cmd: Command and arguments to execute - - Returns: - Tuple of (subprocess handle, read file descriptor for JSON results) - """ - # get the pipes - json_fd_rw: tuple[int, int] = os.pipe() # (read_fd, write_fd) - os.set_inheritable(json_fd_rw[1], True) - os.set_inheritable(json_fd_rw[0], False) - - # Pass the fd number via environment variable - env: dict[str, str] = dict(os.environ) - env["JSON_FD"] = str(json_fd_rw[1]) - - # launch the child process - proc: subprocess.Popen[bytes] = subprocess.Popen( - cmd, - env=env, - stdout=None, # Let stdout stream to console - stderr=None, # Let stderr stream to console - pass_fds=(json_fd_rw[1],), - close_fds=True, - ) - - # In parent process: close the write fd (child has it) and return read fd - os.close(json_fd_rw[1]) - json_r: IO[bytes] = os.fdopen(json_fd_rw[0], "rb", buffering=0) - return proc, json_r - - -def _open_json_fd() -> IO[str]: - """Open file descriptor for JSON output from environment variable. - - Called by child processes to get the fd for emitting structured results. - Caches the writer globally to avoid re-wrapping the same FD. - - Returns: - IO[str]: Text-mode writer (utf-8), line-buffered - """ - global _JSON_WRITER - if _JSON_WRITER is None: - fd_num: int = int(os.environ["JSON_FD"]) - # Use utf-8 explicitly; line-buffered - _JSON_WRITER = os.fdopen(fd_num, "w", buffering=1, encoding="utf-8") # pyright: ignore[reportConstantRedefinition] - return _JSON_WRITER - - -def emit_result(obj: ClusteringBatchResult) -> None: - """Emit result JSON via environment fd. - - Called by child processes to return structured results to the parent. - - Args: - obj: Result dictionary to serialize and emit - """ - out: IO[str] = _open_json_fd() - print(json.dumps(obj, separators=(",", ":")), file=out, flush=True) - - -def _read_json_result(json_r: IO[bytes], dataset_path: Path) -> ClusteringBatchResult: - """Read JSON result from file descriptor. - - Args: - json_r: Read file descriptor for JSON data - dataset_path: Path to dataset being processed (for error messages) - - Returns: - Parsed JSON result dictionary - - Raises: - RuntimeError: If no JSON result was received - ValueError: If JSON parsing failed - """ - json_line: bytes = json_r.readline() - if not json_line: - raise RuntimeError(f"No JSON result received from {dataset_path}") - - json_str: str = json_line.decode("utf-8", errors="strict").strip() - try: - result: ClusteringBatchResult = json.loads(json_str) - return result - except json.JSONDecodeError as e: - raise ValueError( - f"Failed to parse JSON result from {dataset_path}: {e}\nJSON string: {json_str}" - ) from e - - -def _collect_one_ready( - active: list[ActiveProcess], - log_fn: Callable[[str], None], -) -> tuple[ClusteringBatchResult, ActiveProcess]: - """Block until any active process has JSON ready, then collect it. - - Uses selectors to wait on multiple FDs simultaneously, avoiding head-of-line blocking. - - Args: - active: Currently active processes - log_fn: Logger for info messages - - Returns: - Tuple of (parsed JSON result, the corresponding ActiveProcess) - - Raises: - RuntimeError: If subprocess exits with non-zero code - """ - sel: selectors.BaseSelector = selectors.DefaultSelector() - try: - for ap in active: - sel.register(ap.json_fd, selectors.EVENT_READ, ap) - key: selectors.SelectorKey - key, _mask = sel.select()[0] # select() -> list[(SelectorKey, int)] - ap: ActiveProcess = key.data # type: ignore[assignment] - finally: - sel.close() - - result: ClusteringBatchResult = _read_json_result(ap.json_fd, ap.dataset_path) - rc: int | None = ap.proc.wait() - try: # noqa: SIM105 - ap.json_fd.close() - except Exception: - pass - - if rc != 0: - raise RuntimeError( - f"Subprocess {ap.proc.pid} on {ap.device} exited with code {rc} for dataset {ap.dataset_path}" - ) - - log_fn(f"Process {ap.proc.pid} finished, freeing slot on {ap.device}") - return result, ap - - -def distribute_clustering( - config_path: Path, - data_files: list[Path], - devices: list[str], - base_path: Path, - run_identifier: str, - workers_per_device: int = 1, - log_fn: Callable[[str], None] | None = None, - log_fn_error: Callable[[str], None] | None = None, -) -> list[ClusteringBatchResult]: - """Distribute clustering tasks across multiple devices via subprocess. - - Launches clustering processes using shell-out approach with JSON fd for structured - results. Manages concurrency based on workers_per_device and available devices. - - The concurrency model: - - Total concurrency = workers_per_device x len(devices) - - Uses round-robin device assignment starting point - - If target device is full, uses any available device - - If all devices are full, waits for ANY process to finish (whichever is ready first) - - Args: - config_path: Path to clustering configuration file - data_files: List of batch data files to process - devices: List of device strings (e.g., ['cuda:0', 'cuda:1']) - base_path: Base directory for clustering outputs - run_identifier: Unique identifier for this clustering run - workers_per_device: Maximum concurrent workers per device - log_fn: Optional logging function for info messages - log_fn_error: Optional logging function for error messages - - Returns: - List of result dictionaries from each batch processing - - Raises: - ValueError: If devices list is empty - RuntimeError: If subprocess fails or doesn't return results - """ - # setup logger - if log_fn is None: - log_fn = logger.info - if log_fn_error is None: - log_fn_error = lambda msg: logger.error(msg) - - # validate parameters - if workers_per_device < 1: - raise ValueError("workers_per_device must be >= 1") - - n_devices: int = len(devices) - if n_devices == 0: - raise ValueError("devices must be non-empty") - - # Track active processes per device to enforce workers_per_device limit - device_active_counts: dict[str, int] = {device: 0 for device in devices} - active: list[ActiveProcess] = [] - results: list[ClusteringBatchResult] = [] - - n_files: int = len(data_files) - try: - for idx, dataset in enumerate(data_files): - # Find a device with capacity, starting from round-robin position - device_idx = idx % n_devices - - # Check if we need to wait for a device to free up - while all(count >= workers_per_device for count in device_active_counts.values()): - # All devices are at capacity - wait for ANY process to finish - log_fn( - f"All devices at capacity ({workers_per_device} workers each). Waiting for any process to finish..." - ) - - # Wait for whichever process is ready first - result_i, finished_ap = _collect_one_ready(active, log_fn) - results.append(result_i) - device_active_counts[finished_ap.device] -= 1 - active.remove(finished_ap) - - # Now find a device with capacity, starting from our round-robin position - for i in range(n_devices): - check_idx = (device_idx + i) % n_devices - if device_active_counts[devices[check_idx]] < workers_per_device: - device_idx = check_idx - break - - device: str = devices[device_idx] - - cmd: list[str] = [ - "uv", - "run", - "python", - str(REPO_ROOT / "spd/clustering/pipeline/s2_clustering.py"), - "--config", - str(config_path), - "--dataset-path", - str(dataset), - "--base-path", - str(base_path), - "--run-identifier", - run_identifier, - "--device", - device, - ] - log_fn("[cmd] " + " ".join(cmd)) - - proc, json_r = launch_child_with_json_fd(cmd) - active_proc = ActiveProcess( - proc=proc, json_fd=json_r, dataset_path=dataset, device=device - ) - active.append(active_proc) - device_active_counts[device] += 1 - log_fn( - f"Started clustering {idx + 1}/{n_files} on {device} (pid={proc.pid}, active on device: {device_active_counts[device]}/{workers_per_device})\n\t{dataset}" - ) - - # Wait for remaining processes - while active: - result_i, finished_ap = _collect_one_ready(active, log_fn) - results.append(result_i) - device_active_counts[finished_ap.device] -= 1 - active.remove(finished_ap) - log_fn(f"Process {finished_ap.proc.pid} finished on {finished_ap.device}") - - except BaseException as e: - # this means we probably got a KeyboardInterrupt, so kill the child processes - log_fn_error(f"An error occurred: {e}") - for active_proc in active: - try: # noqa: SIM105 - active_proc.proc.kill() - except Exception: - pass - try: # noqa: SIM105 - active_proc.json_fd.close() - except Exception: - pass - log_fn_error(f"Killed process {active_proc.proc.pid} due to error") - raise - - return results diff --git a/spd/clustering/pipeline/s1_split_dataset.py b/spd/clustering/pipeline/s1_split_dataset.py deleted file mode 100644 index 94ac1a8bf..000000000 --- a/spd/clustering/pipeline/s1_split_dataset.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Loads and splits dataset into batches, returning them as an iterator. -""" - -from collections.abc import Generator, Iterator -from typing import Any - -import torch -from muutils.spinner import SpinnerContext -from torch import Tensor -from torch.utils.data import DataLoader -from tqdm import tqdm - -from spd.clustering.consts import BatchTensor -from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.configs import Config -from spd.data import DatasetConfig, create_data_loader -from spd.experiments.lm.configs import LMTaskConfig -from spd.experiments.resid_mlp.configs import ResidMLPModelConfig, ResidMLPTaskConfig -from spd.experiments.resid_mlp.models import ResidMLP -from spd.models.component_model import ComponentModel, SPDRunInfo - - -def split_dataset( - config: ClusteringRunConfig, - **kwargs: Any, -) -> tuple[Iterator[BatchTensor], dict[str, Any]]: - """Split a dataset into n_batches of batch_size, returning iterator and config""" - ds: Generator[BatchTensor] - ds_config_dict: dict[str, Any] - match config.task_name: - case "lm": - ds, ds_config_dict = _get_dataloader_lm( - model_path=config.model_path, - batch_size=config.batch_size, - **kwargs, - ) - case "resid_mlp": - ds, ds_config_dict = _get_dataloader_resid_mlp( - model_path=config.model_path, - batch_size=config.batch_size, - **kwargs, - ) - case name: - raise ValueError( - f"Unsupported task name '{name}'. Supported tasks are 'lm' and 'resid_mlp'. {config.model_path=}, {name=}" - ) - - # Limit iterator to n_batches - def limited_iterator() -> Iterator[BatchTensor]: - batch_idx: int - batch: BatchTensor - for batch_idx, batch in tqdm(enumerate(ds), total=config.n_batches, unit="batch"): - if batch_idx >= config.n_batches: - break - yield batch - - return limited_iterator(), ds_config_dict - - -def _get_dataloader_lm( - model_path: str, - batch_size: int, - config_kwargs: dict[str, Any] | None = None, -) -> tuple[Generator[BatchTensor], dict[str, Any]]: - """split up a SS dataset into n_batches of batch_size, returned the saved paths - - 1. load the config for a SimpleStories SPD Run given by model_path - 2. create a DataLoader for the dataset - 3. iterate over the DataLoader and save each batch to a file - - - """ - with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): - spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) - cfg: Config = spd_run.config - - try: - pretrained_model_name: str = cfg.pretrained_model_name # pyright: ignore[reportAssignmentType] - assert pretrained_model_name is not None - except Exception as e: - raise AttributeError( - "Could not find 'pretrained_model_name' in the SPD Run config, but called `_get_dataloader_lm`" - ) from e - - assert isinstance(cfg.task_config, LMTaskConfig), ( - f"Expected task_config to be of type LMTaskConfig since using `_get_dataloader_lm`, but got {type(cfg.task_config) = }" - ) - - config_kwargs_: dict[str, Any] = { - **dict( - is_tokenized=False, - streaming=False, - seed=0, - ), - **(config_kwargs or {}), - } - - dataset_config: DatasetConfig = DatasetConfig( - name=cfg.task_config.dataset_name, - hf_tokenizer_path=pretrained_model_name, - split=cfg.task_config.train_data_split, - n_ctx=cfg.task_config.max_seq_len, - column_name=cfg.task_config.column_name, - **config_kwargs_, - ) - - with SpinnerContext(message="getting dataloader..."): - dataloader: DataLoader[dict[str, torch.Tensor]] - dataloader, _tokenizer = create_data_loader( - dataset_config=dataset_config, - batch_size=batch_size, - buffer_size=cfg.task_config.buffer_size, - global_seed=cfg.seed, - ddp_rank=0, - ddp_world_size=1, - ) - - return (batch["input_ids"] for batch in dataloader), dataset_config.model_dump(mode="json") - - -def _get_dataloader_resid_mlp( - model_path: str, - batch_size: int, -) -> tuple[Generator[torch.Tensor], dict[str, Any]]: - """Split a ResidMLP dataset into n_batches of batch_size and save the batches.""" - from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset - from spd.utils.data_utils import DatasetGeneratedDataLoader - - with SpinnerContext(message=f"Loading SPD Run Config for '{model_path}'"): - spd_run: SPDRunInfo = SPDRunInfo.from_path(model_path) - # SPD_RUN = SPDRunInfo.from_path(EXPERIMENT_REGISTRY["resid_mlp3"].canonical_run) - component_model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) - cfg: Config = spd_run.config - - with SpinnerContext(message="Creating ResidMLPDataset..."): - assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( - f"Expected task_config to be of type ResidMLPTaskConfig since using `_get_dataloader_resid_mlp`, but got {type(cfg.task_config) = }" - ) - assert isinstance(component_model.target_model, ResidMLP), ( - f"Expected patched_model to be of type ResidMLP since using `_get_dataloader_resid_mlp`, but got {type(component_model.patched_model) = }" - ) - - assert isinstance(component_model.target_model.config, ResidMLPModelConfig), ( - f"Expected patched_model.config to be of type ResidMLPModelConfig since using `_get_dataloader_resid_mlp`, but got {type(component_model.target_model.config) = }" - ) - resid_mlp_dataset_kwargs: dict[str, Any] = dict( - n_features=component_model.target_model.config.n_features, - feature_probability=cfg.task_config.feature_probability, - device="cpu", - calc_labels=False, - label_type=None, - act_fn_name=None, - label_fn_seed=None, - label_coeffs=None, - data_generation_type=cfg.task_config.data_generation_type, - ) - dataset: ResidMLPDataset = ResidMLPDataset(**resid_mlp_dataset_kwargs) - - dataloader: DatasetGeneratedDataLoader[tuple[Tensor, Tensor]] = DatasetGeneratedDataLoader( - dataset, batch_size=batch_size, shuffle=False - ) - - return (batch[0] for batch in dataloader), resid_mlp_dataset_kwargs diff --git a/spd/clustering/pipeline/s2_clustering.py b/spd/clustering/pipeline/s2_clustering.py deleted file mode 100644 index bfeeadfbe..000000000 --- a/spd/clustering/pipeline/s2_clustering.py +++ /dev/null @@ -1,409 +0,0 @@ -"""Stage 2: Run clustering on individual batches (CLI script interface).""" - -import argparse -import os -import tempfile -from collections.abc import Callable -from dataclasses import dataclass -from functools import partial -from pathlib import Path - -import matplotlib.pyplot as plt -import torch -import wandb -from jaxtyping import Float, Int -from matplotlib.figure import Figure -from torch import Tensor -from wandb.sdk.wandb_run import Run - -from spd.clustering.activations import ( - ProcessedActivations, - component_activations, - process_activations, -) -from spd.clustering.consts import ( - ActivationsTensor, - BatchTensor, - ClusterCoactivationShaped, - ComponentLabels, -) -from spd.clustering.math.merge_matrix import GroupMerge -from spd.clustering.math.semilog import semilog -from spd.clustering.merge import _BATCH_PREFIX_FMT, merge_iteration -from spd.clustering.merge_history import MergeHistory -from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.clustering.pipeline.dist_utils import emit_result -from spd.clustering.pipeline.storage import ClusteringStorage -from spd.clustering.plotting.activations import plot_activations -from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration -from spd.clustering.wandb_tensor_info import wandb_log_tensor -from spd.log import logger -from spd.models.component_model import ComponentModel, SPDRunInfo - -os.environ["WANDB_QUIET"] = "True" - -LogCallback = Callable[ - [ - ClusterCoactivationShaped, - ComponentLabels, - GroupMerge, - ClusterCoactivationShaped, - MergeHistory, - int, - int, - float, - float, - float, - Float[Tensor, " k_groups"], - ], - None, -] - - -@dataclass -class ClusteringResult: - history_save_path: Path - wandb_url: str | None - - -def run_clustering( - config: ClusteringRunConfig, - data_path: Path, - base_path: Path, - run_identifier: str, - device: str, -) -> ClusteringResult: - """Run clustering on a single batch. - - Args: - config: Clustering configuration - data_path: Path to batch data file - base_path: Base directory for storage - run_identifier: Unique identifier for this clustering run - device: Device to run on (e.g., 'cuda:0', 'cpu') - - Returns: - ClusteringResult with save path and optional WandB URL - """ - batch_id: str = data_path.stem - prefix: str = _BATCH_PREFIX_FMT.format(batch_id=batch_id) - - def logger_call(msg: str) -> None: - logger.info(f"{prefix} {msg}") - - logger_call("starting batch") - storage: ClusteringStorage = ClusteringStorage( - base_path=base_path, run_identifier=run_identifier - ) - - run: Run | None = ( - _setup_wandb(batch_id=batch_id, config=config) if config.wandb_enabled else None - ) - logger_call("wandb setup complete") - - this_merge_plots_dir: Path = storage.history_path(batch_id).parent / "plots" - - spd_run: SPDRunInfo = SPDRunInfo.from_path(config.model_path) - logger_call("loaded spd run info") - - model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path).to(device) - logger_call("loaded model") - - batch: BatchTensor = storage.load_batch(data_path).to(device) - logger_call(f"loaded batch {batch_id} with shape {batch.shape}") - - activations_dict: ( - dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] - ) = component_activations( - model=model, - batch=batch, - device=device, - ) - logger_call("computed activations") - - processed_activations: ProcessedActivations = process_activations( - activations=activations_dict, - filter_dead_threshold=config.merge_config.filter_dead_threshold, - seq_mode="concat" if config.task_name == "lm" else None, - filter_modules=config.merge_config.filter_modules, - ) - logger_call("processed activations") - - wandb_url: str | None - if run is not None: - wandb_log_tensor( - run=run, - data=processed_activations.activations, - name="processed_activations", - step=0, - single=True, - ) - wandb_url = run.url - else: - wandb_url = None - - # Use original activations for raw plots, but filtered data for concat/coact/histograms - logger_call("plotting") - plot_activations( - processed_activations=processed_activations, - save_dir=this_merge_plots_dir, - n_samples_max=256, # TODO: make this configurable? - wandb_run=run, - ) - logger_call(f"plots saved to {this_merge_plots_dir}") - - logger_call("cleaning up memory") - activations: ActivationsTensor = processed_activations.activations - component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) - del processed_activations # we copied what we needed - del activations_dict # processed already - del model # already did the forward pass - del batch # already did the forward pass - - log_callback: LogCallback | None = ( - partial(_log_callback, run=run, batch_id=batch_id, config=config) - if run is not None - else None - ) - - logger_call("starting merging") - history: MergeHistory = merge_iteration( - merge_config=config.merge_config, - activations=activations, - component_labels=component_labels, - log_callback=log_callback, - batch_id=batch_id, - ) - logger_call("merging complete") - - history_save_path: Path = storage.history_path(batch_id) - - history.save(history_save_path, wandb_url=wandb_url) - - if run is not None: - _log_merge_history_plots_to_wandb(run, history) - _save_merge_history_to_wandb( - run, history_save_path, batch_id, config.config_identifier, history - ) - - run.finish() - - logger_call("batch complete") - - return ClusteringResult(history_save_path=history_save_path, wandb_url=wandb_url) - - -def _setup_wandb( - batch_id: str, - config: ClusteringRunConfig, -) -> Run: - run: Run = wandb.init( - project=config.wandb_project, - name=f"{config.config_identifier}-{batch_id}", - group=config.wandb_group, - config=config.model_dump_with_properties(), - tags=[ - "cluster-run", - f"model:{config.wandb_decomp_model}", - f"task:{config.task_name}", - f"batch:{batch_id}", - f"config:{config.config_identifier}", - ], - ) - logger.info( - f"{_BATCH_PREFIX_FMT.format(batch_id=batch_id)} Initialized WandB run: {run.name} in group {config.wandb_group}" - ) - return run - - -def _log_merge_history_plots_to_wandb(run: Run, history: MergeHistory) -> None: - fig_cs: Figure = plot_merge_history_cluster_sizes(history=history) - run.log( - {"plots/merge_history_cluster_sizes": wandb.Image(fig_cs)}, - step=history.n_iters_current, - ) - plt.close(fig_cs) - - -def _save_merge_history_to_wandb( - run: Run, - history_path: Path, - batch_id: str, - config_identifier: str, - history: MergeHistory, -) -> None: - artifact: wandb.Artifact = wandb.Artifact( - name=f"merge_history_{batch_id}", - type="merge_history", - description=f"Merge history for batch {batch_id}", - metadata={ - "batch_name": batch_id, - "config_identifier": config_identifier, - "n_iters_current": history.n_iters_current, - "filename": history_path, - }, - ) - artifact.add_file(str(history_path)) - run.log_artifact(artifact) - - -def _log_callback( - run: Run, - batch_id: str, - current_coact: ClusterCoactivationShaped, - component_labels: ComponentLabels, - current_merge: GroupMerge, - config: ClusteringRunConfig, - costs: ClusterCoactivationShaped, - merge_history: MergeHistory, - iter_idx: int, - k_groups: int, - merge_pair_cost: float, - mdl_loss: float, - mdl_loss_norm: float, - diag_acts: Float[Tensor, " k_groups"], -) -> None: - if iter_idx % config.intervals["stat"] == 0: - run.log( - { - "k_groups": int(k_groups), - "merge_pair_cost": merge_pair_cost, - "merge_pair_cost_semilog[1e-3]": semilog(merge_pair_cost, epsilon=1e-3), - "mdl_loss": float(mdl_loss), - "mdl_loss_norm": float(mdl_loss_norm), - }, - step=iter_idx, - ) - - if iter_idx % config.intervals["tensor"] == 0: - group_sizes: Int[Tensor, " k_groups"] = current_merge.components_per_group - - tensor_data: dict[str, Tensor] = { - "coactivation": current_coact, - "costs": costs, - "group_sizes": group_sizes, - "group_activations": diag_acts, - "group_activations_over_sizes": ( - diag_acts / group_sizes.to(device=diag_acts.device).float() - ), - } - - fraction_singleton_groups: float = (group_sizes == 1).float().mean().item() - if fraction_singleton_groups > 0: - tensor_data["group_sizes.log1p"] = torch.log1p(group_sizes.float()) - - fraction_zero_coacts: float = (current_coact == 0).float().mean().item() - if fraction_zero_coacts > 0: - tensor_data["coactivation.log1p"] = torch.log1p(current_coact.float()) - - wandb_log_tensor(run, tensor_data, name="iters", step=iter_idx) - - run.log( - { - "fraction_singleton_groups": float(fraction_singleton_groups), - "num_nonsingleton_groups": int((group_sizes > 1).sum().item()), - "fraction_zero_coacts": float(fraction_zero_coacts), - }, - step=iter_idx, - ) - - if iter_idx > 0 and iter_idx % config.intervals["artifact"] == 0: - with tempfile.NamedTemporaryFile() as tmp_file: - file: Path = Path(tmp_file.name) - merge_history.save(file) - artifact: wandb.Artifact = wandb.Artifact( - name=f"merge_hist_iter.{batch_id}.iter_{iter_idx}", - type="merge_hist_iter", - description=f"Group indices for batch {batch_id} at iteration {iter_idx}", - metadata={ - "batch_name": batch_id, - "iteration": iter_idx, - "config": merge_history.merge_config.model_dump(mode="json"), - # TODO: had to remove identifiers on config due to MergeConfig <--> ClusteringRunConfig (formerly MergeRunConfig) split - # "config_identifier": merge_history.merge_config.config_identifier, - }, - ) - artifact.add_file(str(file)) - run.log_artifact(artifact) - - if iter_idx % config.intervals["plot"] == 0: - fig: Figure = plot_merge_iteration( - current_merge=current_merge, - current_coact=current_coact, - costs=costs, - iteration=iter_idx, - component_labels=component_labels, - show=False, - ) - run.log({"plots/merges": wandb.Image(fig)}, step=iter_idx) - plt.close(fig) - - -def cli() -> None: - """Command-line interface for running clustering on a single batch.""" - parser: argparse.ArgumentParser = argparse.ArgumentParser( - description="Run clustering on a single batch of data" - ) - parser.add_argument( - "--config", - "-c", - type=Path, - required=True, - help="Path to the clustering run config JSON/YAML file", - ) - parser.add_argument( - "--dataset-path", - "-d", - type=Path, - required=True, - help="Path to the dataset batch file (e.g., batch_00.npz)", - ) - parser.add_argument( - "--base-path", - "-b", - type=Path, - required=True, - help="Base directory for clustering outputs", - ) - parser.add_argument( - "--run-identifier", - "-r", - type=str, - required=True, - help="Unique identifier for this clustering run", - ) - parser.add_argument( - "--device", - "-D", - type=str, - default="cuda" if torch.cuda.is_available() else "cpu", - help="Device to run on (e.g., 'cuda:0', 'cpu')", - ) - - args: argparse.Namespace = parser.parse_args() - - # Load config - config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config) - - # Run clustering - result: ClusteringResult = run_clustering( - config=config, - data_path=args.dataset_path, - base_path=args.base_path, - run_identifier=args.run_identifier, - device=args.device, - ) - - # Emit structured result for parent process - emit_result( - { - "hist_save_path": str(result.history_save_path), - "wandb_url": result.wandb_url, - "batch_name": args.dataset_path.stem, - "config_identifier": config.config_identifier, - } - ) - - -if __name__ == "__main__": - cli() diff --git a/spd/clustering/pipeline/s3_normalize_histories.py b/spd/clustering/pipeline/s3_normalize_histories.py deleted file mode 100644 index b09733d44..000000000 --- a/spd/clustering/pipeline/s3_normalize_histories.py +++ /dev/null @@ -1,32 +0,0 @@ -from pathlib import Path -from typing import Any - -from spd.clustering.consts import MergesArray -from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble -from spd.clustering.pipeline.storage import ClusteringStorage, NormalizedEnsemble -from spd.log import logger - - -def normalize_and_save(storage: ClusteringStorage) -> MergesArray: - """Load merge histories from storage, normalize, and save ensemble""" - # load - histories: list[MergeHistory] = storage.load_histories() - ensemble: MergeHistoryEnsemble = MergeHistoryEnsemble(data=histories) - - # normalize - normalized_merge_array: MergesArray - normalized_merge_meta: dict[str, Any] - normalized_merge_array, normalized_merge_meta = ensemble.normalized() - - # save - ensemble_data: NormalizedEnsemble = NormalizedEnsemble( - merge_array=normalized_merge_array, - metadata=normalized_merge_meta, - ) - metadata_path: Path - array_path: Path - metadata_path, array_path = storage.save_ensemble(ensemble_data) - logger.info(f"metadata saved to {metadata_path}") - logger.info(f"merge array saved to {array_path}") - - return normalized_merge_array diff --git a/spd/clustering/pipeline/s4_compute_distances.py b/spd/clustering/pipeline/s4_compute_distances.py deleted file mode 100644 index 9e1de4974..000000000 --- a/spd/clustering/pipeline/s4_compute_distances.py +++ /dev/null @@ -1,92 +0,0 @@ -import wandb -from matplotlib import pyplot as plt -from matplotlib.axes import Axes - -from spd.clustering.consts import ( - DistancesArray, - DistancesMethod, -) -from spd.clustering.plotting.merge import plot_dists_distribution -from spd.log import logger - - -def create_clustering_report( - distances: DistancesArray, - method: DistancesMethod, - wandb_urls: list[str], - config_identifier: str, -) -> None: - """Create a WandB report with clustering results and distances plot""" - - # Extract entity/project from first URL for the report - first_url: str = wandb_urls[0] - entity: str - project: str - - if first_url.startswith("wandb:"): - run_path_parts: list[str] = first_url.replace("wandb:", "").split("/") - entity, project = run_path_parts[0], run_path_parts[1] - else: - # Parse full URL - parts: list[str] = first_url.split("/") - if "runs" in parts: - run_idx: int = parts.index("runs") + 1 - entity, project = parts[run_idx - 3], parts[run_idx - 2] - else: - logger.warning(f"Could not parse WandB URL: {first_url}") - return - - # Initialize WandB run for the summary report - with wandb.init( - project=project, - entity=entity, - name=f"clustering-summary-{config_identifier}", - tags=["clustering-summary", f"config:{config_identifier}", f"method:{method}"], - job_type="clustering-analysis", - config=dict(config_identifier=config_identifier, method=method), - ) as run: - # Create and log the distances distribution plot - ax: Axes = plot_dists_distribution( - distances=distances, mode="points", label=f"{method} distances" - ) - plt.title(f"Distance Distribution ({method})") - - # Only add legend if there are labeled artists - handles, _labels = ax.get_legend_handles_labels() - if handles: - plt.legend() - - # Get the figure from the axes - fig = ax.get_figure(root=True) - assert fig is not None - - # Log the plot - run.log( - { - f"distances/{method}": wandb.Image(fig), - "clustering/config_identifier": config_identifier, - } - ) - - plt.close(fig) - - # Log metadata about the batch runs - run.log( - { - "batch_runs/urls": wandb_urls, - } - ) - - # Create a summary table of run information - run_ids: list[str] = [] - for url in wandb_urls: - if "runs/" in url: - run_id: str = url.split("runs/")[-1] - run_ids.append(run_id) - - if run_ids: - run.log({"batch_runs/run_ids": run_ids}) - - logger.info( - f"Created wandb clustering summary report with {len(wandb_urls)} batch runs from config {config_identifier}:\n{run.url}/overview" - ) diff --git a/spd/clustering/pipeline/storage.py b/spd/clustering/pipeline/storage.py deleted file mode 100644 index febe9bcc8..000000000 --- a/spd/clustering/pipeline/storage.py +++ /dev/null @@ -1,300 +0,0 @@ -"""Storage layer for clustering pipeline - handles all persistence operations.""" - -import json -from collections.abc import Iterator -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import numpy as np -import torch -from torch import Tensor - -from spd.clustering.consts import BatchTensor, DistancesArray, DistancesMethod, MergesArray -from spd.clustering.merge_run_config import ClusteringRunConfig - -if TYPE_CHECKING: - from spd.clustering.merge_history import MergeHistory - - -@dataclass -class DatasetBatches: - """Container for dataset batches and their configuration.""" - - batches: list[Tensor] - config: dict[str, Any] - - -@dataclass -class NormalizedEnsemble: - """Container for normalized merge array and metadata.""" - - merge_array: MergesArray - metadata: dict[str, Any] - - -def _write_text_to_path_and_return(path: Path, data: str) -> Path: - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text(data) - return path - - -class ClusteringStorage: - """Handles all file I/O operations for the clustering pipeline. - - This class provides a clean separation between data transformations and persistence, - making the pipeline more testable and flexible. - - Filesystem structure: - / - └── / # Optional run-specific subdirectory - ├── run_config.json # Run configuration parameters - ├── dataset/ # Dataset and batch storage - │ ├── dataset_config.json # Dataset configuration metadata - │ └── batches/ # Individual data batches - │ ├── batch_00.npz # Batch 0 (input_ids array) - │ ├── batch_01.npz # Batch 1 - │ └── ... - ├── merge_histories/ # Merge history per batch - │ ├── data_/ # Per-batch history directory - │ │ └── merge_history.zip # Compressed merge history - │ └── ... - ├── ensemble/ # Normalized ensemble results - │ ├── ensemble_meta.json # Ensemble metadata - │ └── ensemble_merge_array.npz # Normalized merge array - └── distances/ # Distance matrices - ├── distances..npz # Distance array for each method - └── ... - """ - - # Directory structure constants - _DATASET_DIR: str = "dataset" - _BATCHES_DIR: str = "batches" - _HISTORIES_DIR: str = "merge_histories" - _ENSEMBLE_DIR: str = "ensemble" - _DISTANCES_DIR: str = "distances" - _DASHBOARD_DIR: str = "dashboard" - - # File naming constants - _RUN_CONFIG_FILE: str = "run_config.json" - _DATASET_CONFIG_FILE: str = "dataset_config.json" - _ENSEMBLE_META_FILE: str = "ensemble_meta.json" - _ENSEMBLE_ARRAY_FILE: str = "ensemble_merge_array.npz" - _BATCH_FILE_FMT: str = "batch_{batch_idx:02d}.npz" - _HISTORY_FILE_FMT: str = "{batch_id}" - _MERGE_HISTORY_FILE: str = "merge_history.zip" - _DISTANCES_FILE_FMT: str = "distances.{method}.npz" - _MODEL_INFO_FILE: str = "model_info.json" - _MAX_ACTIVATIONS_FILE_FMT: str = "max_activations_i{iteration}_n{n_samples}.json" - - def __init__(self, base_path: Path, run_identifier: str | None = None): - """Initialize storage with base path and optional run identifier. - - Args: - base_path: Root directory for all storage operations - run_identifier: Optional identifier to create a subdirectory for this run - """ - self._base_path: Path = base_path - if run_identifier: - self._run_path = base_path / run_identifier - else: - self._run_path = base_path - - # Ensure base directory exists - self._run_path.mkdir(parents=True, exist_ok=True) - - # directories - - # make base and run path properties so we don't accidentally modify them - @property - def base_path(self) -> Path: - return self._base_path - - @property - def run_path(self) -> Path: - return self._run_path - - @property - def _dataset_dir(self) -> Path: - return self.run_path / self._DATASET_DIR - - # directories themselves private, use the storage/read methods to interact with them - @property - def _batches_dir(self) -> Path: - return self._dataset_dir / self._BATCHES_DIR - - @property - def _histories_dir(self) -> Path: - return self.run_path / self._HISTORIES_DIR - - @property - def _ensemble_dir(self) -> Path: - return self.run_path / self._ENSEMBLE_DIR - - @property - def _distances_dir(self) -> Path: - return self.run_path / self._DISTANCES_DIR - - @property - def _dashboard_dir(self) -> Path: - return self.run_path / self._DASHBOARD_DIR - - # files - @property - def run_config_file(self) -> Path: - return self.run_path / self._RUN_CONFIG_FILE - - @property - def dataset_config_file(self) -> Path: - return self._dataset_dir / self._DATASET_CONFIG_FILE - - @property - def ensemble_meta_file(self) -> Path: - return self._ensemble_dir / self._ENSEMBLE_META_FILE - - @property - def ensemble_array_file(self) -> Path: - return self._ensemble_dir / self._ENSEMBLE_ARRAY_FILE - - @property - def model_info_file(self) -> Path: - return self.run_path / self._MODEL_INFO_FILE - - @property - def dashboard_model_info_file(self) -> Path: - return self._dashboard_dir / self._MODEL_INFO_FILE - - # dynamic - - def batch_path(self, batch_idx: int) -> Path: - return self._batches_dir / self._BATCH_FILE_FMT.format(batch_idx=batch_idx) - - def history_path(self, batch_id: str) -> Path: - return ( - self._histories_dir - / self._HISTORY_FILE_FMT.format(batch_id=batch_id) - / self._MERGE_HISTORY_FILE - ) - - def max_activations_path(self, iteration: int, n_samples: int) -> Path: - return self._dashboard_dir / self._MAX_ACTIVATIONS_FILE_FMT.format( - iteration=iteration, n_samples=n_samples - ) - - # Batch storage methods - - def save_dataset_config(self, config: dict[str, Any]) -> Path: - return _write_text_to_path_and_return( - self.dataset_config_file, json.dumps(config, indent=2) - ) - - def save_batch(self, batch: BatchTensor, batch_idx: int) -> Path: - batch_path: Path = self.batch_path(batch_idx) - batch_path.parent.mkdir(parents=True, exist_ok=True) - - np.savez_compressed(batch_path, input_ids=batch.cpu().numpy()) - return batch_path - - def save_batches(self, batches: Iterator[BatchTensor], config: dict[str, Any]) -> list[Path]: - paths: list[Path] = [] - - self.save_dataset_config(config) - - for idx, batch in enumerate(batches): - path: Path = self.save_batch(batch, idx) - paths.append(path) - - return paths - - def load_batch(self, batch_path: Path) -> BatchTensor: - data: dict[str, np.ndarray] = np.load(batch_path) - return torch.tensor(data["input_ids"]) - - def get_batch_paths(self) -> list[Path]: - return sorted(self._batches_dir.glob("batch_*.npz")) - - # History storage methods - - def save_history(self, history: "MergeHistory", batch_id: str) -> Path: - history_path: Path = self.history_path(batch_id) - history_path.parent.mkdir(parents=True, exist_ok=True) - history.save(history_path) - return history_path - - def load_history(self, batch_id: str) -> "MergeHistory": - # Import only at runtime to avoid circular dependencies - from spd.clustering.merge_history import MergeHistory - - return MergeHistory.read(self.history_path(batch_id)) - - def get_history_paths(self) -> list[Path]: - return sorted(self._histories_dir.glob(f"*/{self._MERGE_HISTORY_FILE}")) - - def load_histories(self) -> list["MergeHistory"]: - # Import only at runtime to avoid circular dependencies - from spd.clustering.merge_history import MergeHistory - - return [MergeHistory.read(path) for path in self.get_history_paths()] - - # Ensemble related storage methods - - def save_ensemble(self, ensemble: NormalizedEnsemble) -> tuple[Path, Path]: - """Save normalized ensemble data""" - self._ensemble_dir.mkdir(parents=True, exist_ok=True) - - # Save metadata - metadata_path: Path = self.ensemble_meta_file - metadata_path.write_text(json.dumps(ensemble.metadata, indent=2)) - - # Save merge array - array_path: Path = self.ensemble_array_file - np.savez_compressed(array_path, merges=ensemble.merge_array) - - return metadata_path, array_path - - def save_distances(self, distances: DistancesArray, method: DistancesMethod) -> Path: - self._distances_dir.mkdir(parents=True, exist_ok=True) - - distances_path: Path = self._distances_dir / self._DISTANCES_FILE_FMT.format(method=method) - np.savez_compressed(distances_path, distances=distances) - return distances_path - - def load_distances(self, method: DistancesMethod) -> DistancesArray: - distances_path: Path = self._distances_dir / self._DISTANCES_FILE_FMT.format(method=method) - data: dict[str, np.ndarray] = np.load(distances_path) - return data["distances"] - - def save_run_config(self, config: ClusteringRunConfig) -> Path: - return _write_text_to_path_and_return( - self.run_config_file, config.model_dump_json(indent=2) - ) - - def load_run_config(self) -> ClusteringRunConfig: - return ClusteringRunConfig.from_file(self.run_config_file) - - # Dashboard storage methods - - def save_max_activations( - self, data: dict[int, dict[str, list[dict[str, Any]]]], iteration: int, n_samples: int - ) -> Path: - """Save max activations data to dashboard directory.""" - max_act_path: Path = self.max_activations_path(iteration, n_samples) - return _write_text_to_path_and_return(max_act_path, json.dumps(data, indent=2)) - - def save_model_info(self, model_info: dict[str, Any]) -> Path: - """Save model info to run directory.""" - return _write_text_to_path_and_return( - self.model_info_file, json.dumps(model_info, indent=2) - ) - - def save_model_info_to_dashboard(self, model_info: dict[str, Any]) -> Path: - """Save or copy model info to dashboard directory.""" - return _write_text_to_path_and_return( - self.dashboard_model_info_file, json.dumps(model_info, indent=2) - ) - - def load_model_info(self) -> dict[str, Any] | None: - """Load model info from run directory if it exists.""" - if self.model_info_file.exists(): - return json.loads(self.model_info_file.read_text()) - return None diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py index c2c3c6bd1..e7f02ad3b 100644 --- a/spd/clustering/plotting/activations.py +++ b/spd/clustering/plotting/activations.py @@ -2,7 +2,6 @@ from collections.abc import Sequence from pathlib import Path -from typing import Literal import matplotlib as mpl import matplotlib.pyplot as plt @@ -20,7 +19,7 @@ def plot_activations( processed_activations: ProcessedActivations, - save_dir: Path, + save_dir: Path | None, n_samples_max: int, figure_prefix: str = "activations", figsize_raw: tuple[int, int] = (12, 4), @@ -30,7 +29,6 @@ def plot_activations( hist_bins: int = 100, do_sorted_samples: bool = False, wandb_run: wandb.sdk.wandb_run.Run | None = None, - save_fmt: Literal["pdf", "png", "svg"] = "pdf", ) -> None: """Plot activation visualizations including raw, concatenated, sorted, and coactivations. @@ -39,15 +37,16 @@ def plot_activations( act_concat: Concatenated activations tensor coact: Coactivation matrix labels: Component labels - save_dir: The directory to save the plots to - figure_prefix: Prefix for figure filenames + save_dir: The directory to save the plots to (None to skip saving to disk) + figure_prefix: Prefix for PDF filenames figsize_raw: Figure size for raw activations figsize_concat: Figure size for concatenated activations figsize_coact: Figure size for coactivations hist_scales: Tuple of (x_scale, y_scale) where each is "lin" or "log" hist_bins: Number of bins for histograms """ - save_dir.mkdir(parents=True, exist_ok=True) + if save_dir is not None: + save_dir.mkdir(parents=True, exist_ok=True) act_dict: dict[str, ActivationsTensor] = processed_activations.activations_raw act_concat: ActivationsTensor = processed_activations.activations @@ -82,8 +81,9 @@ def plot_activations( axs_act[i].set_ylabel(f"components\n{key}") axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})") - fig1_fname = save_dir / f"{figure_prefix}_raw.{save_fmt}" - _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) + if save_dir is not None: + fig1_fname = save_dir / f"{figure_prefix}_raw.pdf" + _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) # Log to WandB if available if wandb_run is not None: @@ -105,8 +105,9 @@ def plot_activations( plt.colorbar(im2) - fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.{save_fmt}" - fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) + if save_dir is not None: + fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.pdf" + fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) # Log to WandB if available if wandb_run is not None: @@ -174,8 +175,9 @@ def plot_activations( plt.colorbar(im3) - fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.{save_fmt}" - fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) + if save_dir is not None: + fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.pdf" + fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) # Log to WandB if available if wandb_run is not None: @@ -198,8 +200,9 @@ def plot_activations( plt.colorbar(im4) - fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.{save_fmt}" - fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) + if save_dir is not None: + fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.pdf" + fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) # Log to WandB if available if wandb_run is not None: @@ -212,7 +215,7 @@ def plot_activations( fig4_log: plt.Figure ax4_log: plt.Axes fig4_log, ax4_log = plt.subplots(figsize=figsize_coact) - # assert np.all(coact_data >= 0) # TODO: why does this fail? + # assert np.all(coact_data >= 0) # TODO: why are coacts negative? :/ coact_log_data: np.ndarray = np.log10(coact_data + 1e-6 + coact_data.min()) im4_log = ax4_log.matshow( coact_log_data, aspect="auto", vmin=coact_log_data.min(), vmax=coact_log_data.max() @@ -222,8 +225,9 @@ def plot_activations( add_component_labeling(ax4_log, labels, axis="x") add_component_labeling(ax4_log, labels, axis="y") plt.colorbar(im4_log) - fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.{save_fmt}" - fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) + if save_dir is not None: + fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.pdf" + fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) # Log to WandB if available if wandb_run is not None: @@ -317,8 +321,9 @@ def plot_activations( plt.tight_layout() - fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.{save_fmt}" - fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) + if save_dir is not None: + fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.pdf" + fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) # Log to WandB if available if wandb_run is not None: diff --git a/spd/clustering/scripts/calc_distances.py b/spd/clustering/scripts/calc_distances.py new file mode 100644 index 000000000..6ee277759 --- /dev/null +++ b/spd/clustering/scripts/calc_distances.py @@ -0,0 +1,120 @@ +"""Calculate distances between clustering runs in an ensemble. + +Output structure: + SPD_CACHE_DIR/ensemble/{pipeline_run_id}/ + ├── pipeline_config.yaml # Created by run_pipeline.py + ├── ensemble_meta.json # Ensemble metadata + ├── ensemble_merge_array.npz # Normalized merge array + ├── distances_.npz # Distance array for each method + └── plots/ + └── distances_.png # Distance distribution plot +""" + +import argparse +import json + +import numpy as np +from matplotlib import pyplot as plt +from matplotlib.axes import Axes + +from spd.clustering.consts import DistancesArray, DistancesMethod +from spd.clustering.ensemble_registry import get_clustering_runs +from spd.clustering.math.merge_distances import compute_distances +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.plotting.merge import plot_dists_distribution +from spd.log import logger +from spd.settings import SPD_CACHE_DIR + + +def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: + """Calculate distances between clustering runs in an ensemble. + + Args: + pipeline_run_id: Pipeline run ID to query from registry + distances_method: Method for calculating distances + """ + logger.info(f"Calculating distances for pipeline run: {pipeline_run_id}") + + # Query registry for clustering runs + clustering_runs = get_clustering_runs(pipeline_run_id) + if not clustering_runs: + raise ValueError(f"No clustering runs found for pipeline {pipeline_run_id}") + + logger.info(f"Found {len(clustering_runs)} clustering runs") + + # Load histories from individual clustering run directories + histories: list[MergeHistory] = [] + for idx, clustering_run_id in clustering_runs: + history_path = SPD_CACHE_DIR / "cluster" / clustering_run_id / "history.npz" + if not history_path.exists(): + raise FileNotFoundError( + f"History not found for run {clustering_run_id}: {history_path}" + ) + histories.append(MergeHistory.read(history_path)) + logger.info(f"Loaded history for run {idx}: {clustering_run_id}") + + # Compute normalized ensemble + ensemble: MergeHistoryEnsemble = MergeHistoryEnsemble(data=histories) + merge_array, merge_meta = ensemble.normalized() + + # Get pipeline output directory + pipeline_dir = SPD_CACHE_DIR / "ensemble" / pipeline_run_id + + # Save ensemble metadata and merge array + ensemble_meta_path = pipeline_dir / "ensemble_meta.json" + ensemble_meta_path.write_text(json.dumps(merge_meta, indent=2)) + logger.info(f"Saved ensemble metadata to {ensemble_meta_path}") + + ensemble_array_path = pipeline_dir / "ensemble_merge_array.npz" + np.savez_compressed(ensemble_array_path, merge_array=merge_array) + logger.info(f"Saved ensemble merge array to {ensemble_array_path}") + + # Compute distances + logger.info(f"Computing distances using method: {distances_method}") + distances: DistancesArray = compute_distances( + normalized_merge_array=merge_array, + method=distances_method, + ) + + distances_path = pipeline_dir / f"distances_{distances_method}.npz" + np.savez_compressed(distances_path, distances=distances) + logger.info(f"Distances computed and saved: shape={distances.shape}, path={distances_path}") + + # Create and save distances distribution plot + ax: Axes = plot_dists_distribution( + distances=distances, mode="points", label=f"{distances_method} distances" + ) + plt.title(f"Distance Distribution ({distances_method})") + + # Only add legend if there are labeled artists + handles, _labels = ax.get_legend_handles_labels() + if handles: + plt.legend() + + plots_dir = pipeline_dir / "plots" + plots_dir.mkdir(parents=True, exist_ok=True) + fig_path = plots_dir / f"distances_{distances_method}.png" + plt.savefig(fig_path) + plt.close() + logger.info(f"Saved distances distribution plot to {fig_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Calculate distances between clustering runs") + parser.add_argument( + "--pipeline-run-id", + type=str, + required=True, + help="Pipeline run ID to query from registry", + ) + parser.add_argument( + "--distances-method", + choices=["perm_invariant_hamming", "jaccard"], + default="perm_invariant_hamming", + help="Method for calculating distances", + ) + args = parser.parse_args() + main( + pipeline_run_id=args.pipeline_run_id, + distances_method=args.distances_method, + ) diff --git a/spd/clustering/scripts/main.py b/spd/clustering/scripts/main.py deleted file mode 100644 index 65e224f5e..000000000 --- a/spd/clustering/scripts/main.py +++ /dev/null @@ -1,92 +0,0 @@ -import argparse -from pathlib import Path - -from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.clustering.pipeline.clustering_pipeline import main -from spd.log import logger -from spd.settings import REPO_ROOT - - -def cli() -> None: - """Command-line interface for clustering.""" - - logger.set_format("console", style="terse") - - parser: argparse.ArgumentParser = argparse.ArgumentParser( - description="Run clustering on a dataset using clean architecture" - ) - parser.add_argument( - "--config", - "-c", - type=Path, - required=True, - help="Path to the merge run config JSON/YAML/TOML file", - ) - parser.add_argument( - "--base-path", - "-p", - type=Path, - default=REPO_ROOT / ".data/clustering/", - help="Base path for saving clustering outputs", - ) - parser.add_argument( - "--devices", - "-d", - type=str, - default=None, - help="Comma-separated list of devices to use for clustering (e.g., 'cuda:0,cuda:1')", - ) - parser.add_argument( - "--workers-per-device", - "-x", - type=int, - default=1, - help="Maximum number of concurrent clustering processes per device (default: 1)", - ) - parser.add_argument( - "--dataset-streaming", - action="store_true", - help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", - ) - args: argparse.Namespace = parser.parse_args() - - logger.info("Starting clustering pipeline") - - # Parse devices - devices: list[str] - if args.devices is None: - import torch - - devices = ["cuda" if torch.cuda.is_available() else "cpu"] - logger.info(f"No devices specified, auto-detected: {devices}") - else: - devices = args.devices.split(",") - logger.info(f"Using specified devices: {devices}") - - # Load and augment config - # Note that the defaults for args here always override the default values in `RunConfig` itself, - # but we must have those defaults to avoid type issues - logger.info(f"Loading config from {args.config}") - config: ClusteringRunConfig = ClusteringRunConfig.from_file(args.config) - # Use model_copy to update frozen fields - config = config.model_copy( - update={ - "base_path": args.base_path, - "devices": devices, - "workers_per_device": args.workers_per_device, - "dataset_streaming": args.dataset_streaming, - } - ) - - logger.info(f"Configuration loaded: {config.config_identifier}") - logger.info(f"Base path: {config.base_path}") - logger.info(f"{config.workers_per_device = }, {config.devices = }, {config.n_batches = }") - - # Run - main(config=config) - - logger.info("Clustering pipeline completed successfully") - - -if __name__ == "__main__": - cli() diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py new file mode 100644 index 000000000..7c614407a --- /dev/null +++ b/spd/clustering/scripts/run_clustering.py @@ -0,0 +1,435 @@ +"""Perform a single clustering run. + +This can be run as a standalone script, or called via `spd-cluster` (i.e. clustering/scripts/run_pipeline.py). +If called via spd-cluster, the ensemble-key is passed in to identify the run within the pipeline ensemble. + +Output structure: + / # from execution stamp (run_type="cluster") + ├── clustering_run_config.json + └── history.npz +""" + +import argparse +import gc +import os +import tempfile +from collections.abc import Callable +from functools import partial +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import torch +import wandb +from jaxtyping import Float, Int +from matplotlib.figure import Figure +from torch import Tensor +from wandb.sdk.wandb_run import Run + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.consts import ( + ActivationsTensor, + BatchTensor, + ClusterCoactivationShaped, + ComponentLabels, +) +from spd.clustering.dataset import load_dataset +from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.math.semilog import semilog +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_history import MergeHistory +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration +from spd.clustering.storage import StorageBase +from spd.clustering.wandb_tensor_info import wandb_log_tensor +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import replace_pydantic_model +from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str + +os.environ["WANDB_QUIET"] = "true" + + +class ClusteringRunStorage(StorageBase): + """Storage paths for a single clustering run. + + All paths are relative to ExecutionStamp.out_dir. + """ + + # Relative path constants + _CONFIG = "clustering_run_config.json" + _HISTORY = "history.npz" + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + super().__init__(execution_stamp) + self.config_path: Path = self.base_dir / self._CONFIG + self.history_path: Path = self.base_dir / self._HISTORY + + +LogCallback = Callable[ + [ + ClusterCoactivationShaped, + ComponentLabels, + GroupMerge, + ClusterCoactivationShaped, + MergeHistory, + int, + int, + float, + float, + float, + Float[Tensor, " k_groups"], + ], + None, +] + + +def _log_merge_history_plots(run: Run, history: MergeHistory) -> None: + """Log merge history plots to WandB.""" + fig_cs: Figure = plot_merge_history_cluster_sizes(history=history) + run.log( + {"plots/merge_history_cluster_sizes": wandb.Image(fig_cs)}, + step=history.n_iters_current, + ) + plt.close(fig_cs) + + +def _save_merge_history_artifact( + run: Run, + history_path: Path, + history: MergeHistory, +) -> None: + """Save merge history as WandB artifact.""" + artifact: wandb.Artifact = wandb.Artifact( + name="merge_history", + type="merge_history", + description="Merge history", + metadata={"n_iters_current": history.n_iters_current, "filename": str(history_path)}, + ) + artifact.add_file(str(history_path)) + run.log_artifact(artifact) + + +def _log_callback( + run: Run, + run_config: ClusteringRunConfig, + current_coact: ClusterCoactivationShaped, + component_labels: ComponentLabels, + current_merge: GroupMerge, + costs: ClusterCoactivationShaped, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: Float[Tensor, " k_groups"], +) -> None: + """Callback for logging during merge iteration.""" + if iter_idx % run_config.logging_intervals.stat == 0: + run.log( + { + "k_groups": int(k_groups), + "merge_pair_cost": merge_pair_cost, + "merge_pair_cost_semilog[1e-3]": semilog(merge_pair_cost, epsilon=1e-3), + "mdl_loss": float(mdl_loss), + "mdl_loss_norm": float(mdl_loss_norm), + }, + step=iter_idx, + ) + + if iter_idx % run_config.logging_intervals.tensor == 0: + group_sizes: Int[Tensor, " k_groups"] = current_merge.components_per_group + + tensor_data: dict[str, Tensor] = { + "coactivation": current_coact, + "costs": costs, + "group_sizes": group_sizes, + "group_activations": diag_acts, + "group_activations_over_sizes": ( + diag_acts / group_sizes.to(device=diag_acts.device).float() + ), + } + + fraction_singleton_groups: float = (group_sizes == 1).float().mean().item() + if fraction_singleton_groups > 0: + tensor_data["group_sizes.log1p"] = torch.log1p(group_sizes.float()) + + fraction_zero_coacts: float = (current_coact == 0).float().mean().item() + if fraction_zero_coacts > 0: + tensor_data["coactivation.log1p"] = torch.log1p(current_coact.float()) + + wandb_log_tensor(run, tensor_data, name="iters", step=iter_idx) + + run.log( + { + "fraction_singleton_groups": float(fraction_singleton_groups), + "num_nonsingleton_groups": int((group_sizes > 1).sum().item()), + "fraction_zero_coacts": float(fraction_zero_coacts), + }, + step=iter_idx, + ) + + if iter_idx > 0 and iter_idx % run_config.logging_intervals.artifact == 0: + with tempfile.NamedTemporaryFile() as tmp_file: + file: Path = Path(tmp_file.name) + merge_history.save(file) + artifact: wandb.Artifact = wandb.Artifact( + name=f"merge_hist_iter.iter_{iter_idx}", + type="merge_hist_iter", + description=f"Group indices at iteration {iter_idx}", + metadata={ + "iteration": iter_idx, + "config": merge_history.merge_config.model_dump(mode="json"), + }, + ) + artifact.add_file(str(file)) + run.log_artifact(artifact) + + if iter_idx % run_config.logging_intervals.plot == 0: + fig: Figure = plot_merge_iteration( + current_merge=current_merge, + current_coact=current_coact, + costs=costs, + iteration=iter_idx, + component_labels=component_labels, + show=False, + ) + run.log({"plots/merges": wandb.Image(fig)}, step=iter_idx) + plt.close(fig) + + +def main(run_config: ClusteringRunConfig) -> Path: + """A single clustering run. + + Args: + run_config: Runtime parameters for this clustering run + + Returns: + Path to saved merge history file + """ + # Create ExecutionStamp and storage + # don't create git snapshot -- if we are part of an ensemble, the snapshot should be created by the pipeline + execution_stamp = ExecutionStamp.create( + run_type="cluster", + create_snapshot=False, + ) + storage = ClusteringRunStorage(execution_stamp) + clustering_run_id = execution_stamp.run_id + logger.info(f"Clustering run ID: {clustering_run_id}") + + # Register with ensemble if this is part of a pipeline + if run_config.ensemble_id: + assert run_config.idx_in_ensemble is not None, ( + "idx_in_ensemble must be set when ensemble_id is provided" + ) + register_clustering_run( + run_config.ensemble_id, + run_config.idx_in_ensemble, + clustering_run_id, + ) + logger.info( + f"Registered with pipeline {run_config.ensemble_id} at index {run_config.idx_in_ensemble} in {_ENSEMBLE_REGISTRY_DB}" + ) + + logger.info("Starting clustering run") + logger.info(f"Output directory: {storage.base_dir}") + device = get_device() + + spd_run = SPDRunInfo.from_path(run_config.model_path) + task_name: TaskName = spd_run.config.task_config.task_name + + # 1. Load dataset + logger.info(f"Loading dataset (seed={run_config.dataset_seed})") + load_dataset_kwargs: dict[str, Any] = dict() + if run_config.dataset_streaming: + logger.info("Using streaming dataset loading") + load_dataset_kwargs["config_kwargs"] = dict(streaming=True) + assert task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." + ) + batch: BatchTensor = load_dataset( + model_path=run_config.model_path, + task_name=task_name, + batch_size=run_config.batch_size, + seed=run_config.dataset_seed, + **load_dataset_kwargs, + ) + batch = batch.to(device) + + # 2. Setup WandB for this run + wandb_run: Run | None = None + if run_config.wandb_project is not None: + wandb_run = wandb.init( + entity=run_config.wandb_entity, + project=run_config.wandb_project, + group=run_config.ensemble_id, + config=run_config.model_dump(mode="json"), + tags=[ + "clustering", + f"task:{task_name}", + f"model:{run_config.wandb_decomp_model}", + f"ensemble_id:{run_config.ensemble_id}", + f"idx:{run_config.idx_in_ensemble}", + ], + ) + # logger.info(f"WandB run: {wandb_run.url}") + + # 3. Load model + logger.info("Loading model") + model = ComponentModel.from_run_info(spd_run).to(device) + + # 4. Compute activations + logger.info("Computing activations") + activations_dict: ( + dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] + ) = component_activations( + model=model, + batch=batch, + device=device, + ) + + # 5. Process activations + logger.info("Processing activations") + processed_activations: ProcessedActivations = process_activations( + activations=activations_dict, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + seq_mode="concat" if task_name == "lm" else None, + filter_modules=run_config.merge_config.filter_modules, + ) + + # 6. Log activations (if WandB enabled) + if wandb_run is not None: + logger.info("Plotting activations") + plot_activations( + processed_activations=processed_activations, + save_dir=None, # Don't save to disk, only WandB + n_samples_max=256, + wandb_run=wandb_run, + ) + wandb_log_tensor( + wandb_run, + processed_activations.activations, + "activations", + 0, + single=True, + ) + + # Clean up memory + activations: ActivationsTensor = processed_activations.activations + component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) + del processed_activations + del activations_dict + del model + del batch + gc.collect() + + # 7. Run merge iteration + logger.info("Starting merging") + log_callback: LogCallback | None = ( + partial(_log_callback, run=wandb_run, run_config=run_config) + if wandb_run is not None + else None + ) + + history: MergeHistory = merge_iteration( + merge_config=run_config.merge_config, + activations=activations, + component_labels=component_labels, + log_callback=log_callback, + ) + + # 8. Save merge history and config + run_config.to_file(storage.config_path) + logger.info(f"Config saved to {storage.config_path}") + + history.save(storage.history_path) + logger.info(f"History saved to {storage.history_path}") + + # 9. Log to WandB + if wandb_run is not None: + _log_merge_history_plots(wandb_run, history) + _save_merge_history_artifact(wandb_run, storage.history_path, history) + wandb_run.finish() + logger.info("WandB run finished") + + return storage.history_path + + +def cli() -> None: + """CLI for running a single clustering run.""" + parser = argparse.ArgumentParser(description="Run clustering on a single dataset") + parser.add_argument( + "--config", + type=Path, + required=True, + help="Path to ClusteringRunConfig file", + ) + parser.add_argument( + "--pipeline-run-id", + type=str, + default=None, + help="Pipeline run ID (ensemble identifier). If provided with --idx-in-ensemble, registers run.", + ) + parser.add_argument( + "--idx-in-ensemble", + type=int, + default=None, + help="Index of this run in the ensemble", + ) + parser.add_argument( + "--wandb-project", + type=read_noneable_str, + default=_NO_ARG_PARSSED_SENTINEL, + help="WandB project name (if not provided, WandB logging is disabled)", + ) + parser.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity name (user or team)", + ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset)", + ) + + args: argparse.Namespace = parser.parse_args() + + # Load base config + run_config = ClusteringRunConfig.from_file(args.config) + + # Override config values from CLI + overrides: dict[str, Any] = { + "dataset_streaming": args.dataset_streaming, + } + + # Handle ensemble-related overrides + if args.idx_in_ensemble is not None: + overrides["dataset_seed"] = run_config.dataset_seed + args.idx_in_ensemble + overrides["idx_in_ensemble"] = args.idx_in_ensemble + if args.pipeline_run_id is not None: + overrides["ensemble_id"] = args.pipeline_run_id + + if args.wandb_project is not _NO_ARG_PARSSED_SENTINEL: + overrides["wandb_project"] = args.wandb_project + if args.wandb_entity is not None: + overrides["wandb_entity"] = args.wandb_entity + + run_config = replace_pydantic_model(run_config, overrides) + + # Run clustering + main(run_config) + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py new file mode 100644 index 000000000..7b6af0e82 --- /dev/null +++ b/spd/clustering/scripts/run_pipeline.py @@ -0,0 +1,375 @@ +"""Submit clustering runs to SLURM as separate jobs in a SLURM array. + +This script submits independent clustering runs as a SLURM job array, +where each run gets its own dataset (seeded), WandB run, and merge history output. + +Also submits a job to calculate distances between the clustering runs, which will run after +the clustering runs (the SLURM job depends on the previous array job). + +Output structure (only pipeline_config.json is saved to directly in this script. The files under + are saved by run_clustering.py which is called in SLURM jobs deployed by this script.): + / # from execution stamp + |── pipeline_config.json # Saved in this script + |── clustering_run_config.json # make copy of the file pointed to by pipeline config + ├── ensemble_meta.json # (Saved by calc_distances.py) Ensemble metadata + ├── ensemble_merge_array.npz # (Saved by calc_distances.py) Normalized merge array + ├── distances_.npz # (Saved by calc_distances.py) Distance array for each method + └── distances_.png # (Saved by calc_distances.py) Distance distribution plot +""" + +import argparse +import os +import shlex +import subprocess +import tempfile +from pathlib import Path +from typing import Any + +import wandb_workspaces.workspaces as ws +from pydantic import Field, PositiveInt + +from spd.base_config import BaseConfig +from spd.clustering.consts import DistancesMethod +from spd.clustering.storage import StorageBase +from spd.log import logger +from spd.utils.command_utils import run_script_array_local +from spd.utils.general_utils import replace_pydantic_model +from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str +from spd.utils.slurm_utils import ( + create_slurm_array_script, + create_slurm_script, + submit_slurm_script, +) + +os.environ["WANDB_QUIET"] = "true" + + +class ClusteringPipelineStorage(StorageBase): + """Storage paths for clustering pipeline (ensemble). + + All paths are relative to ExecutionStamp.out_dir. + """ + + # Relative path constants + _PIPELINE_CONFIG = "pipeline_config.yaml" + _RUN_IDS = "run_ids.json" + _ENSEMBLE_META = "ensemble_meta.json" + _ENSEMBLE_MERGE_ARRAY = "ensemble_merge_array.npz" + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + super().__init__(execution_stamp) + self.pipeline_config_path: Path = self.base_dir / self._PIPELINE_CONFIG + self.run_ids_path: Path = self.base_dir / self._RUN_IDS + self.ensemble_meta_path: Path = self.base_dir / self._ENSEMBLE_META + self.ensemble_merge_array_path: Path = self.base_dir / self._ENSEMBLE_MERGE_ARRAY + + def distances_path(self, method: DistancesMethod) -> Path: + return self.base_dir / f"distances_{method}.npz" + + +class ClusteringPipelineConfig(BaseConfig): + """Configuration for submitting an ensemble of clustering runs to SLURM.""" + + run_clustering_config_path: Path = Field(description="Path to ClusteringRunConfig file.") + n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") + distances_method: DistancesMethod = Field(description="Method to use for calculating distances") + base_output_dir: Path = Field(description="Base directory for outputs of clustering runs.") + slurm_job_name_prefix: str | None = Field(description="Prefix for SLURM job names") + slurm_partition: str | None = Field(description="SLURM partition to use") + wandb_project: str | None = Field( + default=None, + description="Weights & Biases project name (set to None to disable WandB logging)", + ) + wandb_entity: str = Field(description="WandB entity (team/user) name") + create_git_snapshot: bool = Field(description="Create a git snapshot for the run") + + +def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: + """Create WandB workspace view for clustering runs. + + TODO: Use a template workspace which actually shows some panels + TODO: since the run_id here is the same as the wandb id, can we take advantage of that? + + Args: + ensemble_id: Unique identifier for this ensemble + project: WandB project name + entity: WandB entity (team/user) name + + Returns: + URL to workspace view + """ + workspace = ws.Workspace(entity=entity, project=project) + workspace.name = f"Clustering - {ensemble_id}" + + workspace.runset_settings.filters = [ + ws.Tags("tags").isin([f"ensemble_id:{ensemble_id}"]), + ] + + try: + workspace.save_as_new_view() + return workspace.url + except Exception as e: + logger.warning( + f"Failed to create WandB workspace view: {workspace=}, {workspace.name=}, {ensemble_id=}, {project=}, {entity=}, {e}" + ) + raise e + + +def generate_clustering_commands( + pipeline_config: ClusteringPipelineConfig, + pipeline_run_id: str, + dataset_streaming: bool = False, +) -> list[str]: + """Generate commands for each clustering run. + + Args: + pipeline_config: Pipeline configuration + pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) + dataset_streaming: Whether to use dataset streaming + + Returns: + List of shell-safe command strings + """ + commands: list[str] = [] + + for idx in range(pipeline_config.n_runs): + cmd_parts = [ + "python", + "spd/clustering/scripts/run_clustering.py", + "--config", + pipeline_config.run_clustering_config_path.as_posix(), + "--pipeline-run-id", + pipeline_run_id, + "--idx-in-ensemble", + str(idx), + "--wandb-project", + str(pipeline_config.wandb_project), + "--wandb-entity", + pipeline_config.wandb_entity, + ] + if dataset_streaming: + cmd_parts.append("--dataset-streaming") + + commands.append(shlex.join(cmd_parts)) + + return commands + + +def generate_calc_distances_command(pipeline_run_id: str, distances_method: DistancesMethod) -> str: + """Generate command for calculating distances. + + Args: + pipeline_run_id: Pipeline run ID (will query registry for clustering runs) + distances_method: Method for calculating distances + """ + return shlex.join( + [ + "python", + "spd/clustering/scripts/calc_distances.py", + "--pipeline-run-id", + pipeline_run_id, + "--distances-method", + distances_method, + ] + ) + + +def main( + pipeline_config: ClusteringPipelineConfig, + local: bool = False, + dataset_streaming: bool = False, +) -> None: + """Submit clustering runs to SLURM. + + Args: + pipeline_config_path: Path to ClusteringPipelineConfig file + n_runs: Number of clustering runs in the ensemble. Will override value in the config file. + """ + # setup + # ========================================================================================== + + logger.set_format("console", "terse") + + # Create ExecutionStamp for pipeline + execution_stamp: ExecutionStamp = ExecutionStamp.create( + run_type="ensemble", + create_snapshot=pipeline_config.create_git_snapshot, + ) + pipeline_run_id: str = execution_stamp.run_id + logger.info(f"Pipeline run ID: {pipeline_run_id}") + + # Initialize storage + storage = ClusteringPipelineStorage(execution_stamp) + logger.info(f"Pipeline output directory: {storage.base_dir}") + + # Save pipeline config + pipeline_config.to_file(storage.pipeline_config_path) + logger.info(f"Pipeline config saved to {storage.pipeline_config_path}") + + # Create WandB workspace if requested + if pipeline_config.wandb_project is not None: + workspace_url = create_clustering_workspace_view( + ensemble_id=pipeline_run_id, + project=pipeline_config.wandb_project, + entity=pipeline_config.wandb_entity, + ) + logger.info(f"WandB workspace: {workspace_url}") + + # Generate commands for clustering runs + clustering_commands = generate_clustering_commands( + pipeline_config=pipeline_config, + pipeline_run_id=pipeline_run_id, + dataset_streaming=dataset_streaming, + ) + + # Generate command for calculating distances + calc_distances_command = generate_calc_distances_command( + pipeline_run_id=pipeline_run_id, + distances_method=pipeline_config.distances_method, + ) + + # Submit to SLURM + if local: + # submit clustering array job + run_script_array_local( + commands=clustering_commands, + ) + + # submit calc_distances job + logger.info("Calculating distances...") + logger.info(f"Command: {calc_distances_command}") + subprocess.run(shlex.split(calc_distances_command), shell=False, check=True) + + logger.section("complete!") + distances_plot_path = ( + storage.plots_dir / f"distances_{pipeline_config.distances_method}.png" + ) + logger.values( + { + "Total clustering runs": len(clustering_commands), + "Pipeline run ID": pipeline_run_id, + "Pipeline output dir": str(storage.base_dir), + "Distances plot": str(distances_plot_path), + } + ) + + else: + assert pipeline_config.slurm_job_name_prefix is not None, ( + "must specify slurm_job_name_prefix if not running locally" + ) + assert pipeline_config.slurm_partition is not None, ( + "must specify slurm_partition if not running locally" + ) + with tempfile.TemporaryDirectory() as temp_dir: + # Submit clustering array job + clustering_script_path = Path(temp_dir) / f"clustering_{pipeline_run_id}.sh" + + create_slurm_array_script( + script_path=clustering_script_path, + job_name=f"{pipeline_config.slurm_job_name_prefix}_cluster", + commands=clustering_commands, + snapshot_branch=execution_stamp.snapshot_branch, + max_concurrent_tasks=pipeline_config.n_runs, # Run all concurrently + n_gpus_per_job=1, # Always 1 GPU per run + partition=pipeline_config.slurm_partition, + ) + array_job_id = submit_slurm_script(clustering_script_path) + + # Submit calc_distances job with dependency on array job + calc_distances_script_path = Path(temp_dir) / f"calc_distances_{pipeline_run_id}.sh" + + create_slurm_script( + script_path=calc_distances_script_path, + job_name=f"{pipeline_config.slurm_job_name_prefix}_distances", + command=calc_distances_command, + snapshot_branch=execution_stamp.snapshot_branch, + n_gpus=1, # Always 1 GPU for distances calculation + partition=pipeline_config.slurm_partition, + dependency_job_id=array_job_id, + ) + calc_distances_job_id = submit_slurm_script(calc_distances_script_path) + + logger.section("Jobs submitted successfully!") + distances_plot_path = ( + storage.plots_dir / f"distances_{pipeline_config.distances_method}.png" + ) + logger.values( + { + "Clustering Array Job ID": array_job_id, + "Calc Distances Job ID": calc_distances_job_id, + "Total clustering runs": len(clustering_commands), + "Pipeline run ID": pipeline_run_id, + "Pipeline output dir": str(storage.base_dir), + "Clustering logs": f"~/slurm_logs/slurm-{array_job_id}_*.out", + "Calc Distances log": f"~/slurm_logs/slurm-{calc_distances_job_id}.out", + "Distances plot will be saved to": str(distances_plot_path), + } + ) + + +def cli(): + """CLI for spd-cluster command.""" + parser = argparse.ArgumentParser( + prog="spd-cluster", + description="Submit clustering runs to SLURM. Arguments specified here will override the " + "corresponding value in the config file.", + ) + + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to pipeline config file", + ) + parser.add_argument( + "--n-runs", + type=int, + help="Number of clustering runs in the ensemble (overrides value in config file)", + ) + parser.add_argument( + "--wandb-project", + type=read_noneable_str, + default=_NO_ARG_PARSSED_SENTINEL, + help="WandB project name (if not provided, WandB logging is disabled)", + ) + parser.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity name (user or team)", + ) + parser.add_argument( + "--local", + action=argparse.BooleanOptionalAction, + default=False, + help="Run locally instead of submitting to SLURM (required if slurm_job_name_prefix and slurm_partition are None in config)", + ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) + + args = parser.parse_args() + + pipeline_config = ClusteringPipelineConfig.from_file(args.config) + overrides: dict[str, Any] = {} + + if args.n_runs is not None: + overrides["n_runs"] = args.n_runs + if args.wandb_project is not _NO_ARG_PARSSED_SENTINEL: + overrides["wandb_project"] = args.wandb_project + if args.wandb_entity is not None: + overrides["wandb_entity"] = args.wandb_entity + + pipeline_config = replace_pydantic_model(pipeline_config, overrides) + + main( + pipeline_config=pipeline_config, + local=args.local, + dataset_streaming=args.dataset_streaming, + ) + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/storage.py b/spd/clustering/storage.py new file mode 100644 index 000000000..dc3d8765a --- /dev/null +++ b/spd/clustering/storage.py @@ -0,0 +1,19 @@ +"""Minimal storage base class for clustering - just path management.""" + +from pathlib import Path + +from spd.utils.run_utils import ExecutionStamp + + +class StorageBase: + """Base class for storage - provides ExecutionStamp and base directory. + + Subclasses define path constants (relative to base_dir) and set absolute paths in __init__. + Caller handles all actual saving and WandB uploading. + """ + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + """Initialize storage with execution stamp.""" + self.execution_stamp: ExecutionStamp = execution_stamp + self.base_dir: Path = execution_stamp.out_dir + self.plots_dir: Path = self.base_dir / "plots" diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 399b08783..f93836ca1 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -43,7 +43,7 @@ def main( logger.info(f"Using device: {device}") if config.wandb_project: - tags = ["ih"] + tags = ["induction_head"] if evals_id: tags.append(evals_id) if sweep_id: diff --git a/spd/identity_insertion.py b/spd/identity_insertion.py index 2859b7700..dcad69e81 100644 --- a/spd/identity_insertion.py +++ b/spd/identity_insertion.py @@ -42,12 +42,8 @@ def insert_identity_operations_(target_model: nn.Module, identity_patterns: list identity_patterns: Patterns matching modules to prepend identity ops to """ - if is_main_process(): - logger.info(f"Inserting identity operations before {len(identity_patterns)} modules") - identity_module_paths = get_target_module_paths(target_model, identity_patterns) - # Add identity layers and hooks for module_path in identity_module_paths: module = target_model.get_submodule(module_path) @@ -61,7 +57,7 @@ def insert_identity_operations_(target_model: nn.Module, identity_patterns: list case _: raise ValueError(f"Module {module} not supported. type: {type(module)}") - module.pre_identity = Identity(d_in) # type: ignore + module.pre_identity = Identity(d_in) module.register_forward_pre_hook(pre_id_hook, with_kwargs=True) if is_main_process(): diff --git a/spd/scripts/run.py b/spd/scripts/run.py index 463ab44ad..2d281e4bb 100644 --- a/spd/scripts/run.py +++ b/spd/scripts/run.py @@ -11,9 +11,7 @@ import copy import json import shlex -import subprocess import tempfile -from datetime import datetime from hashlib import sha256 from pathlib import Path from typing import Any, Final @@ -24,17 +22,17 @@ from spd.log import LogFormat, logger from spd.registry import EXPERIMENT_REGISTRY, get_max_expected_runtime from spd.settings import REPO_ROOT -from spd.utils.git_utils import create_git_snapshot, repo_current_branch -from spd.utils.run_utils import apply_nested_updates, generate_grid_combinations, generate_run_name -from spd.utils.slurm_utils import create_slurm_array_script, submit_slurm_array +from spd.utils.command_utils import run_script_array_local +from spd.utils.run_utils import ( + ExecutionStamp, + apply_nested_updates, + generate_grid_combinations, + generate_run_name, +) +from spd.utils.slurm_utils import create_slurm_array_script, submit_slurm_script from spd.utils.wandb_utils import wandb_setup -def generate_run_id() -> str: - """Generate a unique run ID based on timestamp.""" - return f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - - def resolve_sweep_params_path(sweep_params_file: str) -> Path: """Resolve the full path to the sweep parameters file.""" if "/" not in sweep_params_file: @@ -114,12 +112,6 @@ def _choose_master_port(run_id_local: str, idx: int) -> int: return base + (h % span) -def _build_mpi_prefix(run_id: str, idx: int, dp: int) -> str: - """Build an MPI prefix for a command.""" - port: int = _choose_master_port(run_id, idx) - return f"MASTER_PORT={port} mpirun -x MASTER_PORT -np {dp} " - - def generate_commands( experiments_list: list[str], run_id: str, @@ -147,7 +139,6 @@ def generate_commands( for experiment in experiments_list: exp_config = EXPERIMENT_REGISTRY[experiment] - # Load base config base_config = Config.from_file(exp_config.config_path) if sweep_params_path is None: @@ -158,14 +149,24 @@ def generate_commands( config_json = f"json:{json.dumps(config_with_overrides.model_dump(mode='json'))}" - mpi_prefix = _build_mpi_prefix(run_id, cmd_idx, dp) if dp > 1 else "" - - command = ( - f"{mpi_prefix}python {exp_config.decomp_script} --config_json '{config_json}' " - f"--sweep_id {run_id} --evals_id {experiment}" - ) + cmd_parts = [ + "python", + str(exp_config.decomp_script), + "--config_json", + config_json, + "--sweep_id", + run_id, + "--evals_id", + experiment, + ] + + if dp > 1: + port = _choose_master_port(run_id, cmd_idx) + cmd = f"MASTER_PORT={port} {shlex.join(['mpirun', '-x', 'MASTER_PORT', '-np', str(dp)] + cmd_parts)}" + else: + cmd = shlex.join(cmd_parts) - commands.append(command) + commands.append(cmd) task_breakdown[experiment] = "1 task" cmd_idx += 1 @@ -186,15 +187,26 @@ def generate_commands( config_json = f"json:{json.dumps(config_with_overrides.model_dump(mode='json'))}" sweep_params_json = f"json:{json.dumps(sweep_params)}" - mpi_prefix = _build_mpi_prefix(run_id, cmd_idx, dp) if dp > 1 else "" - command = ( - f"{mpi_prefix}python {exp_config.decomp_script} --config_json '{config_json}' " - f"--sweep_id {run_id} " - f"--evals_id {experiment} " - f"--sweep_params_json '{sweep_params_json}'" - ) - - commands.append(command) + cmd_parts = [ + "python", + str(exp_config.decomp_script), + "--config_json", + config_json, + "--sweep_id", + run_id, + "--evals_id", + experiment, + "--sweep_params_json", + sweep_params_json, + ] + + if dp > 1: + port = _choose_master_port(run_id, cmd_idx) + cmd = f'MASTER_PORT={port} mpirun -x "MASTER_PORT" -np {dp} {shlex.join(cmd_parts)}' + else: + cmd = shlex.join(cmd_parts) + + commands.append(cmd) cmd_idx += 1 # Print first combination as example @@ -208,35 +220,6 @@ def generate_commands( return commands -def run_commands_locally(commands: list[str]) -> None: - """Execute commands locally in sequence. - - Args: - commands: List of shell commands to execute - """ - - logger.section(f"LOCAL EXECUTION: Running {len(commands)} tasks") - - for i, command in enumerate(commands, 1): - # Parse command into arguments - args = shlex.split(command) - - # Extract experiment name from script path for cleaner output - script_name = args[1].split("/")[-1] - logger.section(f"[{i}/{len(commands)}] Executing: {script_name}...") - - result = subprocess.run(args) - - if result.returncode != 0: - logger.warning( - f"[{i}/{len(commands)}] ⚠️ Warning: Command failed with exit code {result.returncode}" - ) - else: - logger.info(f"[{i}/{len(commands)}] ✓ Completed successfully") - - logger.section("LOCAL EXECUTION COMPLETE") - - def get_experiments( experiments: str | None = None, ) -> list[str]: @@ -338,7 +321,11 @@ def main( logger.set_format("console", log_format) # Determine run id - run_id: str = generate_run_id() + execution_stamp: ExecutionStamp = ExecutionStamp.create( + run_type="spd", + create_snapshot=create_snapshot, + ) + run_id: str = execution_stamp.run_id logger.info(f"Run ID: {run_id}") # Determine the sweep parameters file @@ -365,18 +352,6 @@ def main( # ========================================================================================== if not local or use_wandb: - # set up snapshot branch and commit hash - snapshot_branch: str - commit_hash: str - - if create_snapshot: - snapshot_branch, commit_hash = create_git_snapshot(branch_name_prefix="run") - logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") - else: - snapshot_branch = repo_current_branch() - commit_hash = "none" - logger.info(f"Using current branch: {snapshot_branch}") - # set up wandb if use_wandb: wandb_setup( @@ -386,8 +361,8 @@ def main( create_report=create_report, # if `create_report == False`, the rest of the arguments don't matter report_title=report_title, - snapshot_branch=snapshot_branch, - commit_hash=commit_hash, + snapshot_branch=execution_stamp.snapshot_branch, + commit_hash=execution_stamp.commit_hash, include_run_comparer=sweep_params_file is not None, ) else: @@ -410,7 +385,7 @@ def main( ) if local: - run_commands_locally(commands) + run_script_array_local(commands) else: # Submit to SLURM with tempfile.TemporaryDirectory() as temp_dir: @@ -427,14 +402,13 @@ def main( script_path=array_script, job_name=job_name, commands=commands, - # again -- local is false, so snapshot_branch will exist - snapshot_branch=snapshot_branch, # pyright: ignore[reportPossiblyUnboundVariable] + snapshot_branch=execution_stamp.snapshot_branch, max_concurrent_tasks=n_agents, n_gpus_per_job=n_gpus_per_job, partition=partition, ) - array_job_id = submit_slurm_array(array_script) + array_job_id = submit_slurm_script(array_script) logger.section("Job submitted successfully!") logger.values( diff --git a/spd/spd_types.py b/spd/spd_types.py index 07249f7f7..d348f5e8f 100644 --- a/spd/spd_types.py +++ b/spd/spd_types.py @@ -1,7 +1,8 @@ from pathlib import Path from typing import Annotated, Literal -from pydantic import BeforeValidator, Field, PlainSerializer +from annotated_types import Ge, Le +from pydantic import BeforeValidator, PlainSerializer from spd.settings import REPO_ROOT @@ -45,6 +46,6 @@ def validate_path(v: str | Path) -> str | Path: ] -Probability = Annotated[float, Field(strict=True, ge=0, le=1)] +Probability = Annotated[float, Ge(0), Le(1)] -TaskName = Literal["tms", "resid_mlp", "lm", "ih"] +TaskName = Literal["tms", "resid_mlp", "lm", "induction_head"] diff --git a/spd/utils/command_utils.py b/spd/utils/command_utils.py new file mode 100644 index 000000000..6b79ad6ed --- /dev/null +++ b/spd/utils/command_utils.py @@ -0,0 +1,37 @@ +"""Minimal utilities for running shell-safe commands locally.""" + +import shlex +import subprocess + +from spd.log import logger + + +def run_script_array_local(commands: list[str], parallel: bool = False) -> None: + """Run multiple shell-safe command strings locally. + + Args: + commands: List of shell-safe command strings (built with shlex.join()) + parallel: If True, run all commands in parallel. If False, run sequentially. + """ + n_commands = len(commands) + + if not parallel: + logger.section(f"LOCAL EXECUTION: Running {n_commands} tasks serially") + for i, cmd in enumerate(commands, 1): + logger.info(f"[{i}/{n_commands}] Running: {cmd}") + subprocess.run(shlex.split(cmd), shell=False, check=True) + logger.section("LOCAL EXECUTION COMPLETE") + else: + logger.section(f"LOCAL EXECUTION: Starting {n_commands} tasks in parallel") + procs: list[subprocess.Popen[bytes]] = [] + for i, cmd in enumerate(commands, 1): + logger.info(f"[{i}/{n_commands}] Starting: {cmd}") + proc = subprocess.Popen(shlex.split(cmd), shell=False) + procs.append(proc) + + logger.section("WAITING FOR ALL TASKS TO COMPLETE") + for proc in procs: + proc.wait() + if proc.returncode != 0: + logger.error(f"Process {proc.pid} failed with exit code {proc.returncode}") + logger.section("LOCAL EXECUTION COMPLETE") diff --git a/spd/utils/git_utils.py b/spd/utils/git_utils.py index d21bf240e..b9c0cf370 100644 --- a/spd/utils/git_utils.py +++ b/spd/utils/git_utils.py @@ -1,6 +1,5 @@ """Git utilities for creating code snapshots.""" -import datetime import subprocess import tempfile from pathlib import Path @@ -30,7 +29,32 @@ def repo_current_branch() -> str: return result.stdout.strip() -def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: +def repo_is_clean(catch_except_as_false: bool = False) -> bool: + """Return True if the current git repository has no uncommitted or untracked changes. + + # TODO: this may error in CI environments: https://github.com/goodfire-ai/spd/actions/runs/18560369066/job/52907611203 + `fatal: detected dubious ownership in repository at '/__w/spd/spd'` + + for now, if `catch_except_as_false` is True, we catch any exceptions and return False. + + """ + try: + status: str = subprocess.check_output(["git", "status", "--porcelain"], text=True).strip() + return status == "" + except Exception as e: + if catch_except_as_false: + return False + else: + raise e + + +def repo_current_commit_hash() -> str: + """Return the current commit hash of the active HEAD.""" + commit_hash: str = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + return commit_hash + + +def create_git_snapshot(run_id: str) -> tuple[str, str]: """Create a git snapshot branch with current changes. Creates a timestamped branch containing all current changes (staged and unstaged). Uses a @@ -44,13 +68,12 @@ def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: Raises: subprocess.CalledProcessError: If git commands fail (except for push) """ - # Generate timestamped branch name - timestamp_utc = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") - snapshot_branch = f"{branch_name_prefix}-{timestamp_utc}" + # prefix branch name + snapshot_branch: str = f"snapshot/{run_id}" # Create temporary worktree path with tempfile.TemporaryDirectory() as temp_dir: - worktree_path = Path(temp_dir) / f"spd-snapshot-{timestamp_utc}" + worktree_path = Path(temp_dir) / f"spd-snapshot-{run_id}" try: # Create worktree with new branch @@ -87,7 +110,7 @@ def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: # Commit changes if any exist if diff_result.returncode != 0: # Non-zero means there are changes subprocess.run( - ["git", "commit", "-m", f"Sweep snapshot {timestamp_utc}", "--no-verify"], + ["git", "commit", "-m", f"run id {run_id}", "--no-verify"], cwd=worktree_path, check=True, capture_output=True, diff --git a/spd/utils/run_utils.py b/spd/utils/run_utils.py index b2a16827d..4465d6b28 100644 --- a/spd/utils/run_utils.py +++ b/spd/utils/run_utils.py @@ -6,13 +6,20 @@ import secrets import string from pathlib import Path -from typing import Any +from typing import Any, Final, Literal, NamedTuple import torch import wandb import yaml +from spd.log import logger from spd.settings import SPD_CACHE_DIR +from spd.utils.git_utils import ( + create_git_snapshot, + repo_current_branch, + repo_current_commit_hash, + repo_is_clean, +) # Fields that use discriminated union merging: field_name -> discriminator_field _DISCRIMINATED_LIST_FIELDS: dict[str, str] = { @@ -37,6 +44,7 @@ def get_local_run_id() -> str: return f"local-{random_suffix}" +# TODO: avoid using this function? def get_output_dir(use_wandb_id: bool = True) -> Path: """Get the output directory for a run. @@ -462,3 +470,89 @@ def generate_run_name(params: dict[str, Any]) -> str: parts.append(f"{param}-{value}") return "-".join(parts) + + +RunType = Literal["spd", "cluster", "ensemble"] + +RUN_TYPE_ABBREVIATIONS: Final[dict[RunType, str]] = { + "spd": "s", + "cluster": "c", + "ensemble": "e", +} + + +# TODO: This doesnt work in pytest but would in general be nice to enforce. hmm. +# _CREATED_RUN_ID: bool = False + + +class ExecutionStamp(NamedTuple): + run_id: str + snapshot_branch: str + commit_hash: str + run_type: RunType + + @staticmethod + def _generate_run_id(run_type: RunType) -> str: + """Generate a unique run identifier, + + Format: `{type_abbr}-{random_hex}` + """ + # global _CREATED_RUN_ID + # if _CREATED_RUN_ID: + # raise RuntimeError( + # "Run ID has already been generated for this process! You can only call this once." + # ) + type_abbr: str = RUN_TYPE_ABBREVIATIONS[run_type] + random_hex: str = secrets.token_hex(4) + # _CREATED_RUN_ID = True + return f"{type_abbr}-{random_hex}" + + @classmethod + def create( + cls, + run_type: RunType, + create_snapshot: bool, + ) -> "ExecutionStamp": + """create an execution stamp, possibly including a git snapshot branch""" + + run_id: str = ExecutionStamp._generate_run_id(run_type) + snapshot_branch: str + commit_hash: str + + if create_snapshot: + snapshot_branch, commit_hash = create_git_snapshot(run_id=run_id) + logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") + else: + snapshot_branch = repo_current_branch() + if repo_is_clean(catch_except_as_false=True): + commit_hash = repo_current_commit_hash() + logger.info(f"Using current branch: {snapshot_branch} ({commit_hash[:8]})") + else: + commit_hash = "none" + logger.info( + f"Using current branch: {snapshot_branch} (unpushed changes, no commit hash)" + ) + + return ExecutionStamp( + run_id=run_id, + snapshot_branch=snapshot_branch, + commit_hash=commit_hash, + run_type=run_type, + ) + + @property + def out_dir(self) -> Path: + """Get the output directory for this execution stamp.""" + run_dir = SPD_CACHE_DIR / self.run_type / self.run_id + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +_NO_ARG_PARSSED_SENTINEL = object() + + +def read_noneable_str(value: str) -> str | None: + """Read a string that may be 'None' and convert to None.""" + if value == "None": + return None + return value diff --git a/spd/utils/slurm_utils.py b/spd/utils/slurm_utils.py index a5a9426ef..b9290061f 100644 --- a/spd/utils/slurm_utils.py +++ b/spd/utils/slurm_utils.py @@ -22,57 +22,42 @@ def format_runtime_str(runtime_minutes: int) -> str: return f"{hours}h{minutes}m" if hours > 0 else f"{minutes}m" -def create_slurm_array_script( +def _create_slurm_script_base( script_path: Path, job_name: str, - commands: list[str], snapshot_branch: str, - n_gpus_per_job: int, + n_gpus: int, partition: str, - time_limit: str = "72:00:00", - max_concurrent_tasks: int | None = None, + time_limit: str, + sbatch_directives: str, + work_dir_suffix: str, + command_block: str, ) -> None: - """Create a SLURM job array script with git snapshot for consistent code. + """Create a SLURM script with git snapshot for consistent code. Args: script_path: Path where the script should be written - job_name: Name for the SLURM job array - commands: List of commands to execute in each array job - snapshot_branch: Git branch to checkout. - n_gpus_per_job: Number of GPUs per job. If 0, use CPU jobs. - time_limit: Time limit for each job (default: 72:00:00) - max_concurrent_tasks: Maximum number of array tasks to run concurrently. If None, no limit. + job_name: Name for the SLURM job + snapshot_branch: Git branch to checkout + n_gpus: Number of GPUs. If 0, use CPU jobs. + partition: SLURM partition to use + time_limit: Time limit for the job + sbatch_directives: Additional SBATCH directives (e.g. --array, --dependency, --output) + work_dir_suffix: Suffix for the working directory (e.g. "${SLURM_JOB_ID}" or "${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}") + command_block: The command(s) to execute """ - - slurm_logs_dir = Path.home() / "slurm_logs" - slurm_logs_dir.mkdir(exist_ok=True) - - # Create array range (SLURM arrays are 1-indexed) - if max_concurrent_tasks is not None: - array_range = f"1-{len(commands)}%{max_concurrent_tasks}" - else: - array_range = f"1-{len(commands)}" - - # Create case statement for commands - case_statements = [] - for i, command in enumerate(commands, 1): - case_statements.append(f"{i}) {command} ;;") - - case_block = "\n ".join(case_statements) - script_content = textwrap.dedent(f""" #!/bin/bash #SBATCH --nodes=1 - #SBATCH --gres=gpu:{n_gpus_per_job} + #SBATCH --gres=gpu:{n_gpus} #SBATCH --partition={partition} #SBATCH --time={time_limit} #SBATCH --job-name={job_name} - #SBATCH --array={array_range} + {sbatch_directives} #SBATCH --distribution=pack - #SBATCH --output={slurm_logs_dir}/slurm-%A_%a.out # Create job-specific working directory - WORK_DIR="/tmp/spd-gf-copy-${{SLURM_ARRAY_JOB_ID}}_${{SLURM_ARRAY_TASK_ID}}" + WORK_DIR="/tmp/spd-gf-copy-{work_dir_suffix}" # Clone the repository to the job-specific directory git clone {REPO_ROOT} $WORK_DIR @@ -93,38 +78,123 @@ def create_slurm_array_script( uv sync --no-dev --link-mode copy -q source .venv/bin/activate - # Execute the appropriate command based on array task ID - case $SLURM_ARRAY_TASK_ID in - {case_block} - esac + {command_block} """).strip() with open(script_path, "w") as f: f.write(script_content) - # Make script executable script_path.chmod(0o755) -def submit_slurm_array(script_path: Path) -> str: - """Submit a SLURM job array and return the array job ID. +def create_slurm_array_script( + script_path: Path, + job_name: str, + commands: list[str], + snapshot_branch: str, + n_gpus_per_job: int, + partition: str, + time_limit: str = "72:00:00", + max_concurrent_tasks: int | None = None, +) -> None: + """Create a SLURM job array script with git snapshot for consistent code. Args: - script_path: Path to SLURM batch script + script_path: Path where the script should be written + job_name: Name for the SLURM job array + commands: List of shell-safe command strings (built with shlex.join()) + snapshot_branch: Git branch to checkout. + n_gpus_per_job: Number of GPUs per job. If 0, use CPU jobs. + partition: SLURM partition to use + time_limit: Time limit for each job (default: 72:00:00) + max_concurrent_tasks: Maximum number of array tasks to run concurrently. If None, no limit. + """ + slurm_logs_dir = Path.home() / "slurm_logs" + slurm_logs_dir.mkdir(exist_ok=True) - Returns: - Array job ID from submitted job array + # Create array range (SLURM arrays are 1-indexed) + if max_concurrent_tasks is not None: + array_range = f"1-{len(commands)}%{max_concurrent_tasks}" + else: + array_range = f"1-{len(commands)}" + + # Create case statement for commands + case_statements = [] + for i, cmd in enumerate(commands, 1): + case_statements.append(f"{i}) {cmd} ;;") + + case_block = "\n ".join(case_statements) + + sbatch_directives = f"""#SBATCH --array={array_range} + #SBATCH --output={slurm_logs_dir}/slurm-%A_%a.out""" + + command_block = f"""# Execute the appropriate command based on array task ID + case $SLURM_ARRAY_TASK_ID in + {case_block} + esac""" + + _create_slurm_script_base( + script_path=script_path, + job_name=job_name, + snapshot_branch=snapshot_branch, + n_gpus=n_gpus_per_job, + partition=partition, + time_limit=time_limit, + sbatch_directives=sbatch_directives, + work_dir_suffix="${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}", + command_block=command_block, + ) + + +def create_slurm_script( + script_path: Path, + job_name: str, + command: str, + snapshot_branch: str, + n_gpus: int, + partition: str, + time_limit: str = "72:00:00", + dependency_job_id: str | None = None, +) -> None: + """Create a SLURM job script with git snapshot for consistent code. + + Args: + script_path: Path where the script should be written + job_name: Name for the SLURM job + command: Shell-safe command string (built with shlex.join()) + snapshot_branch: Git branch to checkout + n_gpus: Number of GPUs. If 0, use CPU job. + partition: SLURM partition to use + time_limit: Time limit for the job (default: 72:00:00) + dependency_job_id: Optional job ID to depend on (uses afterok) """ - result = subprocess.run( - ["sbatch", str(script_path)], capture_output=True, text=True, check=True + slurm_logs_dir = Path.home() / "slurm_logs" + slurm_logs_dir.mkdir(exist_ok=True) + + # Build SBATCH directives + directives = [f"#SBATCH --output={slurm_logs_dir}/slurm-%j.out"] + if dependency_job_id is not None: + directives.append(f"#SBATCH --dependency=afterok:{dependency_job_id}") + + sbatch_directives = "\n ".join(directives) + + command_block = f"# Execute the command\n {command}" + + _create_slurm_script_base( + script_path=script_path, + job_name=job_name, + snapshot_branch=snapshot_branch, + n_gpus=n_gpus, + partition=partition, + time_limit=time_limit, + sbatch_directives=sbatch_directives, + work_dir_suffix="${SLURM_JOB_ID}", + command_block=command_block, ) - # Extract job ID from sbatch output (format: "Submitted batch job 12345") - job_id = result.stdout.strip().split()[-1] - return job_id -def submit_slurm_job(script_path: Path) -> str: - """Submit a SLURM job and return the job ID. +def submit_slurm_script(script_path: Path) -> str: + """Submit a SLURM job (array or single) and return the job ID. Args: script_path: Path to SLURM batch script diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index 010e5ecfc..1d2a69c93 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -151,7 +151,6 @@ def _plot_func( MERGE_HIST: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - batch_id="batch_0", activations=PROCESSED_ACTIVATIONS.activations, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=_plot_func, @@ -173,10 +172,9 @@ def _plot_func( # Modern approach: run merge_iteration multiple times to create ensemble ENSEMBLE_SIZE: int = 4 HISTORIES: list[MergeHistory] = [] -for i in range(ENSEMBLE_SIZE): +for _i in range(ENSEMBLE_SIZE): HISTORY: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - batch_id=f"batch_{i}", activations=PROCESSED_ACTIVATIONS.activations, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=None, diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 12dfdae4d..acb6f394e 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,11 +16,11 @@ component_activations, process_activations, ) +from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.clustering.pipeline.s1_split_dataset import split_dataset from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_dists_distribution from spd.models.component_model import ComponentModel, SPDRunInfo @@ -45,25 +45,25 @@ MODEL.to(DEVICE) SPD_CONFIG = SPD_RUN.config -# Use split_dataset with RunConfig to get real data +# Use load_dataset with RunConfig to get real data CONFIG: ClusteringRunConfig = ClusteringRunConfig( merge_config=MergeConfig(), model_path=MODEL_PATH, - task_name="lm", - n_batches=1, batch_size=2, + dataset_seed=42, + idx_in_ensemble=0, dataset_streaming=True, # no effect since we do this manually ) -BATCHES, _ = split_dataset( - config=CONFIG, +DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = load_dataset( + model_path=MODEL_PATH, + task_name="lm", + batch_size=CONFIG.batch_size, + seed=CONFIG.dataset_seed, + # config=CONFIG, config_kwargs=dict(streaming=True), # see https://github.com/goodfire-ai/spd/pull/199 ) -# %% -# Load data batch -# ============================================================ -DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = next(BATCHES) # %% # Get component activations @@ -93,7 +93,6 @@ save_dir=TEMP_DIR, n_samples_max=256, wandb_run=None, - save_fmt="svg", ) # %% @@ -113,10 +112,9 @@ # Modern approach: run merge_iteration multiple times to create ensemble ENSEMBLE_SIZE: int = 2 HISTORIES: list[MergeHistory] = [] -for i in range(ENSEMBLE_SIZE): +for _i in range(ENSEMBLE_SIZE): HISTORY: MergeHistory = merge_iteration( merge_config=MERGE_CFG, - batch_id=f"batch_{i}", activations=PROCESSED_ACTIVATIONS.activations, component_labels=PROCESSED_ACTIVATIONS.labels, log_callback=None, diff --git a/tests/clustering/test_calc_distances.py b/tests/clustering/test_calc_distances.py new file mode 100644 index 000000000..d8971df05 --- /dev/null +++ b/tests/clustering/test_calc_distances.py @@ -0,0 +1,32 @@ +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble + + +def test_merge_history_normalization_happy_path(): + """Test that the normalization part of calc_distances.py works without errors""" + + # Create test merge histories + config = MergeConfig( + iters=3, + alpha=1.0, + activation_threshold=None, + pop_component_prob=0.0, + ) + + histories = [] + for _idx in range(2): + history = MergeHistory.from_config( + merge_config=config, + labels=ComponentLabels([f"comp{j}" for j in range(4)]), + ) + histories.append(history) + + # Test ensemble creation + ensemble = MergeHistoryEnsemble(data=histories) + assert len(ensemble.data) == 2 + + # Test normalization + normalized_array, metadata = ensemble.normalized() + assert normalized_array is not None + assert metadata is not None diff --git a/tests/clustering/test_clustering_experiments.py b/tests/clustering/test_clustering_experiments.py index 19ff937f6..fc27b6831 100644 --- a/tests/clustering/test_clustering_experiments.py +++ b/tests/clustering/test_clustering_experiments.py @@ -34,7 +34,7 @@ def test_cluster_resid_mlp_notebook(): @pytest.mark.slow def test_clustering_with_resid_mlp1_config(): """Test running clustering with test-resid_mlp1.json config.""" - config_path = CONFIG_DIR / "test-resid_mlp1.json" + config_path = CONFIG_DIR / "pipeline-test-resid_mlp1.yaml" assert config_path.exists(), f"Config not found: {config_path}" # Run the clustering main script with the test config @@ -43,6 +43,7 @@ def test_clustering_with_resid_mlp1_config(): "spd-cluster", "--config", str(config_path), + "--local", # don't assume we have slurm in the test env ], capture_output=True, text=True, @@ -78,7 +79,7 @@ def test_cluster_ss_notebook(): @pytest.mark.slow def test_clustering_with_simplestories_config(): """Test running clustering with test-simplestories.json config.""" - config_path = CONFIG_DIR / "test-simplestories.json" + config_path = CONFIG_DIR / "pipeline-test-simplestories.yaml" assert config_path.exists(), f"Config not found: {config_path}" # Run the clustering main script with the test config @@ -88,6 +89,7 @@ def test_clustering_with_simplestories_config(): "--config", str(config_path), "--dataset-streaming", # see https://github.com/goodfire-ai/spd/pull/199 + "--local", # don't assume we have slurm in the test env ], capture_output=True, text=True, diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 6463ad07b..14811b7c5 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -31,10 +31,7 @@ def test_merge_with_range_sampler(self): # Run merge iteration history = merge_iteration( - activations=activations, - batch_id="test_merge_with_range_sampler", - merge_config=config, - component_labels=component_labels, + activations=activations, merge_config=config, component_labels=component_labels ) # Check results @@ -68,10 +65,7 @@ def test_merge_with_mcmc_sampler(self): # Run merge iteration history = merge_iteration( - activations=activations, - batch_id="test_merge_with_mcmc_sampler", - merge_config=config, - component_labels=component_labels, + activations=activations, merge_config=config, component_labels=component_labels ) # Check results @@ -104,10 +98,7 @@ def test_merge_with_popping(self): # Run merge iteration history = merge_iteration( - activations=activations, - batch_id="test_merge_with_popping", - merge_config=config, - component_labels=component_labels, + activations=activations, merge_config=config, component_labels=component_labels ) # Check results @@ -142,7 +133,6 @@ def test_merge_comparison_samplers(self): history_range = merge_iteration( activations=activations.clone(), - batch_id="test_merge_comparison_samplers_range", merge_config=config_range, component_labels=ComponentLabels(component_labels.copy()), ) @@ -159,7 +149,6 @@ def test_merge_comparison_samplers(self): history_mcmc = merge_iteration( activations=activations.clone(), - batch_id="test_merge_comparison_samplers_mcmc", merge_config=config_mcmc, component_labels=ComponentLabels(component_labels.copy()), ) @@ -188,10 +177,7 @@ def test_merge_with_small_components(self): ) history = merge_iteration( - activations=activations, - batch_id="test_merge_with_small_components", - merge_config=config, - component_labels=component_labels, + activations=activations, merge_config=config, component_labels=component_labels ) # First entry is after first merge, so should be 3 - 1 = 2 diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py new file mode 100644 index 000000000..91a7cf2ad --- /dev/null +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -0,0 +1,40 @@ +import tempfile +from pathlib import Path + +import pytest + +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_run_config import ClusteringRunConfig, LoggingIntervals +from spd.clustering.scripts.run_clustering import main + + +@pytest.mark.slow +def test_run_clustering_happy_path(): + """Test that run_clustering.py runs without errors.""" + with tempfile.TemporaryDirectory() as temp_dir: + config = ClusteringRunConfig( + model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run + batch_size=4, + dataset_seed=0, + idx_in_ensemble=0, + base_output_dir=Path(temp_dir), + ensemble_id=None, + merge_config=MergeConfig( + activation_threshold=0.01, + alpha=1.0, + iters=3, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.05}, + pop_component_prob=0, + ), + wandb_project=None, + wandb_entity="goodfire", + logging_intervals=LoggingIntervals( + stat=1, + tensor=100, + plot=100, + artifact=100, + ), + dataset_streaming=True, # tests in CI very slow without this, see https://github.com/goodfire-ai/spd/pull/199 + ) + main(config) diff --git a/tests/clustering/test_storage.py b/tests/clustering/test_storage.py deleted file mode 100644 index 389940e54..000000000 --- a/tests/clustering/test_storage.py +++ /dev/null @@ -1,351 +0,0 @@ -"""Comprehensive tests for ClusteringStorage.""" - -import tempfile -from collections.abc import Iterator -from pathlib import Path - -import numpy as np -import pytest -import torch - -from spd.clustering.consts import ComponentLabels, DistancesMethod -from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_history import MergeHistory -from spd.clustering.merge_run_config import ClusteringRunConfig -from spd.clustering.pipeline.storage import ClusteringStorage, NormalizedEnsemble - - -@pytest.fixture -def temp_storage() -> Iterator[ClusteringStorage]: - """Create a temporary ClusteringStorage instance.""" - with tempfile.TemporaryDirectory() as tmp_dir: - storage = ClusteringStorage(base_path=Path(tmp_dir), run_identifier="test_run") - yield storage - - -@pytest.fixture -def sample_config() -> MergeConfig: - """Create a sample MergeConfig for testing.""" - return MergeConfig( - iters=5, - alpha=1.0, - activation_threshold=None, - pop_component_prob=0.0, - ) - - -class TestStorageInitialization: - """Test storage initialization and directory structure.""" - - def test_storage_creates_run_directory(self): - """Test that storage creates the run directory on initialization.""" - with tempfile.TemporaryDirectory() as tmp_dir: - base_path = Path(tmp_dir) - storage = ClusteringStorage(base_path=base_path, run_identifier="test_run") - - assert storage.run_path.exists() - assert storage.run_path == base_path / "test_run" - - def test_storage_without_run_identifier(self): - """Test that storage works without a run identifier.""" - with tempfile.TemporaryDirectory() as tmp_dir: - base_path = Path(tmp_dir) - storage = ClusteringStorage(base_path=base_path, run_identifier=None) - - assert storage.run_path == base_path - - def test_storage_paths_are_consistent(self, temp_storage: ClusteringStorage): - """Test that all storage paths are under the run path.""" - assert str(temp_storage._dataset_dir).startswith(str(temp_storage.run_path)) - assert str(temp_storage._batches_dir).startswith(str(temp_storage.run_path)) - assert str(temp_storage._histories_dir).startswith(str(temp_storage.run_path)) - assert str(temp_storage._ensemble_dir).startswith(str(temp_storage.run_path)) - assert str(temp_storage._distances_dir).startswith(str(temp_storage.run_path)) - - -class TestRunConfigStorage: - """Test run configuration storage.""" - - def test_save_and_load_run_config(self, temp_storage: ClusteringStorage): - """Test saving and loading RunConfig.""" - # Create a minimal RunConfig - config = ClusteringRunConfig( - merge_config=MergeConfig( - iters=10, - alpha=1.0, - activation_threshold=None, - pop_component_prob=0.0, - ), - model_path="wandb:entity/project/run_id", - task_name="lm", - n_batches=5, - batch_size=32, - base_path=temp_storage.base_path, - workers_per_device=1, - devices=["cuda"], - ) - - # Save config - saved_path = temp_storage.save_run_config(config) - assert saved_path.exists() - assert saved_path == temp_storage.run_config_file - - # Load and verify - loaded_config = ClusteringRunConfig.from_file(saved_path) - assert loaded_config.n_batches == 5 - assert loaded_config.batch_size == 32 - assert loaded_config.task_name == "lm" - - -class TestBatchStorage: - """Test batch data storage.""" - - def test_save_single_batch(self, temp_storage: ClusteringStorage): - """Test saving a single batch.""" - batch = torch.randint(0, 100, (8, 16)) # batch_size=8, seq_len=16 - batch_idx = 0 - - saved_path = temp_storage.save_batch(batch, batch_idx) - assert saved_path.exists() - assert saved_path.name == "batch_00.npz" - - def test_save_and_load_batch(self, temp_storage: ClusteringStorage): - """Test saving and loading a batch.""" - original_batch = torch.randint(0, 100, (8, 16)) - batch_idx = 0 - - # Save - temp_storage.save_batch(original_batch, batch_idx) - - # Load - loaded_batch = temp_storage.load_batch(temp_storage.batch_path(batch_idx)) - - # Verify - assert torch.equal(loaded_batch, original_batch) - - def test_save_multiple_batches(self, temp_storage: ClusteringStorage): - """Test saving multiple batches using save_batches.""" - batches = [torch.randint(0, 100, (8, 16)) for _ in range(3)] - config = {"test": "config"} - - saved_paths = temp_storage.save_batches(iter(batches), config) - - assert len(saved_paths) == 3 - assert all(p.exists() for p in saved_paths) - assert temp_storage.dataset_config_file.exists() - - def test_get_batch_paths(self, temp_storage: ClusteringStorage): - """Test retrieving all batch paths.""" - # Save some batches - for i in range(3): - temp_storage.save_batch(torch.randint(0, 100, (8, 16)), i) - - # Get paths - batch_paths = temp_storage.get_batch_paths() - - assert len(batch_paths) == 3 - assert all(p.exists() for p in batch_paths) - # Should be sorted - assert batch_paths == sorted(batch_paths) - - -class TestHistoryStorage: - """Test merge history storage.""" - - def test_save_and_load_history( - self, temp_storage: ClusteringStorage, sample_config: MergeConfig - ): - """Test saving and loading merge history.""" - # Create history - history = MergeHistory.from_config( - merge_config=sample_config, - labels=ComponentLabels(["comp0", "comp1", "comp2"]), - ) - - batch_id = "batch_00" - - # Save - saved_path = temp_storage.save_history(history, batch_id) - assert saved_path.exists() - assert "batch_00" in str(saved_path) - - # Load - loaded_history = temp_storage.load_history(batch_id) - assert loaded_history is not None - assert len(loaded_history.labels) == 3 - - def test_load_multiple_histories( - self, temp_storage: ClusteringStorage, sample_config: MergeConfig - ): - """Test loading all histories.""" - # Save multiple histories - for i in range(3): - history = MergeHistory.from_config( - merge_config=sample_config, - labels=ComponentLabels([f"comp{j}" for j in range(4)]), - ) - temp_storage.save_history(history, batch_id=f"batch_{i:02d}") - - # Load all - histories = temp_storage.load_histories() - assert len(histories) == 3 - - def test_get_history_paths(self, temp_storage: ClusteringStorage, sample_config: MergeConfig): - """Test getting all history paths.""" - # Save histories - for i in range(2): - history = MergeHistory.from_config( - merge_config=sample_config, - labels=ComponentLabels(["comp0", "comp1"]), - ) - temp_storage.save_history(history, batch_id=f"batch_{i:02d}") - - # Get paths - history_paths = temp_storage.get_history_paths() - assert len(history_paths) == 2 - assert all(p.exists() for p in history_paths) - - -class TestEnsembleStorage: - """Test ensemble data storage.""" - - def test_save_ensemble(self, temp_storage: ClusteringStorage): - """Test saving ensemble data.""" - # Create dummy ensemble data - merge_array = np.random.randint(0, 10, size=(2, 5, 8)) # n_ens, n_iters, c_components - metadata = {"n_ensemble": 2, "n_iters": 5} - - ensemble = NormalizedEnsemble(merge_array=merge_array, metadata=metadata) - - # Save - meta_path, array_path = temp_storage.save_ensemble(ensemble) - - assert meta_path.exists() - assert array_path.exists() - assert meta_path == temp_storage.ensemble_meta_file - assert array_path == temp_storage.ensemble_array_file - - def test_ensemble_data_integrity(self, temp_storage: ClusteringStorage): - """Test that ensemble data can be saved and loaded correctly.""" - # Create ensemble data - original_array = np.random.randint(0, 10, size=(2, 5, 8)) - metadata = {"test": "value", "n_ensemble": 2} - - ensemble = NormalizedEnsemble(merge_array=original_array, metadata=metadata) - - # Save - _, array_path = temp_storage.save_ensemble(ensemble) - - # Load and verify - loaded_data = np.load(array_path) - loaded_array = loaded_data["merges"] - - assert np.array_equal(loaded_array, original_array) - - -class TestDistancesStorage: - """Test distance matrix storage.""" - - def test_save_distances(self, temp_storage: ClusteringStorage): - """Test saving distance matrix.""" - distances = np.random.rand(5, 3, 3) # n_iters, n_ens, n_ens - method: DistancesMethod = "perm_invariant_hamming" - - saved_path = temp_storage.save_distances(distances, method) - - assert saved_path.exists() - assert method in saved_path.name - - def test_save_and_load_distances(self, temp_storage: ClusteringStorage): - """Test saving and loading distances.""" - original_distances = np.random.rand(5, 3, 3) - method: DistancesMethod = "perm_invariant_hamming" - - # Save - temp_storage.save_distances(original_distances, method) - - # Load - loaded_distances = temp_storage.load_distances(method) - - assert np.array_equal(loaded_distances, original_distances) - - -class TestStorageIntegration: - """Test integration scenarios.""" - - def test_full_pipeline_storage_flow( - self, temp_storage: ClusteringStorage, sample_config: MergeConfig - ): - """Test a complete storage workflow.""" - # 1. Save run config - run_config = ClusteringRunConfig( - merge_config=sample_config, - model_path="wandb:entity/project/run_id", - task_name="lm", - n_batches=2, - batch_size=8, - base_path=temp_storage.base_path, - workers_per_device=1, - devices=["cpu"], - ) - temp_storage.save_run_config(run_config) - - # 2. Save batches - batches = [torch.randint(0, 100, (8, 16)) for _ in range(2)] - temp_storage.save_batches(iter(batches), {"dataset": "test"}) - - # 3. Save histories - for i in range(2): - history = MergeHistory.from_config( - merge_config=sample_config, - labels=ComponentLabels(["comp0", "comp1", "comp2"]), - ) - temp_storage.save_history(history, batch_id=f"batch_{i:02d}") - - # 4. Save ensemble - merge_array = np.random.randint(0, 3, size=(2, 5, 3)) - ensemble = NormalizedEnsemble( - merge_array=merge_array, - metadata={"n_ensemble": 2, "n_iters": 5}, - ) - temp_storage.save_ensemble(ensemble) - - # 5. Save distances - distances = np.random.rand(5, 2, 2) - temp_storage.save_distances(distances, "perm_invariant_hamming") - - # Verify all files exist - assert temp_storage.run_config_file.exists() - assert temp_storage.dataset_config_file.exists() - assert len(temp_storage.get_batch_paths()) == 2 - assert len(temp_storage.get_history_paths()) == 2 - assert temp_storage.ensemble_meta_file.exists() - assert temp_storage.ensemble_array_file.exists() - - def test_storage_filesystem_structure(self, temp_storage: ClusteringStorage): - """Test that the filesystem structure matches documentation.""" - # Create minimal data to generate structure - temp_storage.save_run_config( - ClusteringRunConfig( - merge_config=MergeConfig( - iters=1, - alpha=1.0, - activation_threshold=None, - pop_component_prob=0.0, - ), - model_path="wandb:e/p/r", - task_name="lm", - n_batches=1, - batch_size=1, - base_path=temp_storage.base_path, - workers_per_device=1, - devices=["cpu"], - ) - ) - - # Verify structure - assert (temp_storage.run_path / "run_config.json").exists() - - # The directories are created lazily, so trigger their creation - temp_storage.save_batch(torch.tensor([[1, 2, 3]]), 0) - assert (temp_storage.run_path / "dataset" / "batches").exists() diff --git a/tests/clustering/test_wandb_integration.py b/tests/clustering/test_wandb_integration.py deleted file mode 100644 index cf400ca2b..000000000 --- a/tests/clustering/test_wandb_integration.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Quick sanity tests for WandB integration features.""" - -import tempfile -from pathlib import Path -from unittest.mock import Mock, patch - -from spd.clustering.consts import ComponentLabels -from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble -from spd.clustering.pipeline.s2_clustering import _save_merge_history_to_wandb -from spd.clustering.pipeline.s3_normalize_histories import normalize_and_save - - -def test_wandb_url_parsing_short_format(): - """Test that normalize_and_save can process merge histories using storage.""" - from spd.clustering.pipeline.storage import ClusteringStorage - - # Create temporary directory for storage - with tempfile.TemporaryDirectory() as tmp_dir: - tmp_path = Path(tmp_dir) - - # Create ClusteringStorage instance - storage = ClusteringStorage(base_path=tmp_path, run_identifier="test_run") - - # Create mock merge histories - config = MergeConfig( - iters=5, - alpha=1.0, - activation_threshold=None, - pop_component_prob=0.0, - ) - - # Save histories using storage - for idx in range(2): - history = MergeHistory.from_config( - merge_config=config, - labels=ComponentLabels([f"comp{j}" for j in range(5)]), - ) - storage.save_history(history, batch_id=f"batch_{idx:02d}") - - # Test normalize_and_save with storage - result = normalize_and_save(storage=storage) - - # Basic checks - assert result is not None - assert storage.ensemble_meta_file.exists() - assert storage.ensemble_array_file.exists() - - # Verify we can load the histories back - loaded_histories = storage.load_histories() - assert len(loaded_histories) == 2 - - -def test_merge_history_ensemble(): - """Test that MergeHistoryEnsemble can handle multiple histories.""" - - # Create test merge histories - config = MergeConfig( - iters=3, - alpha=1.0, - activation_threshold=None, - pop_component_prob=0.0, - ) - - histories = [] - for _idx in range(2): - history = MergeHistory.from_config( - merge_config=config, - labels=ComponentLabels([f"comp{j}" for j in range(4)]), - ) - histories.append(history) - - # Test ensemble creation - ensemble = MergeHistoryEnsemble(data=histories) - assert len(ensemble.data) == 2 - - # Test normalization - normalized_array, metadata = ensemble.normalized() - assert normalized_array is not None - assert metadata is not None - - -def test_save_merge_history_to_wandb(): - """Test that _save_merge_history_to_wandb creates the expected artifact.""" - - # Create a real MergeHistory - config = MergeConfig( - iters=5, - alpha=1.0, - activation_threshold=None, - pop_component_prob=0.0, - ) - - history = MergeHistory.from_config( - merge_config=config, - labels=ComponentLabels(["comp0", "comp1", "comp2"]), - ) - - # Mock wandb run and artifact - mock_wandb_run = Mock() - mock_artifact = Mock() - - with tempfile.TemporaryDirectory() as tmp_dir: - history_path = Path(tmp_dir) / "test_history.zip" - history.save(history_path) - - with patch("spd.clustering.pipeline.s2_clustering.wandb.Artifact") as mock_artifact_class: - mock_artifact_class.return_value = mock_artifact - - # Call the function - _save_merge_history_to_wandb( - run=mock_wandb_run, - history_path=history_path, - batch_id="batch_01", - config_identifier="test_config", - history=history, - ) - - # Check that artifact was created and logged - mock_artifact_class.assert_called_once() - mock_wandb_run.log_artifact.assert_called_once_with(mock_artifact) - - # Check artifact creation parameters - call_args = mock_artifact_class.call_args - assert call_args.kwargs["name"] == "merge_history_batch_01" - assert call_args.kwargs["type"] == "merge_history" - assert "batch_01" in call_args.kwargs["description"] - - -def test_wandb_url_field_in_merge_history(): - """Test that MergeHistory can store and serialize wandb_url.""" - - # Create a simple config - config = MergeConfig( - iters=10, - alpha=1.0, - activation_threshold=None, - pop_component_prob=0.0, - ) - - # Create MergeHistory with wandb_url - history = MergeHistory.from_config( - merge_config=config, - labels=ComponentLabels(["comp0", "comp1", "comp2", "comp3", "comp4"]), - ) - # Check that it can be serialized and deserialized - with tempfile.TemporaryDirectory() as tmp_dir: - save_path = Path(tmp_dir) / "test_history.zip" - history.save(save_path) - loaded_history = MergeHistory.read(save_path) - - assert loaded_history is not None - assert loaded_history.merges.group_idxs.shape == (10, 5) # (iters, n_components) diff --git a/tests/scripts_run/test_main.py b/tests/scripts_run/test_main.py index 8f8ce65e9..00a6ee044 100644 --- a/tests/scripts_run/test_main.py +++ b/tests/scripts_run/test_main.py @@ -36,7 +36,7 @@ class TestSPDRun: ("tms_5-2", True, 4, None), # Command count depends on sweep params ], ) - @patch("spd.scripts.run.submit_slurm_array") + @patch("spd.scripts.run.submit_slurm_script") @patch("spd.scripts.run.create_slurm_array_script") @patch("spd.scripts.run.load_sweep_params") def test_spd_run_not_local_no_sweep( @@ -88,6 +88,7 @@ def test_spd_run_not_local_no_sweep( # Verify command structure for cmd in commands: + assert isinstance(cmd, str) assert "python" in cmd assert "_decomposition.py" in cmd assert "json:" in cmd @@ -109,7 +110,7 @@ def test_spd_run_not_local_no_sweep( ("tms_5-2", True), ], ) - @patch("spd.scripts.run.subprocess.run") + @patch("spd.scripts.run.run_script_array_local") @patch("spd.scripts.run.load_sweep_params") def test_spd_run_local_no_sweep( self, @@ -133,30 +134,31 @@ def test_spd_run_local_no_sweep( **self._DEFAULT_MAIN_KWARGS, # pyright: ignore[reportArgumentType] ) - # Calculate expected number of subprocess calls + # Calculate expected number of commands num_experiments = len(experiments.split(",")) - expected_calls = num_experiments * 2 if sweep else num_experiments + expected_num_commands = num_experiments * 2 if sweep else num_experiments - # Assert subprocess.run was called the expected number of times - assert mock_subprocess.call_count == expected_calls + # Assert run_script_array_local was called exactly once + assert mock_subprocess.call_count == 1 - # Verify each subprocess call - for call in mock_subprocess.call_args_list: - args = call[0][0] # Get the command list + # Get the commands list from the call + commands = mock_subprocess.call_args[0][0] + assert len(commands) == expected_num_commands - # Should be a list of arguments - assert isinstance(args, list) - assert args[0] == "python" - assert "_decomposition.py" in args[1] + # Verify each command + for cmd in commands: + # Should be a string + assert isinstance(cmd, str) + assert "python" in cmd + assert "_decomposition.py" in cmd # Check for required arguments in the command - cmd_str = " ".join(args) - assert "json:" in cmd_str - assert "--sweep_id" in cmd_str - assert "--evals_id" in cmd_str + assert "json:" in cmd + assert "--sweep_id" in cmd + assert "--evals_id" in cmd if sweep: - assert "--sweep_params_json" in cmd_str + assert "--sweep_params_json" in cmd # No wandb functions should be called since use_wandb=False @@ -178,7 +180,7 @@ def test_invalid_experiment_name(self): **self._DEFAULT_MAIN_KWARGS, # pyright: ignore[reportArgumentType] ) - @patch("spd.scripts.run.subprocess.run") + @patch("spd.scripts.run.run_script_array_local") def test_sweep_params_integration(self, mock_subprocess): """Test that sweep parameters are correctly integrated into commands. @@ -196,12 +198,17 @@ def test_sweep_params_integration(self, mock_subprocess): **self._DEFAULT_MAIN_KWARGS, # pyright: ignore[reportArgumentType] ) + # Assert run_script_array_local was called exactly once + assert mock_subprocess.call_count == 1 + + # Get the commands list + commands = mock_subprocess.call_args[0][0] + # Verify multiple commands were generated (sweep should create multiple runs) - assert mock_subprocess.call_count > 1 + assert len(commands) > 1 # Check that sweep parameters are in the commands - for call in mock_subprocess.call_args_list: - args = call[0][0] - cmd_str = " ".join(args) - assert "--sweep_params_json" in cmd_str - assert "json:" in cmd_str + for cmd in commands: + assert isinstance(cmd, str) + assert "--sweep_params_json" in cmd + assert "json:" in cmd diff --git a/uv.lock b/uv.lock index 9ae8b57da..560ef3f01 100644 --- a/uv.lock +++ b/uv.lock @@ -774,15 +774,15 @@ wheels = [ [[package]] name = "jupyter-core" -version = "5.9.0" +version = "5.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "platformdirs" }, { name = "traitlets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/0c/7b01e93e054555cbadf614f2ff10ea77aecbc8867831914d8a2c5868481a/jupyter_core-5.9.0.tar.gz", hash = "sha256:5f8fba10cfc946fe1b4037e986458fc89430397207b21d741dc399d3d42951d4", size = 89804, upload-time = "2025-10-16T12:12:23.851Z" } +sdist = { url = "https://files.pythonhosted.org/packages/02/49/9d1284d0dc65e2c757b74c6687b6d319b02f822ad039e5c512df9194d9dd/jupyter_core-5.9.1.tar.gz", hash = "sha256:4d09aaff303b9566c3ce657f580bd089ff5c91f5f89cf7d8846c3cdf465b5508", size = 89814, upload-time = "2025-10-16T19:19:18.444Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/f2/5efda2a70d98288f4d94baba8489cd782d53772233c77351864bc754a146/jupyter_core-5.9.0-py3-none-any.whl", hash = "sha256:bf13431d292ce34a25568586729a3b9deb07d112289b77350dc4c2340c2f34c1", size = 29024, upload-time = "2025-10-16T12:12:22.19Z" }, + { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, ] [[package]] @@ -1738,28 +1738,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/b9/9bd84453ed6dd04688de9b3f3a4146a1698e8faae2ceeccce4e14c67ae17/ruff-0.14.0.tar.gz", hash = "sha256:62ec8969b7510f77945df916de15da55311fade8d6050995ff7f680afe582c57", size = 5452071, upload-time = "2025-10-07T18:21:55.763Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/4e/79d463a5f80654e93fa653ebfb98e0becc3f0e7cf6219c9ddedf1e197072/ruff-0.14.0-py3-none-linux_armv6l.whl", hash = "sha256:58e15bffa7054299becf4bab8a1187062c6f8cafbe9f6e39e0d5aface455d6b3", size = 12494532, upload-time = "2025-10-07T18:21:00.373Z" }, - { url = "https://files.pythonhosted.org/packages/ee/40/e2392f445ed8e02aa6105d49db4bfff01957379064c30f4811c3bf38aece/ruff-0.14.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:838d1b065f4df676b7c9957992f2304e41ead7a50a568185efd404297d5701e8", size = 13160768, upload-time = "2025-10-07T18:21:04.73Z" }, - { url = "https://files.pythonhosted.org/packages/75/da/2a656ea7c6b9bd14c7209918268dd40e1e6cea65f4bb9880eaaa43b055cd/ruff-0.14.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:703799d059ba50f745605b04638fa7e9682cc3da084b2092feee63500ff3d9b8", size = 12363376, upload-time = "2025-10-07T18:21:07.833Z" }, - { url = "https://files.pythonhosted.org/packages/42/e2/1ffef5a1875add82416ff388fcb7ea8b22a53be67a638487937aea81af27/ruff-0.14.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3ba9a8925e90f861502f7d974cc60e18ca29c72bb0ee8bfeabb6ade35a3abde7", size = 12608055, upload-time = "2025-10-07T18:21:10.72Z" }, - { url = "https://files.pythonhosted.org/packages/4a/32/986725199d7cee510d9f1dfdf95bf1efc5fa9dd714d0d85c1fb1f6be3bc3/ruff-0.14.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e41f785498bd200ffc276eb9e1570c019c1d907b07cfb081092c8ad51975bbe7", size = 12318544, upload-time = "2025-10-07T18:21:13.741Z" }, - { url = "https://files.pythonhosted.org/packages/9a/ed/4969cefd53315164c94eaf4da7cfba1f267dc275b0abdd593d11c90829a3/ruff-0.14.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30a58c087aef4584c193aebf2700f0fbcfc1e77b89c7385e3139956fa90434e2", size = 14001280, upload-time = "2025-10-07T18:21:16.411Z" }, - { url = "https://files.pythonhosted.org/packages/ab/ad/96c1fc9f8854c37681c9613d825925c7f24ca1acfc62a4eb3896b50bacd2/ruff-0.14.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f8d07350bc7af0a5ce8812b7d5c1a7293cf02476752f23fdfc500d24b79b783c", size = 15027286, upload-time = "2025-10-07T18:21:19.577Z" }, - { url = "https://files.pythonhosted.org/packages/b3/00/1426978f97df4fe331074baf69615f579dc4e7c37bb4c6f57c2aad80c87f/ruff-0.14.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eec3bbbf3a7d5482b5c1f42d5fc972774d71d107d447919fca620b0be3e3b75e", size = 14451506, upload-time = "2025-10-07T18:21:22.779Z" }, - { url = "https://files.pythonhosted.org/packages/58/d5/9c1cea6e493c0cf0647674cca26b579ea9d2a213b74b5c195fbeb9678e15/ruff-0.14.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16b68e183a0e28e5c176d51004aaa40559e8f90065a10a559176713fcf435206", size = 13437384, upload-time = "2025-10-07T18:21:25.758Z" }, - { url = "https://files.pythonhosted.org/packages/29/b4/4cd6a4331e999fc05d9d77729c95503f99eae3ba1160469f2b64866964e3/ruff-0.14.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb732d17db2e945cfcbbc52af0143eda1da36ca8ae25083dd4f66f1542fdf82e", size = 13447976, upload-time = "2025-10-07T18:21:28.83Z" }, - { url = "https://files.pythonhosted.org/packages/3b/c0/ac42f546d07e4f49f62332576cb845d45c67cf5610d1851254e341d563b6/ruff-0.14.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:c958f66ab884b7873e72df38dcabee03d556a8f2ee1b8538ee1c2bbd619883dd", size = 13682850, upload-time = "2025-10-07T18:21:31.842Z" }, - { url = "https://files.pythonhosted.org/packages/5f/c4/4b0c9bcadd45b4c29fe1af9c5d1dc0ca87b4021665dfbe1c4688d407aa20/ruff-0.14.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7eb0499a2e01f6e0c285afc5bac43ab380cbfc17cd43a2e1dd10ec97d6f2c42d", size = 12449825, upload-time = "2025-10-07T18:21:35.074Z" }, - { url = "https://files.pythonhosted.org/packages/4b/a8/e2e76288e6c16540fa820d148d83e55f15e994d852485f221b9524514730/ruff-0.14.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4c63b2d99fafa05efca0ab198fd48fa6030d57e4423df3f18e03aa62518c565f", size = 12272599, upload-time = "2025-10-07T18:21:38.08Z" }, - { url = "https://files.pythonhosted.org/packages/18/14/e2815d8eff847391af632b22422b8207704222ff575dec8d044f9ab779b2/ruff-0.14.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:668fce701b7a222f3f5327f86909db2bbe99c30877c8001ff934c5413812ac02", size = 13193828, upload-time = "2025-10-07T18:21:41.216Z" }, - { url = "https://files.pythonhosted.org/packages/44/c6/61ccc2987cf0aecc588ff8f3212dea64840770e60d78f5606cd7dc34de32/ruff-0.14.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a86bf575e05cb68dcb34e4c7dfe1064d44d3f0c04bbc0491949092192b515296", size = 13628617, upload-time = "2025-10-07T18:21:44.04Z" }, - { url = "https://files.pythonhosted.org/packages/73/e6/03b882225a1b0627e75339b420883dc3c90707a8917d2284abef7a58d317/ruff-0.14.0-py3-none-win32.whl", hash = "sha256:7450a243d7125d1c032cb4b93d9625dea46c8c42b4f06c6b709baac168e10543", size = 12367872, upload-time = "2025-10-07T18:21:46.67Z" }, - { url = "https://files.pythonhosted.org/packages/41/77/56cf9cf01ea0bfcc662de72540812e5ba8e9563f33ef3d37ab2174892c47/ruff-0.14.0-py3-none-win_amd64.whl", hash = "sha256:ea95da28cd874c4d9c922b39381cbd69cb7e7b49c21b8152b014bd4f52acddc2", size = 13464628, upload-time = "2025-10-07T18:21:50.318Z" }, - { url = "https://files.pythonhosted.org/packages/c6/2a/65880dfd0e13f7f13a775998f34703674a4554906167dce02daf7865b954/ruff-0.14.0-py3-none-win_arm64.whl", hash = "sha256:f42c9495f5c13ff841b1da4cb3c2a42075409592825dada7c5885c2c844ac730", size = 12565142, upload-time = "2025-10-07T18:21:53.577Z" }, +version = "0.14.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/58/6ca66896635352812de66f71cdf9ff86b3a4f79071ca5730088c0cd0fc8d/ruff-0.14.1.tar.gz", hash = "sha256:1dd86253060c4772867c61791588627320abcb6ed1577a90ef432ee319729b69", size = 5513429, upload-time = "2025-10-16T18:05:41.766Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/39/9cc5ab181478d7a18adc1c1e051a84ee02bec94eb9bdfd35643d7c74ca31/ruff-0.14.1-py3-none-linux_armv6l.whl", hash = "sha256:083bfc1f30f4a391ae09c6f4f99d83074416b471775b59288956f5bc18e82f8b", size = 12445415, upload-time = "2025-10-16T18:04:48.227Z" }, + { url = "https://files.pythonhosted.org/packages/ef/2e/1226961855ccd697255988f5a2474890ac7c5863b080b15bd038df820818/ruff-0.14.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f6fa757cd717f791009f7669fefb09121cc5f7d9bd0ef211371fad68c2b8b224", size = 12784267, upload-time = "2025-10-16T18:04:52.515Z" }, + { url = "https://files.pythonhosted.org/packages/c1/ea/fd9e95863124ed159cd0667ec98449ae461de94acda7101f1acb6066da00/ruff-0.14.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d6191903d39ac156921398e9c86b7354d15e3c93772e7dbf26c9fcae59ceccd5", size = 11781872, upload-time = "2025-10-16T18:04:55.396Z" }, + { url = "https://files.pythonhosted.org/packages/1e/5a/e890f7338ff537dba4589a5e02c51baa63020acfb7c8cbbaea4831562c96/ruff-0.14.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed04f0e04f7a4587244e5c9d7df50e6b5bf2705d75059f409a6421c593a35896", size = 12226558, upload-time = "2025-10-16T18:04:58.166Z" }, + { url = "https://files.pythonhosted.org/packages/a6/7a/8ab5c3377f5bf31e167b73651841217542bcc7aa1c19e83030835cc25204/ruff-0.14.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c9e6cf6cd4acae0febbce29497accd3632fe2025c0c583c8b87e8dbdeae5f61", size = 12187898, upload-time = "2025-10-16T18:05:01.455Z" }, + { url = "https://files.pythonhosted.org/packages/48/8d/ba7c33aa55406955fc124e62c8259791c3d42e3075a71710fdff9375134f/ruff-0.14.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fa2458527794ecdfbe45f654e42c61f2503a230545a91af839653a0a93dbc6", size = 12939168, upload-time = "2025-10-16T18:05:04.397Z" }, + { url = "https://files.pythonhosted.org/packages/b4/c2/70783f612b50f66d083380e68cbd1696739d88e9b4f6164230375532c637/ruff-0.14.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:39f1c392244e338b21d42ab29b8a6392a722c5090032eb49bb4d6defcdb34345", size = 14386942, upload-time = "2025-10-16T18:05:07.102Z" }, + { url = "https://files.pythonhosted.org/packages/48/44/cd7abb9c776b66d332119d67f96acf15830d120f5b884598a36d9d3f4d83/ruff-0.14.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7382fa12a26cce1f95070ce450946bec357727aaa428983036362579eadcc5cf", size = 13990622, upload-time = "2025-10-16T18:05:09.882Z" }, + { url = "https://files.pythonhosted.org/packages/eb/56/4259b696db12ac152fe472764b4f78bbdd9b477afd9bc3a6d53c01300b37/ruff-0.14.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd0bf2be3ae8521e1093a487c4aa3b455882f139787770698530d28ed3fbb37c", size = 13431143, upload-time = "2025-10-16T18:05:13.46Z" }, + { url = "https://files.pythonhosted.org/packages/e0/35/266a80d0eb97bd224b3265b9437bd89dde0dcf4faf299db1212e81824e7e/ruff-0.14.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabcaa9ccf8089fb4fdb78d17cc0e28241520f50f4c2e88cb6261ed083d85151", size = 13132844, upload-time = "2025-10-16T18:05:16.1Z" }, + { url = "https://files.pythonhosted.org/packages/65/6e/d31ce218acc11a8d91ef208e002a31acf315061a85132f94f3df7a252b18/ruff-0.14.1-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:747d583400f6125ec11a4c14d1c8474bf75d8b419ad22a111a537ec1a952d192", size = 13401241, upload-time = "2025-10-16T18:05:19.395Z" }, + { url = "https://files.pythonhosted.org/packages/9f/b5/dbc4221bf0b03774b3b2f0d47f39e848d30664157c15b965a14d890637d2/ruff-0.14.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5a6e74c0efd78515a1d13acbfe6c90f0f5bd822aa56b4a6d43a9ffb2ae6e56cd", size = 12132476, upload-time = "2025-10-16T18:05:22.163Z" }, + { url = "https://files.pythonhosted.org/packages/98/4b/ac99194e790ccd092d6a8b5f341f34b6e597d698e3077c032c502d75ea84/ruff-0.14.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0ea6a864d2fb41a4b6d5b456ed164302a0d96f4daac630aeba829abfb059d020", size = 12139749, upload-time = "2025-10-16T18:05:25.162Z" }, + { url = "https://files.pythonhosted.org/packages/47/26/7df917462c3bb5004e6fdfcc505a49e90bcd8a34c54a051953118c00b53a/ruff-0.14.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0826b8764f94229604fa255918d1cc45e583e38c21c203248b0bfc9a0e930be5", size = 12544758, upload-time = "2025-10-16T18:05:28.018Z" }, + { url = "https://files.pythonhosted.org/packages/64/d0/81e7f0648e9764ad9b51dd4be5e5dac3fcfff9602428ccbae288a39c2c22/ruff-0.14.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cbc52160465913a1a3f424c81c62ac8096b6a491468e7d872cb9444a860bc33d", size = 13221811, upload-time = "2025-10-16T18:05:30.707Z" }, + { url = "https://files.pythonhosted.org/packages/c3/07/3c45562c67933cc35f6d5df4ca77dabbcd88fddaca0d6b8371693d29fd56/ruff-0.14.1-py3-none-win32.whl", hash = "sha256:e037ea374aaaff4103240ae79168c0945ae3d5ae8db190603de3b4012bd1def6", size = 12319467, upload-time = "2025-10-16T18:05:33.261Z" }, + { url = "https://files.pythonhosted.org/packages/02/88/0ee4ca507d4aa05f67e292d2e5eb0b3e358fbcfe527554a2eda9ac422d6b/ruff-0.14.1-py3-none-win_amd64.whl", hash = "sha256:59d599cdff9c7f925a017f6f2c256c908b094e55967f93f2821b1439928746a1", size = 13401123, upload-time = "2025-10-16T18:05:35.984Z" }, + { url = "https://files.pythonhosted.org/packages/b8/81/4b6387be7014858d924b843530e1b2a8e531846807516e9bea2ee0936bf7/ruff-0.14.1-py3-none-win_arm64.whl", hash = "sha256:e3b443c4c9f16ae850906b8d0a707b2a4c16f8d2f0a7fe65c475c5886665ce44", size = 12436636, upload-time = "2025-10-16T18:05:38.995Z" }, ] [[package]] From 8f6462c427349611da527df3783049459725f6aa Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 12:01:39 +0100 Subject: [PATCH 16/77] split up ci_dt --- spd/clustering/ci_dt/__init__.py | 40 ++++++++ spd/clustering/ci_dt/config.py | 14 +++ spd/clustering/ci_dt/core.py | 157 ++++++++++++++++++++++++++++ spd/clustering/ci_dt/plot.py | 102 +++++++++++++++++++ spd/clustering/ci_dt/run.py | 170 +++++++++++++++++++++++++++++++ 5 files changed, 483 insertions(+) create mode 100644 spd/clustering/ci_dt/__init__.py create mode 100644 spd/clustering/ci_dt/config.py create mode 100644 spd/clustering/ci_dt/core.py create mode 100644 spd/clustering/ci_dt/plot.py create mode 100644 spd/clustering/ci_dt/run.py diff --git a/spd/clustering/ci_dt/__init__.py b/spd/clustering/ci_dt/__init__.py new file mode 100644 index 000000000..1451a7f59 --- /dev/null +++ b/spd/clustering/ci_dt/__init__.py @@ -0,0 +1,40 @@ +"""Causal importance decision tree package.""" + +from spd.clustering.ci_dt.config import CIDTConfig +from spd.clustering.ci_dt.core import ( + LayerModel, + build_xy, + concat_cols, + get_estimator_for, + layer_metrics, + predict_all, + predict_k, + proba_for_layer, + train_trees, +) +from spd.clustering.ci_dt.plot import ( + plot_activations, + plot_covariance, + plot_layer_metrics, + plot_selected_trees, +) + +__all__ = [ + # Config + "CIDTConfig", + # Core + "LayerModel", + "concat_cols", + "build_xy", + "train_trees", + "predict_k", + "predict_all", + "layer_metrics", + "proba_for_layer", + "get_estimator_for", + # Plot + "plot_activations", + "plot_covariance", + "plot_layer_metrics", + "plot_selected_trees", +] diff --git a/spd/clustering/ci_dt/config.py b/spd/clustering/ci_dt/config.py new file mode 100644 index 000000000..a83c6adc4 --- /dev/null +++ b/spd/clustering/ci_dt/config.py @@ -0,0 +1,14 @@ +"""Configuration for causal importance decision tree training.""" + +from dataclasses import dataclass + + +@dataclass +class CIDTConfig: + """Configuration for causal importance decision tree training.""" + + n_samples: int = 250 + activation_threshold: float = 0.01 # Threshold for boolean conversion + filter_dead_threshold: float = 0.001 # Threshold for filtering dead components + max_depth: int = 8 # Maximum depth for decision trees + random_state: int = 7 # Random state for reproducibility diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py new file mode 100644 index 000000000..bfdaa27fd --- /dev/null +++ b/spd/clustering/ci_dt/core.py @@ -0,0 +1,157 @@ +"""Core library functions for causal importance decision trees.""" + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal + +import numpy as np +from jaxtyping import Bool, Float +from sklearn.base import ClassifierMixin +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + balanced_accuracy_score, +) +from sklearn.multioutput import MultiOutputClassifier +from sklearn.tree import DecisionTreeClassifier + + +@dataclass +class LayerModel: + """Holds a trained per-layer model.""" + + layer_index: int + model: ClassifierMixin + feature_dim: int + target_dim: int + + +def concat_cols( + Xs: Sequence[Bool[np.ndarray, "n_samples n_features"]], +) -> Bool[np.ndarray, "n_samples n_concat"]: + """Column-concat a sequence or return empty (n,0).""" + n_samples: int = Xs[0].shape[0] if len(Xs) else 0 + return np.concatenate(Xs, axis=1) if len(Xs) else np.zeros((n_samples, 0), bool) + + +def build_xy( + layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], +) -> list[ + tuple[ + Bool[np.ndarray, "n_samples n_features"], + Bool[np.ndarray, "n_samples n_targets"], + ] +]: + """Return (X_k,Y_k) for k=1..L-1 with X_k=concat(layers[:k]).""" + XYs: list[tuple[np.ndarray, np.ndarray]] = [] + for k in range(1, len(layers)): + X_k: np.ndarray = concat_cols(layers[:k]) + Y_k: np.ndarray = layers[k] + XYs.append((X_k, Y_k)) + return XYs + + +def train_trees( + layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + *, + strategy: Literal["one_vs_all", "single_tree"] = "one_vs_all", + max_depth: int | None = None, + min_samples_leaf: int = 1, + random_state: int | None = 0, +) -> list[LayerModel]: + """Train one model per target layer using previous layers as features.""" + XYs = build_xy(layers) + models: list[LayerModel] = [] + for k, (X_k, Y_k) in enumerate(XYs, start=1): + base = DecisionTreeClassifier( + max_depth=max_depth, + min_samples_leaf=min_samples_leaf, + random_state=random_state, + ) + model: ClassifierMixin = MultiOutputClassifier(base) if strategy == "one_vs_all" else base + _ = model.fit(X_k.astype(np.uint8), Y_k.astype(np.uint8)) + models.append(LayerModel(k, model, int(X_k.shape[1]), int(Y_k.shape[1]))) + return models + + +def predict_k( + models: Sequence[LayerModel], + prefix_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + k: int, + *, + threshold: float = 0.5, +) -> Bool[np.ndarray, "n_samples n_components_k"]: + """Predict layer k activations from layers[:k].""" + lm: LayerModel = next(m for m in models if m.layer_index == k) + X: np.ndarray = concat_cols(prefix_layers) + proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore + if isinstance(proba, list): + P: np.ndarray = np.stack([p[:, 1] for p in proba], axis=1) + else: + P = proba[..., 1] # type: ignore + Y_hat: np.ndarray = (float(threshold) <= P).astype(bool) + return Y_hat + + +def predict_all( + models: Sequence[LayerModel], + seed_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + *, + thresholds: Sequence[float] | None = None, +) -> list[Bool[np.ndarray, "n_samples n_components"]]: + """Sequentially predict layers 1.. using layer 0 as seed.""" + out: list[np.ndarray] = [seed_layers[0].copy()] + ths: list[float] = list(thresholds) if thresholds is not None else [] + for i, lm in enumerate(sorted(models, key=lambda m: m.layer_index)): + thr: float = ths[i] if i < len(ths) else 0.5 + out.append(predict_k(models, out, lm.layer_index, threshold=thr)) + return out + + +def layer_metrics( + Y_true: Bool[np.ndarray, "n t"], + Y_prob: Float[np.ndarray, "n t"], + Y_pred: Bool[np.ndarray, "n t"], +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Return per-target AP, acc, bacc, prevalence.""" + T: int = Y_true.shape[1] + ap: np.ndarray = np.zeros(T) + acc: np.ndarray = np.zeros(T) + bacc: np.ndarray = np.zeros(T) + prev: np.ndarray = np.zeros(T) + for j in range(T): + y: np.ndarray = Y_true[:, j].astype(int) + p: np.ndarray = Y_prob[:, j] + yhat: np.ndarray = Y_pred[:, j].astype(int) + prev[j] = float(y.mean()) + try: + ap[j] = average_precision_score(y, p) + except Exception: + ap[j] = np.nan + try: + acc[j] = accuracy_score(y, yhat) + except Exception: + acc[j] = np.nan + try: + bacc[j] = balanced_accuracy_score(y, yhat) + except Exception: + bacc[j] = np.nan + return ap, acc, bacc, prev + + +def proba_for_layer(lm: LayerModel, X: np.ndarray) -> np.ndarray: + """Return P(y=1) per target column.""" + pr = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore + if isinstance(pr, list): + return np.stack([p[:, 1] for p in pr], axis=1) + return pr[..., 1] # type: ignore + + +def get_estimator_for( + models: list[LayerModel], layer_idx: int, target_idx: int +) -> DecisionTreeClassifier: + """Fetch the per-output estimator for a given layer and column.""" + lm = next(m for m in models if m.layer_index == layer_idx) + if isinstance(lm.model, MultiOutputClassifier): + return lm.model.estimators_[target_idx] # type: ignore + return lm.model # type: ignore diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py new file mode 100644 index 000000000..9685a3695 --- /dev/null +++ b/spd/clustering/ci_dt/plot.py @@ -0,0 +1,102 @@ +"""Plotting functions for causal importance decision trees.""" + +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +from sklearn.tree import plot_tree + +from spd.clustering.ci_dt.core import LayerModel, get_estimator_for + + +def plot_activations(layers_true: list[np.ndarray], layers_pred: list[np.ndarray]) -> None: + """Show true and predicted activations as heatmaps.""" + A_true: np.ndarray = np.concatenate(layers_true, axis=1) + A_pred: np.ndarray = np.concatenate([layers_pred[0]] + layers_pred[1:], axis=1) + fig1 = plt.figure(figsize=(10, 6)) + ax1 = fig1.add_subplot(2, 1, 1) + ax1.set_title("Activations (True)") + ax1.imshow(A_true, aspect="auto", interpolation="nearest") + ax1.set_xlabel("components (all layers concatenated)") + ax1.set_ylabel("samples") + ax2 = fig1.add_subplot(2, 1, 2) + ax2.set_title("Activations (Predicted)") + ax2.imshow(A_pred, aspect="auto", interpolation="nearest") + ax2.set_xlabel("components (all layers concatenated)") + ax2.set_ylabel("samples") + fig1.tight_layout() + + +def plot_covariance(layers_true: list[np.ndarray]) -> None: + """Plot covariance between all components across layers.""" + A: np.ndarray = np.concatenate(layers_true, axis=1).astype(float) + C: np.ndarray = np.cov(A, rowvar=False) + fig2 = plt.figure(figsize=(6, 6)) + ax = fig2.add_subplot(1, 1, 1) + ax.set_title("Covariance of components (all layers)") + ax.imshow(C, aspect="auto", interpolation="nearest") + ax.set_xlabel("component index") + ax.set_ylabel("component index") + fig2.tight_layout() + + +def plot_layer_metrics(per_layer_stats: list[dict[str, Any]]) -> None: + """Plot summary metrics per layer and per-target AP vs prevalence.""" + L: int = len(per_layer_stats) + mean_ap: np.ndarray = np.array([d["mean_ap"] for d in per_layer_stats]) + mean_acc: np.ndarray = np.array([d["mean_acc"] for d in per_layer_stats]) + mean_bacc: np.ndarray = np.array([d["mean_bacc"] for d in per_layer_stats]) + + # bar: mean AP, ACC, BACC per layer (three separate figures to respect one-plot rule) + fig3 = plt.figure(figsize=(8, 3)) + ax3 = fig3.add_subplot(1, 1, 1) + ax3.set_title("Mean Average Precision per layer") + ax3.bar(np.arange(1, L + 1), mean_ap) + ax3.set_xlabel("layer index (target)") + ax3.set_ylabel("mean AP") + fig3.tight_layout() + + fig4 = plt.figure(figsize=(8, 3)) + ax4 = fig4.add_subplot(1, 1, 1) + ax4.set_title("Mean Accuracy per layer") + ax4.bar(np.arange(1, L + 1), mean_acc) + ax4.set_xlabel("layer index (target)") + ax4.set_ylabel("mean accuracy") + fig4.tight_layout() + + fig5 = plt.figure(figsize=(8, 3)) + ax5 = fig5.add_subplot(1, 1, 1) + ax5.set_title("Mean Balanced Accuracy per layer") + ax5.bar(np.arange(1, L + 1), mean_bacc) + ax5.set_xlabel("layer index (target)") + ax5.set_ylabel("mean balanced accuracy") + fig5.tight_layout() + + # scatter: prevalence vs AP for all targets across layers + fig6 = plt.figure(figsize=(6, 5)) + ax6 = fig6.add_subplot(1, 1, 1) + ax6.set_title("Per-target AP vs prevalence") + x_list: list[float] = [] + y_list: list[float] = [] + for d in per_layer_stats: + x_list.extend(list(d["prev"])) + y_list.extend(list(d["ap"])) + ax6.scatter(x_list, y_list, alpha=0.6) + ax6.set_xlabel("prevalence") + ax6.set_ylabel("average precision") + fig6.tight_layout() + + +def plot_selected_trees( + picks: list[tuple[int, int, float]], + title_prefix: str, + models: list[LayerModel], +) -> None: + """Plot a list of selected trees by (layer, target_idx, score).""" + for layer_idx, target_idx, score in picks: + est = get_estimator_for(models, layer_idx, target_idx) + fig = plt.figure(figsize=(10, 6)) + ax = fig.add_subplot(1, 1, 1) + ax.set_title(f"{title_prefix}: layer {layer_idx}, target {target_idx}, AP={score:.3f}") + plot_tree(est, ax=ax, filled=False) # default styling + fig.tight_layout() diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py new file mode 100644 index 000000000..cec065266 --- /dev/null +++ b/spd/clustering/ci_dt/run.py @@ -0,0 +1,170 @@ +# %% +"""Main execution script for causal importance decision tree training.""" + +from typing import Any + +import numpy as np +import torch +from torch import Tensor + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.ci_dt.config import CIDTConfig +from spd.clustering.ci_dt.core import ( + LayerModel, + build_xy, + layer_metrics, + proba_for_layer, + predict_all, + train_trees, +) +from spd.clustering.ci_dt.plot import ( + plot_activations, + plot_covariance, + plot_layer_metrics, + plot_selected_trees, +) +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, SPDRunInfo + +# ----------------------- configuration ----------------------- + +config = CIDTConfig( + n_samples=10, + activation_threshold=0.01, + filter_dead_threshold=0.001, + max_depth=8, + random_state=42, +) +device: str = "cuda" if torch.cuda.is_available() else "cpu" + +# ----------------------- load model ----------------------- + +wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" + +spd_run: SPDRunInfo = SPDRunInfo.from_path(wandb_run_path) +model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) +model.to(device) +cfg: Config = spd_run.config + +print(f"Loaded model from {wandb_run_path}") + +# ----------------------- load dataset ----------------------- + +# Create LM dataset and dataloader +assert isinstance(cfg.task_config, LMTaskConfig) +pretrained_model_name = cfg.pretrained_model_name +assert pretrained_model_name is not None + +dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + column_name=cfg.task_config.column_name, + is_tokenized=False, + streaming=False, + seed=0, +) +dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=config.n_samples, + buffer_size=cfg.task_config.buffer_size, + global_seed=cfg.seed, + ddp_rank=0, + ddp_world_size=1, +) +batch_data = next(iter(dataloader)) +batch: Tensor = batch_data["input_ids"] +print(f"Created LM dataset with {cfg.task_config.dataset_name}, batch shape: {batch.shape}") + +# ----------------------- get activations ----------------------- + +# Get component activations (on device) +print("Computing component activations...") +component_acts: dict[str, Tensor] = component_activations( + model=model, + device=device, + batch=batch, +) + +# Process activations (filter dead components, concatenate) +print("Processing activations...") +processed_acts: ProcessedActivations = process_activations( + component_acts, + filter_dead_threshold=config.filter_dead_threshold, + seq_mode="seq_mean", # LM task needs seq_mean +) + +print(f"Total components (before filtering): {processed_acts.n_components_original}") +print(f"Alive components: {processed_acts.n_components_alive}") +print(f"Dead components: {processed_acts.n_components_dead}") +print(f"Module keys: {processed_acts.module_keys}") + +# ----------------------- convert to layers ----------------------- + +# Move to CPU and convert to numpy for sklearn +# Group by module to create "layers" for decision trees +print("\nConverting to boolean layers...") +layers_true: list[np.ndarray] = [] +for module_key in processed_acts.module_keys: + # Get the activations for this module from activations_raw, move to CPU + module_acts_cpu = processed_acts.activations_raw[module_key].cpu().numpy() + module_acts_bool = (module_acts_cpu >= config.activation_threshold).astype(bool) + layers_true.append(module_acts_bool) + print(f"Layer {len(layers_true) - 1} ({module_key}): {module_acts_bool.shape[1]} components") + +print(f"\nCreated {len(layers_true)} layers for decision tree training") + +# ----------------------- fit and predict ----------------------- + +print("\nTraining decision trees...") +models: list[LayerModel] = train_trees( + layers_true, max_depth=config.max_depth, random_state=config.random_state +) +layers_pred: list[np.ndarray] = predict_all(models, [layers_true[0]]) + +# ----------------------- metrics ----------------------- + +XYs_demo = build_xy(layers_true) +per_layer_stats: list[dict[str, Any]] = [] +all_triplets: list[tuple[int, int, float]] = [] # (layer, target_idx, AP) + +for lm, (Xk, Yk) in zip(models, XYs_demo, strict=True): + Pk: np.ndarray = proba_for_layer(lm, Xk) + Yhat_k: np.ndarray = Pk >= 0.5 + ap, acc, bacc, prev = layer_metrics(Yk, Pk, Yhat_k) + per_layer_stats.append( + { + "ap": ap, + "acc": acc, + "bacc": bacc, + "prev": prev, + "mean_ap": float(np.nanmean(ap)), + "mean_acc": float(np.nanmean(acc)), + "mean_bacc": float(np.nanmean(bacc)), + } + ) + for j, apj in enumerate(ap): + all_triplets.append((lm.layer_index, j, float(apj))) + +# identify best and worst trees across all outputs by AP +sorted_triplets = sorted(all_triplets, key=lambda t: (np.isnan(t[2]), t[2])) +worst_list = [t for t in sorted_triplets if not np.isnan(t[2])][:2] +best_list = [t for t in sorted_triplets if not np.isnan(t[2])][-2:] + +# ----------------------- plotting ----------------------- + +# Run the plots +plot_activations(layers_true, layers_pred) +plot_covariance(layers_true) +plot_layer_metrics(per_layer_stats) +plot_selected_trees(worst_list, "Worst", models) +plot_selected_trees(best_list, "Best", models) + +print("Plots generated.") From e35a7b18c2e7d02e1557980b035e4daf4e5988b8 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 13:11:55 +0100 Subject: [PATCH 17/77] wip --- spd/clustering/ci_dt/core.py | 15 +++++++++++---- spd/clustering/ci_dt/run.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index bfdaa27fd..03d2cd511 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -1,5 +1,6 @@ """Core library functions for causal importance decision trees.""" +from collections import Counter from collections.abc import Sequence from dataclasses import dataclass from typing import Literal @@ -14,6 +15,7 @@ ) from sklearn.multioutput import MultiOutputClassifier from sklearn.tree import DecisionTreeClassifier +from muutils.dbg import dbg_auto, dbg, dbg_tensor @dataclass @@ -84,12 +86,17 @@ def predict_k( """Predict layer k activations from layers[:k].""" lm: LayerModel = next(m for m in models if m.layer_index == k) X: np.ndarray = concat_cols(prefix_layers) + dbg_auto(X) proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore - if isinstance(proba, list): - P: np.ndarray = np.stack([p[:, 1] for p in proba], axis=1) - else: - P = proba[..., 1] # type: ignore + dbg_auto(proba) + dbg_auto(proba[0]) + + dbg(Counter(tuple(p.shape) for p in proba)) + + P: np.ndarray = np.stack(proba, axis=1) + dbg_auto(P) Y_hat: np.ndarray = (float(threshold) <= P).astype(bool) + dbg_auto(Y_hat) return Y_hat diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index cec065266..770ad7f3b 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -32,6 +32,12 @@ from spd.experiments.lm.configs import LMTaskConfig from spd.models.component_model import ComponentModel, SPDRunInfo + +# magic autoreload +%load_ext autoreload +%autoreload 2 + +# %% # ----------------------- configuration ----------------------- config = CIDTConfig( @@ -43,6 +49,7 @@ ) device: str = "cuda" if torch.cuda.is_available() else "cpu" +# %% # ----------------------- load model ----------------------- wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" @@ -54,6 +61,7 @@ print(f"Loaded model from {wandb_run_path}") +# %% # ----------------------- load dataset ----------------------- # Create LM dataset and dataloader @@ -83,6 +91,7 @@ batch: Tensor = batch_data["input_ids"] print(f"Created LM dataset with {cfg.task_config.dataset_name}, batch shape: {batch.shape}") +# %% # ----------------------- get activations ----------------------- # Get component activations (on device) @@ -106,6 +115,7 @@ print(f"Dead components: {processed_acts.n_components_dead}") print(f"Module keys: {processed_acts.module_keys}") +# %% # ----------------------- convert to layers ----------------------- # Move to CPU and convert to numpy for sklearn @@ -121,6 +131,7 @@ print(f"\nCreated {len(layers_true)} layers for decision tree training") +# %% # ----------------------- fit and predict ----------------------- print("\nTraining decision trees...") @@ -129,6 +140,7 @@ ) layers_pred: list[np.ndarray] = predict_all(models, [layers_true[0]]) +# %% # ----------------------- metrics ----------------------- XYs_demo = build_xy(layers_true) @@ -158,6 +170,7 @@ worst_list = [t for t in sorted_triplets if not np.isnan(t[2])][:2] best_list = [t for t in sorted_triplets if not np.isnan(t[2])][-2:] +# %% # ----------------------- plotting ----------------------- # Run the plots From c6cadd46e3d92f0227054a685c03db1a5a51edf0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 14:00:06 +0100 Subject: [PATCH 18/77] wip --- spd/clustering/ci_dt/__init__.py | 2 + spd/clustering/ci_dt/core.py | 67 ++++++++++++++++++++------------ 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/spd/clustering/ci_dt/__init__.py b/spd/clustering/ci_dt/__init__.py index 1451a7f59..7dde8ab07 100644 --- a/spd/clustering/ci_dt/__init__.py +++ b/spd/clustering/ci_dt/__init__.py @@ -5,6 +5,7 @@ LayerModel, build_xy, concat_cols, + extract_prob_class_1, get_estimator_for, layer_metrics, predict_all, @@ -27,6 +28,7 @@ "concat_cols", "build_xy", "train_trees", + "extract_prob_class_1", "predict_k", "predict_all", "layer_metrics", diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index 03d2cd511..8e6a46342 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -3,11 +3,10 @@ from collections import Counter from collections.abc import Sequence from dataclasses import dataclass -from typing import Literal import numpy as np from jaxtyping import Bool, Float -from sklearn.base import ClassifierMixin +from muutils.dbg import dbg, dbg_auto, dbg_tensor from sklearn.metrics import ( accuracy_score, average_precision_score, @@ -15,7 +14,6 @@ ) from sklearn.multioutput import MultiOutputClassifier from sklearn.tree import DecisionTreeClassifier -from muutils.dbg import dbg_auto, dbg, dbg_tensor @dataclass @@ -23,7 +21,7 @@ class LayerModel: """Holds a trained per-layer model.""" layer_index: int - model: ClassifierMixin + model: MultiOutputClassifier feature_dim: int target_dim: int @@ -56,12 +54,11 @@ def build_xy( def train_trees( layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], *, - strategy: Literal["one_vs_all", "single_tree"] = "one_vs_all", max_depth: int | None = None, min_samples_leaf: int = 1, random_state: int | None = 0, ) -> list[LayerModel]: - """Train one model per target layer using previous layers as features.""" + """Train one decision tree per component per target layer using previous layers as features.""" XYs = build_xy(layers) models: list[LayerModel] = [] for k, (X_k, Y_k) in enumerate(XYs, start=1): @@ -70,12 +67,40 @@ def train_trees( min_samples_leaf=min_samples_leaf, random_state=random_state, ) - model: ClassifierMixin = MultiOutputClassifier(base) if strategy == "one_vs_all" else base - _ = model.fit(X_k.astype(np.uint8), Y_k.astype(np.uint8)) + model = MultiOutputClassifier(base) + model.fit(X_k.astype(np.uint8), Y_k.astype(np.uint8)) models.append(LayerModel(k, model, int(X_k.shape[1]), int(Y_k.shape[1]))) return models +def extract_prob_class_1( + proba_list: list[np.ndarray], + model: MultiOutputClassifier, +) -> np.ndarray: + """Extract P(y=1) for each output, handling constant components. + + When a component is always 0 or always 1 in training data, + sklearn only returns probabilities for the observed class. + This function handles all cases correctly. + """ + result: list[np.ndarray] = [] + for i, p in enumerate(proba_list): + estimator = model.estimators_[i] + classes = estimator.classes_ + if len(classes) == 1: + # Only one class observed during training + if classes[0] == 0: + # Only saw class 0, so P(y=1) = 0 + result.append(np.zeros(p.shape[0])) + else: # classes[0] == 1 + # Only saw class 1, so P(y=1) = 1 + result.append(np.ones(p.shape[0])) + else: + # Saw both classes, extract P(y=1) from second column + result.append(p[:, 1]) + return np.stack(result, axis=1) + + def predict_k( models: Sequence[LayerModel], prefix_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], @@ -86,17 +111,15 @@ def predict_k( """Predict layer k activations from layers[:k].""" lm: LayerModel = next(m for m in models if m.layer_index == k) X: np.ndarray = concat_cols(prefix_layers) - dbg_auto(X) + # dbg_auto(X) proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore - dbg_auto(proba) - dbg_auto(proba[0]) - + # dbg_auto(proba) + # dbg_auto(proba[0]) dbg(Counter(tuple(p.shape) for p in proba)) - - P: np.ndarray = np.stack(proba, axis=1) - dbg_auto(P) - Y_hat: np.ndarray = (float(threshold) <= P).astype(bool) - dbg_auto(Y_hat) + P: np.ndarray = extract_prob_class_1(proba, lm.model) + # dbg_auto(P) + Y_hat: np.ndarray = (P >= threshold).astype(bool) + # dbg_auto(Y_hat) return Y_hat @@ -148,10 +171,8 @@ def layer_metrics( def proba_for_layer(lm: LayerModel, X: np.ndarray) -> np.ndarray: """Return P(y=1) per target column.""" - pr = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore - if isinstance(pr, list): - return np.stack([p[:, 1] for p in pr], axis=1) - return pr[..., 1] # type: ignore + proba_list = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore + return extract_prob_class_1(proba_list, lm.model) def get_estimator_for( @@ -159,6 +180,4 @@ def get_estimator_for( ) -> DecisionTreeClassifier: """Fetch the per-output estimator for a given layer and column.""" lm = next(m for m in models if m.layer_index == layer_idx) - if isinstance(lm.model, MultiOutputClassifier): - return lm.model.estimators_[target_idx] # type: ignore - return lm.model # type: ignore + return lm.model.estimators_[target_idx] # type: ignore From 2b74497965b4c1f846e30b948d13d7d89b3d2745 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 14:12:46 +0100 Subject: [PATCH 19/77] wip --- spd/clustering/ci_dt/core.py | 4 +-- spd/clustering/ci_dt/run.py | 67 ++++++++++++++++++++++++++++++------ 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index 8e6a46342..29988ca2c 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -6,7 +6,7 @@ import numpy as np from jaxtyping import Bool, Float -from muutils.dbg import dbg, dbg_auto, dbg_tensor +from muutils.dbg import dbg from sklearn.metrics import ( accuracy_score, average_precision_score, @@ -118,7 +118,7 @@ def predict_k( dbg(Counter(tuple(p.shape) for p in proba)) P: np.ndarray = extract_prob_class_1(proba, lm.model) # dbg_auto(P) - Y_hat: np.ndarray = (P >= threshold).astype(bool) + Y_hat: np.ndarray = (threshold <= P).astype(bool) # dbg_auto(Y_hat) return Y_hat diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 770ad7f3b..eb1ebc356 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -5,6 +5,7 @@ import numpy as np import torch +from jaxtyping import Bool, Float from torch import Tensor from spd.clustering.activations import ( @@ -121,13 +122,36 @@ # Move to CPU and convert to numpy for sklearn # Group by module to create "layers" for decision trees print("\nConverting to boolean layers...") -layers_true: list[np.ndarray] = [] +layers_true: list[Bool[np.ndarray, "n_samples n_components"]] = [] for module_key in processed_acts.module_keys: # Get the activations for this module from activations_raw, move to CPU - module_acts_cpu = processed_acts.activations_raw[module_key].cpu().numpy() - module_acts_bool = (module_acts_cpu >= config.activation_threshold).astype(bool) - layers_true.append(module_acts_bool) - print(f"Layer {len(layers_true) - 1} ({module_key}): {module_acts_bool.shape[1]} components") + module_acts_cpu: Float[np.ndarray, "n_samples n_components"] = ( + processed_acts.activations_raw[module_key].cpu().numpy() + ) + module_acts_bool: Bool[np.ndarray, "n_samples n_components"] = ( + module_acts_cpu >= config.activation_threshold + ).astype(bool) + + # Filter out components that are always dead or always alive + # (they provide no information for decision trees) + n_before: int = module_acts_bool.shape[1] + component_variance: Float[np.ndarray, "n_components"] = module_acts_bool.var(axis=0) + varying_mask: Bool[np.ndarray, "n_components"] = component_variance > 0 + + # Count always-dead and always-alive components for diagnostics + always_dead_mask: Bool[np.ndarray, "n_components"] = ~module_acts_bool.any(axis=0) + always_alive_mask: Bool[np.ndarray, "n_components"] = module_acts_bool.all(axis=0) + n_always_dead: int = always_dead_mask.sum() + n_always_alive: int = always_alive_mask.sum() + + module_acts_filtered: Bool[np.ndarray, "n_samples n_varying"] = module_acts_bool[:, varying_mask] + n_after: int = module_acts_filtered.shape[1] + + layers_true.append(module_acts_filtered) + print( + f"Layer {len(layers_true) - 1} ({module_key}): {n_after} varying components " + f"({n_always_dead} always dead, {n_always_alive} always alive removed)" + ) print(f"\nCreated {len(layers_true)} layers for decision tree training") @@ -171,13 +195,36 @@ best_list = [t for t in sorted_triplets if not np.isnan(t[2])][-2:] # %% -# ----------------------- plotting ----------------------- +# ----------------------- plot: layer metrics ----------------------- +# Simplest - just bar charts and scatter plot of summary statistics + +plot_layer_metrics(per_layer_stats) +print("Layer metrics plots generated.") + +# %% +# ----------------------- plot: activations ----------------------- +# Simple heatmaps of true vs predicted activations -# Run the plots plot_activations(layers_true, layers_pred) +print("Activation plots generated.") + +# %% +# ----------------------- plot: covariance ----------------------- +# Covariance matrix - can be slow with many components + plot_covariance(layers_true) -plot_layer_metrics(per_layer_stats) +print("Covariance plot generated.") + +# %% +# ----------------------- plot: worst trees ----------------------- +# Decision tree visualization for worst performing trees + plot_selected_trees(worst_list, "Worst", models) -plot_selected_trees(best_list, "Best", models) +print("Worst trees plots generated.") + +# %% +# ----------------------- plot: best trees ----------------------- +# Decision tree visualization for best performing trees -print("Plots generated.") +plot_selected_trees(best_list, "Best", models) +print("Best trees plots generated.") From 73a483543e3b1339d25d6ba40331c24332bba0ac Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 20 Oct 2025 15:35:07 +0100 Subject: [PATCH 20/77] wip --- pyproject.toml | 1 + spd/clustering/ci_dt/core.py | 1 - spd/clustering/ci_dt/plot.py | 133 ++++++++++++++++++++++++++++ spd/clustering/ci_dt/run.py | 10 ++- uv.lock | 162 +++++++++++++++++++++++++++++++++++ 5 files changed, 305 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 86348e2ba..cb18a3b93 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dev = [ "ruff", "basedpyright", "pre-commit", + "nbconvert", ] [project.scripts] diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index 29988ca2c..a5ceac916 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -115,7 +115,6 @@ def predict_k( proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore # dbg_auto(proba) # dbg_auto(proba[0]) - dbg(Counter(tuple(p.shape) for p in proba)) P: np.ndarray = extract_prob_class_1(proba, lm.model) # dbg_auto(P) Y_hat: np.ndarray = (threshold <= P).astype(bool) diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 9685a3695..92cedb84b 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import numpy as np +from jaxtyping import Float, Int from sklearn.tree import plot_tree from spd.clustering.ci_dt.core import LayerModel, get_estimator_for @@ -100,3 +101,135 @@ def plot_selected_trees( ax.set_title(f"{title_prefix}: layer {layer_idx}, target {target_idx}, AP={score:.3f}") plot_tree(est, ax=ax, filled=False) # default styling fig.tight_layout() + + +def extract_tree_stats( + models: list[LayerModel], + per_layer_stats: list[dict[str, Any]], +) -> dict[str, Float[np.ndarray, "n_trees"]]: + """Extract depth, leaf count, and accuracy for all trees across all layers.""" + depths: list[int] = [] + leaf_counts: list[int] = [] + accuracies: list[float] = [] + balanced_accuracies: list[float] = [] + aps: list[float] = [] + + for lm, stats in zip(models, per_layer_stats, strict=True): + for i, estimator in enumerate(lm.model.estimators_): + depths.append(int(estimator.tree_.max_depth)) + leaf_counts.append(int(estimator.tree_.n_leaves)) + accuracies.append(float(stats["acc"][i])) + balanced_accuracies.append(float(stats["bacc"][i])) + aps.append(float(stats["ap"][i])) + + return { + "depth": np.array(depths), + "n_leaves": np.array(leaf_counts), + "accuracy": np.array(accuracies), + "balanced_accuracy": np.array(balanced_accuracies), + "ap": np.array(aps), + } + + +def plot_tree_statistics( + models: list[LayerModel], per_layer_stats: list[dict[str, Any]] +) -> None: + """Plot distributions of tree depth, leaf count, and their correlations with accuracy.""" + stats = extract_tree_stats(models, per_layer_stats) + + # Distribution of tree depths + fig1, ax1 = plt.subplots() + ax1.hist(stats["depth"], bins=range(int(stats["depth"].max()) + 2)) + ax1.set_yscale("log") + ax1.set_xlabel("Tree depth") + ax1.set_ylabel("Count (log scale)") + + # Distribution of leaf counts + fig2, ax2 = plt.subplots() + ax2.hist(stats["n_leaves"], bins=50) + ax2.set_yscale("log") + ax2.set_xlabel("Number of leaves") + ax2.set_ylabel("Count (log scale)") + + # Distribution of accuracies + fig3, ax3 = plt.subplots() + ax3.hist(stats["accuracy"][~np.isnan(stats["accuracy"])], bins=30) + ax3.set_yscale("log") + ax3.set_xlabel("Accuracy") + ax3.set_ylabel("Count (log scale)") + + # Heatmap: depth vs accuracy + valid_mask: np.ndarray = ~np.isnan(stats["accuracy"]) + depth_bins: Int[np.ndarray, "n_bins"] = np.arange( + int(stats["depth"].min()), int(stats["depth"].max()) + 2 + ) + acc_bins: Float[np.ndarray, "n_bins"] = np.linspace(0, 1, 11) + heatmap_depth_acc: Float[np.ndarray, "depth_bins acc_bins"] + heatmap_depth_acc, _, _ = np.histogram2d( + stats["depth"][valid_mask], stats["accuracy"][valid_mask], bins=[depth_bins, acc_bins] + ) + + fig4, ax4 = plt.subplots() + heatmap_log: Float[np.ndarray, "depth_bins acc_bins"] = np.log10( + heatmap_depth_acc.T + 1 + ) # +1 to avoid log(0) + im = ax4.imshow(heatmap_log, origin="lower", aspect="auto", cmap="Blues") + ax4.set_xticks(range(len(depth_bins) - 1)) + ax4.set_xticklabels(depth_bins[:-1]) + ax4.set_yticks(range(len(acc_bins) - 1)) + ax4.set_yticklabels([f"{x:.1f}" for x in acc_bins[:-1]]) + ax4.set_xlabel("Tree depth") + ax4.set_ylabel("Accuracy") + for i in range(len(depth_bins) - 1): + for j in range(len(acc_bins) - 1): + count: int = int(heatmap_depth_acc[i, j]) + if count > 0: + ax4.text(i, j, str(count), ha="center", va="center") + plt.colorbar(im, ax=ax4, label="log10(count+1)") + + # Heatmap: leaf count vs accuracy + leaf_bins: Int[np.ndarray, "n_bins"] = np.linspace( + int(stats["n_leaves"].min()), int(stats["n_leaves"].max()) + 1, 11, dtype=int + ) + heatmap_leaf_acc: Float[np.ndarray, "leaf_bins acc_bins"] + heatmap_leaf_acc, _, _ = np.histogram2d( + stats["n_leaves"][valid_mask], stats["accuracy"][valid_mask], bins=[leaf_bins, acc_bins] + ) + + fig5, ax5 = plt.subplots() + heatmap_log = np.log10(heatmap_leaf_acc.T + 1) + im = ax5.imshow(heatmap_log, origin="lower", aspect="auto", cmap="Blues") + ax5.set_xticks(range(len(leaf_bins) - 1)) + ax5.set_xticklabels(leaf_bins[:-1]) + ax5.set_yticks(range(len(acc_bins) - 1)) + ax5.set_yticklabels([f"{x:.1f}" for x in acc_bins[:-1]]) + ax5.set_xlabel("Number of leaves") + ax5.set_ylabel("Accuracy") + for i in range(len(leaf_bins) - 1): + for j in range(len(acc_bins) - 1): + count: int = int(heatmap_leaf_acc[i, j]) + if count > 0: + ax5.text(i, j, str(count), ha="center", va="center") + plt.colorbar(im, ax=ax5, label="log10(count+1)") + + # Heatmap: depth vs leaf count + heatmap_depth_leaf: Float[np.ndarray, "depth_bins leaf_bins"] + heatmap_depth_leaf, _, _ = np.histogram2d( + stats["depth"][valid_mask], stats["n_leaves"][valid_mask], bins=[depth_bins, leaf_bins] + ) + + fig6, ax6 = plt.subplots() + heatmap_log = np.log10(heatmap_depth_leaf.T + 1) + im = ax6.imshow(heatmap_log, origin="lower", aspect="auto", cmap="Blues") + ax6.set_xticks(range(len(depth_bins) - 1)) + ax6.set_xticklabels(depth_bins[:-1]) + ax6.set_yticks(range(len(leaf_bins) - 1)) + ax6.set_yticklabels(leaf_bins[:-1]) + ax6.set_xlabel("Tree depth") + ax6.set_ylabel("Number of leaves") + for i in range(len(depth_bins) - 1): + for j in range(len(leaf_bins) - 1): + count: int = int(heatmap_depth_leaf[i, j]) + if count > 0: + ax6.text(i, j, str(count), ha="center", va="center") + plt.colorbar(im, ax=ax6, label="log10(count+1)") diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index eb1ebc356..cc4333d2f 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -27,6 +27,7 @@ plot_covariance, plot_layer_metrics, plot_selected_trees, + plot_tree_statistics, ) from spd.configs import Config from spd.data import DatasetConfig, create_data_loader @@ -42,7 +43,7 @@ # ----------------------- configuration ----------------------- config = CIDTConfig( - n_samples=10, + n_samples=64, # batch size 64 -> 16GB vram activation_threshold=0.01, filter_dead_threshold=0.001, max_depth=8, @@ -201,6 +202,13 @@ plot_layer_metrics(per_layer_stats) print("Layer metrics plots generated.") +# %% +# ----------------------- plot: tree statistics ----------------------- +# Distributions of tree depth, leaf counts, and correlations with accuracy + +plot_tree_statistics(models, per_layer_stats) +print("Tree statistics plots generated.") + # %% # ----------------------- plot: activations ----------------------- # Simple heatmaps of true vs predicted activations diff --git a/uv.lock b/uv.lock index 91098f2ea..0fde83f87 100644 --- a/uv.lock +++ b/uv.lock @@ -138,6 +138,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, ] +[[package]] +name = "beautifulsoup4" +version = "4.14.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "soupsieve" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/e9/df2358efd7659577435e2177bfa69cba6c33216681af51a707193dec162a/beautifulsoup4-4.14.2.tar.gz", hash = "sha256:2a98ab9f944a11acee9cc848508ec28d9228abfd522ef0fad6a02a72e0ded69e", size = 625822, upload-time = "2025-09-29T10:05:42.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/fe/3aed5d0be4d404d12d36ab97e2f1791424d9ca39c2f754a6285d59a3b01d/beautifulsoup4-4.14.2-py3-none-any.whl", hash = "sha256:5ef6fa3a8cbece8488d66985560f97ed091e22bbc4e9c2338508a9d5de6d4515", size = 106392, upload-time = "2025-09-29T10:05:43.771Z" }, +] + +[[package]] +name = "bleach" +version = "6.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/9a/0e33f5054c54d349ea62c277191c020c2d6ef1d65ab2cb1993f91ec846d1/bleach-6.2.0.tar.gz", hash = "sha256:123e894118b8a599fd80d3ec1a6d4cc7ce4e5882b1317a7e1ba69b56e95f991f", size = 203083, upload-time = "2024-10-29T18:30:40.477Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/55/96142937f66150805c25c4d0f31ee4132fd33497753400734f9dfdcbdc66/bleach-6.2.0-py3-none-any.whl", hash = "sha256:117d9c6097a7c3d22fd578fcd8d35ff1e125df6736f554da4e432fdd63f31e5e", size = 163406, upload-time = "2024-10-29T18:30:38.186Z" }, +] + +[package.optional-dependencies] +css = [ + { name = "tinycss2" }, +] + [[package]] name = "blinker" version = "1.9.0" @@ -376,6 +406,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, +] + [[package]] name = "dill" version = "0.4.0" @@ -421,6 +460,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, ] +[[package]] +name = "fastjsonschema" +version = "2.21.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/b5/23b216d9d985a956623b6bd12d4086b60f0059b27799f23016af04a74ea1/fastjsonschema-2.21.2.tar.gz", hash = "sha256:b1eb43748041c880796cd077f1a07c3d94e93ae84bba5ed36800a33554ae05de", size = 374130, upload-time = "2025-08-14T18:49:36.666Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" }, +] + [[package]] name = "filelock" version = "3.20.0" @@ -794,6 +842,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, ] +[[package]] +name = "jupyterlab-pygments" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/51/9187be60d989df97f5f0aba133fa54e7300f17616e065d1ada7d7646b6d6/jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d", size = 512900, upload-time = "2023-11-23T09:26:37.44Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780", size = 15884, upload-time = "2023-11-23T09:26:34.325Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.9" @@ -902,6 +959,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, ] +[[package]] +name = "mistune" +version = "3.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/02/a7fb8b21d4d55ac93cdcde9d3638da5dd0ebdd3a4fed76c7725e10b81cbe/mistune-3.1.4.tar.gz", hash = "sha256:b5a7f801d389f724ec702840c11d8fc48f2b33519102fc7ee739e8177b672164", size = 94588, upload-time = "2025-08-29T07:20:43.594Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/f0/8282d9641415e9e33df173516226b404d367a0fc55e1a60424a152913abc/mistune-3.1.4-py3-none-any.whl", hash = "sha256:93691da911e5d9d2e23bc54472892aff676df27a75274962ff9edc210364266d", size = 53481, upload-time = "2025-08-29T07:20:42.218Z" }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -990,6 +1056,61 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1d/86/ac808ecb94322a3f1ea31627d13ab3e50dd4333564d711e0e481ad0f4586/narwhals-2.8.0-py3-none-any.whl", hash = "sha256:6304856676ba4a79fd34148bda63aed8060dd6edb1227edf3659ce5e091de73c", size = 415852, upload-time = "2025-10-13T08:44:25.421Z" }, ] +[[package]] +name = "nbclient" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "nbformat" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/66/7ffd18d58eae90d5721f9f39212327695b749e23ad44b3881744eaf4d9e8/nbclient-0.10.2.tar.gz", hash = "sha256:90b7fc6b810630db87a6d0c2250b1f0ab4cf4d3c27a299b0cde78a4ed3fd9193", size = 62424, upload-time = "2024-12-19T10:32:27.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/6d/e7fa07f03a4a7b221d94b4d586edb754a9b0dc3c9e2c93353e9fa4e0d117/nbclient-0.10.2-py3-none-any.whl", hash = "sha256:4ffee11e788b4a27fabeb7955547e4318a5298f34342a4bfd01f2e1faaeadc3d", size = 25434, upload-time = "2024-12-19T10:32:24.139Z" }, +] + +[[package]] +name = "nbconvert" +version = "7.16.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "beautifulsoup4" }, + { name = "bleach", extra = ["css"] }, + { name = "defusedxml" }, + { name = "jinja2" }, + { name = "jupyter-core" }, + { name = "jupyterlab-pygments" }, + { name = "markupsafe" }, + { name = "mistune" }, + { name = "nbclient" }, + { name = "nbformat" }, + { name = "packaging" }, + { name = "pandocfilters" }, + { name = "pygments" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/59/f28e15fc47ffb73af68a8d9b47367a8630d76e97ae85ad18271b9db96fdf/nbconvert-7.16.6.tar.gz", hash = "sha256:576a7e37c6480da7b8465eefa66c17844243816ce1ccc372633c6b71c3c0f582", size = 857715, upload-time = "2025-01-28T09:29:14.724Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/9a/cd673b2f773a12c992f41309ef81b99da1690426bd2f96957a7ade0d3ed7/nbconvert-7.16.6-py3-none-any.whl", hash = "sha256:1375a7b67e0c2883678c48e506dc320febb57685e5ee67faa51b18a90f3a712b", size = 258525, upload-time = "2025-01-28T09:29:12.551Z" }, +] + +[[package]] +name = "nbformat" +version = "5.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastjsonschema" }, + { name = "jsonschema" }, + { name = "jupyter-core" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/fd/91545e604bc3dad7dca9ed03284086039b294c6b3d75c0d2fa45f9e9caf3/nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a", size = 142749, upload-time = "2024-04-04T11:20:37.371Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b", size = 78454, upload-time = "2024-04-04T11:20:34.895Z" }, +] + [[package]] name = "nest-asyncio" version = "1.6.0" @@ -1225,6 +1346,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/17/e756653095a083d8a37cbd816cb87148debcfcd920129b25f99dd8d04271/pandas-2.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c4fc4c21971a1a9f4bdb4c73978c7f7256caa3e62b323f70d6cb80db583350bc", size = 13199233, upload-time = "2025-09-29T23:24:24.876Z" }, ] +[[package]] +name = "pandocfilters" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/6f/3dd4940bbe001c06a65f88e36bad298bc7a0de5036115639926b0c5c0458/pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e", size = 8454, upload-time = "2024-01-18T20:08:13.726Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/af/4fbc8cab944db5d21b7e2a5b8e9211a03a79852b1157e2c102fcc61ac440/pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc", size = 8663, upload-time = "2024-01-18T20:08:11.28Z" }, +] + [[package]] name = "parso" version = "0.8.5" @@ -1917,6 +2047,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "soupsieve" +version = "2.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/e6/21ccce3262dd4889aa3332e5a119a3491a95e8f60939870a3a035aabac0d/soupsieve-2.8.tar.gz", hash = "sha256:e2dd4a40a628cb5f28f6d4b0db8800b8f581b65bb380b97de22ba5ca8d72572f", size = 103472, upload-time = "2025-08-27T15:39:51.78Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl", hash = "sha256:0cc76456a30e20f5d7f2e14a98a4ae2ee4e5abdc7c5ea0aafe795f344bc7984c", size = 36679, upload-time = "2025-08-27T15:39:50.179Z" }, +] + [[package]] name = "spd" version = "0.0.1" @@ -1949,6 +2088,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "basedpyright" }, + { name = "nbconvert" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -1985,6 +2125,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "basedpyright" }, + { name = "nbconvert" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -2111,6 +2252,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" }, ] +[[package]] +name = "tinycss2" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/fd/7a5ee21fd08ff70d3d33a5781c255cbe779659bd03278feb98b19ee550f4/tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7", size = 87085, upload-time = "2024-10-24T14:58:29.895Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" }, +] + [[package]] name = "tokenizers" version = "0.22.1" @@ -2408,6 +2561,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, ] +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721, upload-time = "2017-04-05T20:21:34.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, +] + [[package]] name = "xxhash" version = "3.6.0" From d80ba3f0480c9dc8c80f0c0e2048eaf261483f93 Mon Sep 17 00:00:00 2001 From: mivanit Date: Tue, 21 Oct 2025 03:03:24 -0700 Subject: [PATCH 21/77] [clustering] distance computation (#213) add `matching_dist` and `matching_dist_vec` distance computations * wip * better debugging plots * wip * wip * wip * format * pyright fixes * wip * wip * [!!!] remove old testing code * integrate distances with the rest of the pipeline * wip * wip * wip * wip * dev recipe * fix pyright warning * wip * wip * wip * wip * make format * cli control for calc_distances `time` command doesnt work in CI or on andromeda * rework configs, make tests simpler and faster * fix config typo distances_method -> distances_methods --- Makefile | 8 +- .../configs/pipeline-dev-simplestories.yaml | 9 + .../configs/pipeline-test-resid_mlp1.yaml | 4 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- spd/clustering/configs/resid_mlp1.json | 6 +- spd/clustering/configs/simplestories_dev.json | 14 +- spd/clustering/configs/test-resid_mlp1.json | 6 +- .../configs/test-simplestories.json | 2 +- spd/clustering/consts.py | 2 +- spd/clustering/math/jaccard.py | 71 ------- spd/clustering/math/jaccard_test.py | 194 ------------------ spd/clustering/math/matching_dist.py | 47 +++++ spd/clustering/math/merge_distances.py | 30 +-- spd/clustering/merge_history.py | 2 + spd/clustering/plotting/merge.py | 30 +++ spd/clustering/scripts/calc_distances.py | 16 +- spd/clustering/scripts/run_pipeline.py | 173 ++++++++++++---- spd/utils/command_utils.py | 118 +++++++++-- uv.lock | 82 ++++---- 20 files changed, 411 insertions(+), 407 deletions(-) create mode 100644 spd/clustering/configs/pipeline-dev-simplestories.yaml delete mode 100644 spd/clustering/math/jaccard.py delete mode 100644 spd/clustering/math/jaccard_test.py create mode 100644 spd/clustering/math/matching_dist.py diff --git a/Makefile b/Makefile index 85a0a4d8a..2bbfee873 100644 --- a/Makefile +++ b/Makefile @@ -82,4 +82,10 @@ clean: @echo "Cleaning Python cache and build artifacts..." find . -type d -name "__pycache__" -exec rm -rf {} + find . -type d -name "*.egg-info" -exec rm -rf {} + - rm -rf build/ dist/ .ruff_cache/ .pytest_cache/ .coverage \ No newline at end of file + rm -rf build/ dist/ .ruff_cache/ .pytest_cache/ .coverage + + +.PHONY: clustering-dev +clustering-dev: + uv run spd-cluster --local --config spd/clustering/configs/pipeline-dev-simplestories.yaml + diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml new file mode 100644 index 000000000..6909c5841 --- /dev/null +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -0,0 +1,9 @@ +run_clustering_config_path: "spd/clustering/configs/simplestories_dev.json" +n_runs: 4 +distances_methods: ["matching_dist", "matching_dist_vec", "perm_invariant_hamming"] +base_output_dir: "tests/.temp/clustering" +slurm_job_name_prefix: null +slurm_partition: null +wandb_project: null # wandb fails in CI +wandb_entity: "goodfire" +create_git_snapshot: false \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index e6680b8d0..a413a5438 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,6 +1,6 @@ run_clustering_config_path: "spd/clustering/configs/test-resid_mlp1.json" -n_runs: 2 -distances_method: "perm_invariant_hamming" +n_runs: 3 +distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" slurm_job_name_prefix: null slurm_partition: null diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index a2fc9ec9c..e406628c4 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,6 +1,6 @@ run_clustering_config_path: "spd/clustering/configs/test-simplestories.json" n_runs: 2 -distances_method: "perm_invariant_hamming" +distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" slurm_job_name_prefix: null slurm_partition: null diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 3d2085c6b..6a40c9b29 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,6 +1,6 @@ run_clustering_config_path: "spd/clustering/configs/example.yaml" n_runs: 2 -distances_method: "perm_invariant_hamming" +distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" slurm_job_name_prefix: "spd" slurm_partition: "h100-reserved" diff --git a/spd/clustering/configs/resid_mlp1.json b/spd/clustering/configs/resid_mlp1.json index e825215ee..a7d118ac7 100644 --- a/spd/clustering/configs/resid_mlp1.json +++ b/spd/clustering/configs/resid_mlp1.json @@ -2,7 +2,7 @@ "merge_config": { "activation_threshold": 0.01, "alpha": 1, - "iters": null, + "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, "pop_component_prob": 0, @@ -10,9 +10,9 @@ "module_name_filter": null }, "experiment_key": "resid_mlp1", - "distances_method": "perm_invariant_hamming", + "distances_methods": ["perm_invariant_hamming"], "n_batches": 8, - "batch_size": 1024, + "batch_size": 128, "wandb_enabled": true, "wandb_project": "spd-cluster", "intervals": { diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/simplestories_dev.json index 89cbfde06..f585e848f 100644 --- a/spd/clustering/configs/simplestories_dev.json +++ b/spd/clustering/configs/simplestories_dev.json @@ -2,21 +2,17 @@ "merge_config": { "activation_threshold": 0.1, "alpha": 1.0, - "iters": null, + "iters": 100, "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "merge_pair_sampling_kwargs": {"threshold": 0.01}, "pop_component_prob": 0, "filter_dead_threshold": 0.1, "module_name_filter": null }, - "model_path": "wandb:goodfire/spd/runs/rn9klzfs", - "task_name": "lm", - "distances_method": "jaccard", - "n_batches": 1, + "model_path": "wandb:goodfire/spd/runs/lxs77xye", "batch_size": 32, - "wandb_enabled": true, - "wandb_project": "spd-cluster", - "intervals": { + "wandb_project": null, + "logging_intervals": { "stat": 1, "tensor": 200, "plot": 2000, diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/test-resid_mlp1.json index 6dd7fb12b..01b510200 100644 --- a/spd/clustering/configs/test-resid_mlp1.json +++ b/spd/clustering/configs/test-resid_mlp1.json @@ -1,8 +1,8 @@ { "merge_config": { - "activation_threshold": 0.1, + "activation_threshold": 0.5, "alpha": 1, - "iters": 140, + "iters": 16, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, "pop_component_prob": 0, @@ -10,7 +10,7 @@ "module_name_filter": null }, "experiment_key": "resid_mlp1", - "batch_size": 100, + "batch_size": 128, "wandb_project": null, "logging_intervals": { "stat": 1, diff --git a/spd/clustering/configs/test-simplestories.json b/spd/clustering/configs/test-simplestories.json index 891177ab1..147634edb 100644 --- a/spd/clustering/configs/test-simplestories.json +++ b/spd/clustering/configs/test-simplestories.json @@ -1,6 +1,6 @@ { "merge_config": { - "activation_threshold": 0.5, + "activation_threshold": 0.9, "alpha": 1.0, "iters": 5, "merge_pair_sampling_method": "range", diff --git a/spd/clustering/consts.py b/spd/clustering/consts.py index ab824b8d7..8a9647dc8 100644 --- a/spd/clustering/consts.py +++ b/spd/clustering/consts.py @@ -11,7 +11,7 @@ # Merge arrays and distances (numpy-based for storage/analysis) MergesAtIterArray = Int[np.ndarray, "n_ens n_components"] MergesArray = Int[np.ndarray, "n_ens n_iters n_components"] -DistancesMethod = Literal["perm_invariant_hamming", "jaccard"] +DistancesMethod = Literal["perm_invariant_hamming", "matching_dist", "matching_dist_vec"] DistancesArray = Float[np.ndarray, "n_iters n_ens n_ens"] # Component and label types (NewType for stronger type safety) diff --git a/spd/clustering/math/jaccard.py b/spd/clustering/math/jaccard.py deleted file mode 100644 index 0c0b1a284..000000000 --- a/spd/clustering/math/jaccard.py +++ /dev/null @@ -1,71 +0,0 @@ -"""jaccard index between clusterings - - -we start with a matrix X: Int[np.ndarray, "s n"] where each of the s rows is a label vector of length n -we want to compute a Float["s s"] matrix $J$ of pairwise jaccard indices between the rows of X - -jaccard index between two partitions A and B is defined as: -J(A, B) = M11 / (M11 + M10 + M01) - -where: -- M11 = number of pairs clustered together in both partitions -- M10 = number of pairs clustered together in A but not in B -- M01 = number of pairs clustered together in B but not in A - -""" - -# %% -import matplotlib.pyplot as plt -import torch -from jaxtyping import Bool, Float, Int -from muutils.dbg import dbg_auto -from torch import Tensor - - -def jaccard_index( - X: Int[Tensor, "s n"], -) -> Float[Tensor, "s s"]: - """Compute the pairwise jaccard index between rows of X""" - - s_ensemble, _n_components = X.shape - dbg_auto(X) - matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] - dbg_auto(matches) - - _jaccard: Float[Tensor, "s s"] = torch.full((s_ensemble, s_ensemble), torch.nan) - - for i in range(s_ensemble): - plt.matshow(matches[i].cpu().numpy()) - plt.title(f"matches for row {i}") - plt.show() - - # for i in range(s_ensemble): - # for j in range(i, s_ensemble): - # M11: int = int((matches[i] & matches[j]).sum() - n_components) // 2 - # M10: int = int((matches[i] & ~matches[j]).sum()) // 2 - # M01: int = int((~matches[i] & matches[j]).sum()) // 2 - # if M11 + M10 + M01 == 0: - # jaccard[i, j] = float("nan") - # else: - # jaccard[i, j] = M11 / (M11 + M10 + M01) - # jaccard[j, i] = jaccard[i, j] - # dbg_auto(i, j, M11, M10, M01, jaccard[i, j]) - - return _jaccard - - -jaccard_index( - torch.tensor( - [ - # [1, 2, 3, 3], - [0, 1, 1, 2, 3, 3], - [3, 0, 0, 1, 2, 2], - [0, 3, 1, 1, 2, 2], - [0, 3, 0, 0, 1, 1], - [0, 0, 0, 0, 0, 0], - # [0, 1, 2, 3], - ] - ) -) - -# dbg(X - z[0]) diff --git a/spd/clustering/math/jaccard_test.py b/spd/clustering/math/jaccard_test.py deleted file mode 100644 index 9322abe3c..000000000 --- a/spd/clustering/math/jaccard_test.py +++ /dev/null @@ -1,194 +0,0 @@ -"""jaccard index between clusterings - - -we start with a matrix X: Int[np.ndarray, "k n"] where each of the k rows is a label vector of length n -we want to compute a Float["k k"] matrix $J$ of pairwise jaccard indices between the rows of X - -jaccard index between two partitions A and B is defined as: -J(A, B) = M11 / (M11 + M10 + M01) - -where: -- M11 = number of pairs clustered together in both partitions -- M10 = number of pairs clustered together in A but not in B -- M01 = number of pairs clustered together in B but not in A - -""" - -# %% -import matplotlib.pyplot as plt -import torch -from jaxtyping import Bool, Float, Int -from muutils.dbg import dbg -from torch import Tensor - -# def per_row_label_counts(X: Int[Tensor, "k n"]) -> list[Tensor]: -# """Return a list of 1D count arrays, one per row.""" -# return [ -# torch.bincount(x) -# for x in X -# ] - - -def process_singletons( - x: Int[Tensor, " n"], -) -> tuple[Int[Tensor, " n"], int]: - """relabel anything in a singleton cluster to -1, relabel other clusters to minimize labels""" - assert (x >= 0).all(), "input labels must be non-negative" - # figure out where the singletons are - counts: Int[Tensor, " k"] = torch.bincount(x) - singleton_mask: Bool[Tensor, " k"] = counts == 1 - - x_relabel: Int[Tensor, " n"] = x.clone() - dbg(x) - dbg(singleton_mask) - dbg(singleton_mask[x]) - dbg(x_relabel) - dbg(x_relabel[singleton_mask[x]]) - - # map singletons to -1 - x_relabel[singleton_mask[x]] = -1 - dbg(x_relabel) - - # map every non `-1` label to a new label - non_singleton_labels: Int[Tensor, " m"] = x_relabel[~singleton_mask[x]].unique() - dbg(non_singleton_labels) - n_unique_nonsingleton_labels: int = non_singleton_labels.shape[0] - dbg(n_unique_nonsingleton_labels) - old_to_new: dict[int, int] = { - old: new for new, old in enumerate(sorted(non_singleton_labels.tolist())) - } - dbg(old_to_new) - - for old, new in old_to_new.items(): - x_relabel[x == old] = new - dbg(x_relabel) - - return x_relabel, n_unique_nonsingleton_labels - - -# X_1 = torch.tensor([0, 3, 3, 2, 4, 0, 5, 6, 7, 7, 7]) -# X_2 = torch.tensor([1, 1, 2, 3, 3, 1, 4, 5, 6, 6, 6]) -# dbg(X_1) -# process_singletons(X_1) - - -# def to_matrix( -# self, device: torch.device | None = None -# ) -> Bool[Tensor, "k_groups n_components"]: -# if device is None: -# device = self.group_idxs.device -# mat: Bool[Tensor, "k_groups n_components"] = torch.zeros( -# (self.k_groups, self._n_components), dtype=torch.bool, device=device -# ) -# idxs: Int[Tensor, " n_components"] = torch.arange( -# self._n_components, device=device, dtype=torch.int -# ) -# mat[self.group_idxs.to(dtype=torch.int), idxs] = True -# return mat - - -def expand_to_onehot( - x: Int[Tensor, " n"], - k_groups: int, -) -> Bool[Tensor, " k_groups+1 n_components"]: - """expand a label (possibly having -1s) vector to a one-hot matrix""" - n_components: int = x.shape[0] - - # add 1 as -1 will map to last index and be ignored - mat: Bool[Tensor, " k_groups n_components"] = torch.zeros( - (k_groups + 1, n_components), dtype=torch.bool - ) - idxs: Int[Tensor, " n_components"] = torch.arange(n_components, dtype=torch.int) - mat[x.to(dtype=torch.int), idxs] = True - return mat - - -def show_matrix(mat: Tensor, title: str = "", cmap: str = "viridis") -> None: - """Display a matrix with values annotated on each cell.""" - mat_np = mat.cpu().numpy() - _fig, ax = plt.subplots() - im = ax.matshow(mat_np, cmap=cmap) - - # Add text annotations - for i in range(mat_np.shape[0]): - for j in range(mat_np.shape[1]): - ax.text( - j, - i, - f"{mat_np[i, j]:.2f}", - ha="center", - va="center", - color="white" if mat_np[i, j] < mat_np.max() / 2 else "black", - ) - - if title: - plt.title(title) - plt.colorbar(im, ax=ax) - plt.show() - - -# plt.imshow(expand_to_onehot(*process_singletons(X_1))) -# plt.show() -# plt.imshow(expand_to_onehot(*process_singletons(X_2))) -# plt.show() - - -def jaccard_index( - X: Int[Tensor, " s n"], -) -> Float[Tensor, " s s"]: - """compute pairwise jaccard indices between rows of X""" - s: int - _n: int - s, _n = X.shape - - X_expanded_list: list[Int[Tensor, " k n"]] = [ - expand_to_onehot(*process_singletons(X[i])) for i in range(s) - ] - - # compute jaccard for each pair of rows - # jaccard: dict[ - # tuple[int, int], # key is (i, j) from the rows of X - # fl - # # Int[Tensor, " k_i k_j"], # value at (p, q) is jaccard index between two clusters - # ] = {} - jaccard: Float[Tensor, " s s"] = torch.full((s, s), fill_value=torch.nan) - for i in range(s): - for j in range(i, s): - X_i: Int[Tensor, " k_i n"] = X_expanded_list[i].to(torch.int16) - X_j: Int[Tensor, " k_j n"] = X_expanded_list[j].to(torch.int16) - intersects: Int[Tensor, " k_i k_j"] = X_i @ X_j.T - unions: Int[Tensor, " k_i k_j"] = ( - X_i.sum(dim=1, keepdim=True) + X_j.sum(dim=1, keepdim=True).T - intersects - ) - jaccard_mat: Int[Tensor, " k_i k_j"] = intersects / unions - - show_matrix( - X_i, title=f"One-hot matrix for row {i} of X\nshape={X_i.shape}", cmap="Blues" - ) - show_matrix( - X_j, title=f"One-hot matrix for row {j} of X\nshape={X_j.shape}", cmap="Blues" - ) - show_matrix( - jaccard_mat, - title=f"Gram matrix between row {i} and row {j}\n$[{jaccard_mat.min():.2f}, {jaccard_mat.max():.2f}]$", - ) - - # jaccard[i, j] = jaccard_mat.mean() - - return jaccard - - -jaccard_index( - torch.tensor( - [ - # [1, 2, 3, 3], - [0, 1, 1, 2, 3, 3], - [0, 1, 1, 1, 2, 2], - # [0, 0, 0, 0, 1, 1], - [0, 0, 0, 0, 0, 0], - # [0, 1, 2, 3], - ] - ) -) - -# dbg(X - z[0]) diff --git a/spd/clustering/math/matching_dist.py b/spd/clustering/math/matching_dist.py new file mode 100644 index 000000000..1991e9ba0 --- /dev/null +++ b/spd/clustering/math/matching_dist.py @@ -0,0 +1,47 @@ +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +_DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def matching_dist( + X: Int[Tensor, "s n"], +) -> Float[Tensor, "s s"]: + s_ensemble, _n_components = X.shape + matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] + + dists: Float[Tensor, "s s"] = torch.full((s_ensemble, s_ensemble), torch.nan) + + for i in range(s_ensemble): + for j in range(i + 1, s_ensemble): + dist_mat = matches[i].float() - matches[j].float() + dists[i, j] = torch.tril(dist_mat, diagonal=-1).abs().sum() + + return dists + + +def matching_dist_vec( + X: Int[Tensor, "s n"], +) -> Float[Tensor, "s s"]: + matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] + diffs: Bool[Tensor, "s s n n"] = matches[:, None, :, :] ^ matches[None, :, :, :] + + dists_int: torch.Tensor = diffs.sum(dim=(-1, -2)) + dists: Float[Tensor, "s s"] = dists_int.to(torch.float32) + return dists + + +def matching_dist_np( + X: Int[np.ndarray, "s n"], + device: torch.device = _DEVICE, +) -> Float[np.ndarray, "s s"]: + return matching_dist(torch.tensor(X, device=device)).cpu().numpy() + + +def matching_dist_vec_np( + X: Int[np.ndarray, "s n"], + device: torch.device = _DEVICE, +) -> Float[np.ndarray, "s s"]: + return matching_dist_vec(torch.tensor(X, device=device)).cpu().numpy() diff --git a/spd/clustering/math/merge_distances.py b/spd/clustering/math/merge_distances.py index 3d9215972..d3644cd68 100644 --- a/spd/clustering/math/merge_distances.py +++ b/spd/clustering/math/merge_distances.py @@ -10,13 +10,12 @@ MergesArray, MergesAtIterArray, ) - -# from spd.clustering.math.jaccard import jaccard_partition_matrix +from spd.clustering.math.matching_dist import matching_dist_np, matching_dist_vec_np from spd.clustering.math.perm_invariant_hamming import perm_invariant_hamming_matrix DISTANCES_METHODS: dict[DistancesMethod, Callable[[MergesAtIterArray], DistancesArray]] = { "perm_invariant_hamming": perm_invariant_hamming_matrix, - # "jaccard": jaccard_partition_matrix, + "matching_dist": matching_dist_np, } # pyright: reportUnnecessaryComparison=false, reportUnreachable=false @@ -40,14 +39,21 @@ def compute_distances( ) return np.stack(distances_list, axis=0) - case "jaccard": - raise NotImplementedError("Jaccard distance computation is not implemented.") - # merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] - # distances_list = run_maybe_parallel( - # func=jaccard_partition_matrix, - # iterable=merges_array_list, - # parallel=True, - # ) - # return np.stack(distances_list, axis=0) + case "matching_dist": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + distances_list = run_maybe_parallel( + func=matching_dist_np, + iterable=merges_array_list, + parallel=True, + ) + return np.stack(distances_list, axis=0) + case "matching_dist_vec": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + distances_list = run_maybe_parallel( + func=matching_dist_vec_np, + iterable=merges_array_list, + parallel=True, + ) + return np.stack(distances_list, axis=0) case _: raise ValueError(f"Unknown distance method: {method}") diff --git a/spd/clustering/merge_history.py b/spd/clustering/merge_history.py index 5ba3226ce..bbff78893 100644 --- a/spd/clustering/merge_history.py +++ b/spd/clustering/merge_history.py @@ -439,6 +439,8 @@ def normalized(self) -> tuple[MergesArray, dict[str, Any]]: else: history_metadatas.append(None) + dbg_tensor(merges_array) + return ( # TODO: dataclass this merges_array, diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py index 049f06e29..b213e1724 100644 --- a/spd/clustering/plotting/merge.py +++ b/spd/clustering/plotting/merge.py @@ -186,6 +186,8 @@ def plot_dists_distribution( ax: plt.Axes | None = None, kwargs_fig: dict[str, Any] | None = None, kwargs_plot: dict[str, Any] | None = None, + use_symlog: bool = True, + linthresh: float = 1.0, ) -> plt.Axes: n_iters: int = distances.shape[0] n_ens: int = distances.shape[1] @@ -278,6 +280,34 @@ def plot_dists_distribution( ax_.set_ylabel("distance") ax_.set_title("Distribution of pairwise distances between group merges in an ensemble") + if use_symlog: + from matplotlib.ticker import FuncFormatter + + ax_.set_yscale("symlog", linthresh=linthresh, linscale=0.2) + + # Custom formatter for y-axis ticks + def custom_format(y: float, _pos: int) -> str: + if abs(y) < linthresh: + # Show exact values in the linear range + return f"{y:.1f}" + elif abs(y) == 1: + return "1" + elif abs(y) == 10: + return "10" + else: + # Use scientific notation for larger values + exponent = int(np.log10(abs(y))) + return f"$10^{{{exponent}}}$" + + ax_.yaxis.set_major_formatter(FuncFormatter(custom_format)) + + # Add a visual indicator for the linear region (0 to linthresh) + ax_.axhspan(0, linthresh, alpha=0.05, color="gray", zorder=-10) + # Add subtle lines at linthresh boundaries + ax_.axhline(linthresh, color="gray", linestyle="--", linewidth=0.5, alpha=0.3) + if linthresh > 0: + ax_.axhline(0, color="gray", linestyle="-", linewidth=0.5, alpha=0.3) + return ax_ diff --git a/spd/clustering/scripts/calc_distances.py b/spd/clustering/scripts/calc_distances.py index 6ee277759..709d3c1c6 100644 --- a/spd/clustering/scripts/calc_distances.py +++ b/spd/clustering/scripts/calc_distances.py @@ -12,10 +12,13 @@ import argparse import json +import multiprocessing import numpy as np +import torch from matplotlib import pyplot as plt from matplotlib.axes import Axes +from muutils.dbg import dbg_tensor from spd.clustering.consts import DistancesArray, DistancesMethod from spd.clustering.ensemble_registry import get_clustering_runs @@ -25,6 +28,15 @@ from spd.log import logger from spd.settings import SPD_CACHE_DIR +# Set spawn method for CUDA compatibility with multiprocessing +# Must be done before any CUDA operations +if torch.cuda.is_available(): + try: # noqa: SIM105 + multiprocessing.set_start_method("spawn") + except RuntimeError: + # Already set, ignore + pass + def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: """Calculate distances between clustering runs in an ensemble. @@ -76,6 +88,8 @@ def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: method=distances_method, ) + dbg_tensor(distances) + distances_path = pipeline_dir / f"distances_{distances_method}.npz" np.savez_compressed(distances_path, distances=distances) logger.info(f"Distances computed and saved: shape={distances.shape}, path={distances_path}") @@ -109,7 +123,7 @@ def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: ) parser.add_argument( "--distances-method", - choices=["perm_invariant_hamming", "jaccard"], + choices=DistancesMethod.__args__, default="perm_invariant_hamming", help="Method for calculating distances", ) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 7b6af0e82..cde83ffa1 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -20,13 +20,12 @@ import argparse import os import shlex -import subprocess import tempfile from pathlib import Path from typing import Any import wandb_workspaces.workspaces as ws -from pydantic import Field, PositiveInt +from pydantic import Field, PositiveInt, field_validator from spd.base_config import BaseConfig from spd.clustering.consts import DistancesMethod @@ -72,7 +71,9 @@ class ClusteringPipelineConfig(BaseConfig): run_clustering_config_path: Path = Field(description="Path to ClusteringRunConfig file.") n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") - distances_method: DistancesMethod = Field(description="Method to use for calculating distances") + distances_methods: list[DistancesMethod] = Field( + description="List of method(s) to use for calculating distances" + ) base_output_dir: Path = Field(description="Base directory for outputs of clustering runs.") slurm_job_name_prefix: str | None = Field(description="Prefix for SLURM job names") slurm_partition: str | None = Field(description="SLURM partition to use") @@ -83,6 +84,16 @@ class ClusteringPipelineConfig(BaseConfig): wandb_entity: str = Field(description="WandB entity (team/user) name") create_git_snapshot: bool = Field(description="Create a git snapshot for the run") + @field_validator("distances_methods") + @classmethod + def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesMethod]: + """Validate that distances_methods is non-empty and contains valid methods.""" + assert all(method in DistancesMethod.__args__ for method in v), ( + f"Invalid distances_methods: {v}" + ) + + return v + def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: """Create WandB workspace view for clustering runs. @@ -155,29 +166,42 @@ def generate_clustering_commands( return commands -def generate_calc_distances_command(pipeline_run_id: str, distances_method: DistancesMethod) -> str: - """Generate command for calculating distances. +def generate_calc_distances_commands( + pipeline_run_id: str, distances_methods: list[DistancesMethod] +) -> list[str]: + """Generate commands for calculating distances. Args: pipeline_run_id: Pipeline run ID (will query registry for clustering runs) - distances_method: Method for calculating distances + distances_methods: List of methods for calculating distances + + Returns: + List of shell-safe command strings, one per method """ - return shlex.join( - [ - "python", - "spd/clustering/scripts/calc_distances.py", - "--pipeline-run-id", - pipeline_run_id, - "--distances-method", - distances_method, - ] - ) + commands: list[str] = [] + for method in distances_methods: + commands.append( + shlex.join( + [ + "python", + "spd/clustering/scripts/calc_distances.py", + "--pipeline-run-id", + pipeline_run_id, + "--distances-method", + method, + ] + ) + ) + return commands def main( pipeline_config: ClusteringPipelineConfig, local: bool = False, + local_clustering_parallel: bool = False, + local_calc_distances_parallel: bool = False, dataset_streaming: bool = False, + track_resources_calc_distances: bool = False, ) -> None: """Submit clustering runs to SLURM. @@ -190,6 +214,13 @@ def main( logger.set_format("console", "terse") + if local_clustering_parallel or local_calc_distances_parallel or track_resources_calc_distances: + assert local, ( + "local_clustering_parallel, local_calc_distances_parallel, track_resources_calc_distances " + "can only be set when running locally\n" + f"{local_clustering_parallel=}, {local_calc_distances_parallel=}, {track_resources_calc_distances=}, {local=}" + ) + # Create ExecutionStamp for pipeline execution_stamp: ExecutionStamp = ExecutionStamp.create( run_type="ensemble", @@ -222,10 +253,10 @@ def main( dataset_streaming=dataset_streaming, ) - # Generate command for calculating distances - calc_distances_command = generate_calc_distances_command( + # Generate commands for calculating distances + calc_distances_commands = generate_calc_distances_commands( pipeline_run_id=pipeline_run_id, - distances_method=pipeline_config.distances_method, + distances_methods=pipeline_config.distances_methods, ) # Submit to SLURM @@ -233,23 +264,31 @@ def main( # submit clustering array job run_script_array_local( commands=clustering_commands, + parallel=local_clustering_parallel, ) - # submit calc_distances job + # submit calc_distances jobs in parallel logger.info("Calculating distances...") - logger.info(f"Command: {calc_distances_command}") - subprocess.run(shlex.split(calc_distances_command), shell=False, check=True) + run_script_array_local( + commands=calc_distances_commands, + parallel=local_calc_distances_parallel, + track_resources=track_resources_calc_distances, + ) logger.section("complete!") - distances_plot_path = ( - storage.plots_dir / f"distances_{pipeline_config.distances_method}.png" - ) + + # Build distances plot paths dict + distances_plots = { + f"distances via {method}": str(storage.plots_dir / f"distances_{method}.png") + for method in pipeline_config.distances_methods + } + logger.values( { "Total clustering runs": len(clustering_commands), "Pipeline run ID": pipeline_run_id, "Pipeline output dir": str(storage.base_dir), - "Distances plot": str(distances_plot_path), + **distances_plots, } ) @@ -275,36 +314,52 @@ def main( ) array_job_id = submit_slurm_script(clustering_script_path) - # Submit calc_distances job with dependency on array job - calc_distances_script_path = Path(temp_dir) / f"calc_distances_{pipeline_run_id}.sh" - - create_slurm_script( - script_path=calc_distances_script_path, - job_name=f"{pipeline_config.slurm_job_name_prefix}_distances", - command=calc_distances_command, - snapshot_branch=execution_stamp.snapshot_branch, - n_gpus=1, # Always 1 GPU for distances calculation - partition=pipeline_config.slurm_partition, - dependency_job_id=array_job_id, - ) - calc_distances_job_id = submit_slurm_script(calc_distances_script_path) + # Submit calc_distances jobs (one per method) with dependency on array job + calc_distances_job_ids: list[str] = [] + calc_distances_logs: list[str] = [] + + for _i, (method, cmd) in enumerate( + zip(pipeline_config.distances_methods, calc_distances_commands, strict=True) + ): + calc_distances_script_path = ( + Path(temp_dir) / f"calc_distances_{method}_{pipeline_run_id}.sh" + ) + + create_slurm_script( + script_path=calc_distances_script_path, + job_name=f"{pipeline_config.slurm_job_name_prefix}_dist_{method}", + command=cmd, + snapshot_branch=execution_stamp.snapshot_branch, + n_gpus=1, # Always 1 GPU for distances calculation + partition=pipeline_config.slurm_partition, + dependency_job_id=array_job_id, + ) + job_id = submit_slurm_script(calc_distances_script_path) + calc_distances_job_ids.append(job_id) + calc_distances_logs.append(f"~/slurm_logs/slurm-{job_id}.out") logger.section("Jobs submitted successfully!") - distances_plot_path = ( - storage.plots_dir / f"distances_{pipeline_config.distances_method}.png" - ) + + # Build distances plot paths dict + distances_plots = { + method: str(storage.plots_dir / f"distances_{method}.png") + for method in pipeline_config.distances_methods + } + logger.values( { "Clustering Array Job ID": array_job_id, - "Calc Distances Job ID": calc_distances_job_id, + "Calc Distances Job IDs": ", ".join(calc_distances_job_ids), "Total clustering runs": len(clustering_commands), "Pipeline run ID": pipeline_run_id, "Pipeline output dir": str(storage.base_dir), "Clustering logs": f"~/slurm_logs/slurm-{array_job_id}_*.out", - "Calc Distances log": f"~/slurm_logs/slurm-{calc_distances_job_id}.out", - "Distances plot will be saved to": str(distances_plot_path), + "Calc Distances logs": ", ".join(calc_distances_logs), } ) + logger.info("Distances plots will be saved to:") + for method, path in distances_plots.items(): + logger.info(f" {method}: {path}") def cli(): @@ -338,12 +393,33 @@ def cli(): default=None, help="WandB entity name (user or team)", ) + parser.add_argument( + "--distances-methods", + type=str, + default=None, + help="Comma-separated list of distance methods (e.g., 'perm_invariant_hamming,matching_dist')", + ) parser.add_argument( "--local", action=argparse.BooleanOptionalAction, default=False, help="Run locally instead of submitting to SLURM (required if slurm_job_name_prefix and slurm_partition are None in config)", ) + parser.add_argument( + "--local-clustering-parallel", + action="store_true", + help="If running locally, whether to run clustering runs in parallel", + ) + parser.add_argument( + "--local-calc-distances-parallel", + action="store_true", + help="If running locally, whether to run distance calculations in parallel", + ) + parser.add_argument( + "--track-resources-calc-distances", + action="store_true", + help="If running locally, whether to track resource usage during distance calculations", + ) parser.add_argument( "--dataset-streaming", action="store_true", @@ -361,6 +437,10 @@ def cli(): overrides["wandb_project"] = args.wandb_project if args.wandb_entity is not None: overrides["wandb_entity"] = args.wandb_entity + if args.distances_methods is not None: + # Parse comma-separated list of distance methods + methods = [method.strip() for method in args.distances_methods.split(",")] + overrides["distances_methods"] = methods pipeline_config = replace_pydantic_model(pipeline_config, overrides) @@ -368,6 +448,9 @@ def cli(): pipeline_config=pipeline_config, local=args.local, dataset_streaming=args.dataset_streaming, + local_clustering_parallel=args.local_clustering_parallel, + local_calc_distances_parallel=args.local_calc_distances_parallel, + track_resources_calc_distances=args.track_resources_calc_distances, ) diff --git a/spd/utils/command_utils.py b/spd/utils/command_utils.py index 6b79ad6ed..b6b74f3b3 100644 --- a/spd/utils/command_utils.py +++ b/spd/utils/command_utils.py @@ -1,37 +1,113 @@ """Minimal utilities for running shell-safe commands locally.""" -import shlex import subprocess +import tempfile +from pathlib import Path from spd.log import logger -def run_script_array_local(commands: list[str], parallel: bool = False) -> None: +def run_script_array_local( + commands: list[str], parallel: bool = False, track_resources: bool = False +) -> dict[str, dict[str, float]] | None: """Run multiple shell-safe command strings locally. Args: commands: List of shell-safe command strings (built with shlex.join()) parallel: If True, run all commands in parallel. If False, run sequentially. + track_resources: If True, track and return resource usage for each command using /usr/bin/time. + + Returns: + If track_resources is True, returns dict mapping commands to resource metrics dict. + Resource metrics include: K (avg memory KB), M (max memory KB), P (CPU %), + S (system CPU sec), U (user CPU sec), e (wall time sec). + Otherwise returns None. """ n_commands = len(commands) + resources: dict[str, dict[str, float]] = {} + resource_files: list[Path] = [] - if not parallel: - logger.section(f"LOCAL EXECUTION: Running {n_commands} tasks serially") - for i, cmd in enumerate(commands, 1): - logger.info(f"[{i}/{n_commands}] Running: {cmd}") - subprocess.run(shlex.split(cmd), shell=False, check=True) - logger.section("LOCAL EXECUTION COMPLETE") + # Wrap commands with /usr/bin/time if resource tracking is requested + if track_resources: + wrapped_commands: list[str] = [] + for cmd in commands: + resource_file = Path(tempfile.mktemp(suffix=".resources")) # pyright: ignore[reportDeprecated] + resource_files.append(resource_file) + # Use /usr/bin/time to track comprehensive resource usage + # K=avg total mem, M=max resident, P=CPU%, S=system time, U=user time, e=wall time + wrapped_cmd = ( + f'/usr/bin/time -f "K:%K M:%M P:%P S:%S U:%U e:%e" -o {resource_file} {cmd}' + ) + wrapped_commands.append(wrapped_cmd) + commands_to_run = wrapped_commands else: - logger.section(f"LOCAL EXECUTION: Starting {n_commands} tasks in parallel") - procs: list[subprocess.Popen[bytes]] = [] - for i, cmd in enumerate(commands, 1): - logger.info(f"[{i}/{n_commands}] Starting: {cmd}") - proc = subprocess.Popen(shlex.split(cmd), shell=False) - procs.append(proc) - - logger.section("WAITING FOR ALL TASKS TO COMPLETE") - for proc in procs: - proc.wait() - if proc.returncode != 0: - logger.error(f"Process {proc.pid} failed with exit code {proc.returncode}") - logger.section("LOCAL EXECUTION COMPLETE") + commands_to_run = commands + + try: + if not parallel: + logger.section(f"LOCAL EXECUTION: Running {n_commands} tasks serially") + for i, cmd in enumerate(commands_to_run, 1): + logger.info(f"[{i}/{n_commands}] Running: {commands[i - 1]}") + subprocess.run(cmd, shell=True, check=True) + logger.section("LOCAL EXECUTION COMPLETE") + else: + logger.section(f"LOCAL EXECUTION: Starting {n_commands} tasks in parallel") + procs: list[subprocess.Popen[bytes]] = [] + + for i, cmd in enumerate(commands_to_run, 1): + logger.info(f"[{i}/{n_commands}] Starting: {commands[i - 1]}") + proc = subprocess.Popen(cmd, shell=True) + procs.append(proc) + + logger.section("WAITING FOR ALL TASKS TO COMPLETE") + for proc, cmd in zip(procs, commands, strict=True): # noqa: B007 + proc.wait() + if proc.returncode != 0: + logger.error(f"Process {proc.pid} failed with exit code {proc.returncode}") + logger.section("LOCAL EXECUTION COMPLETE") + + # Read resource usage results + if track_resources: + for cmd, resource_file in zip(commands, resource_files, strict=True): + if resource_file.exists(): + # Parse format: "K:123 M:456 P:78% S:1.23 U:4.56 e:7.89" + output = resource_file.read_text().strip() + metrics: dict[str, float] = {} + + for part in output.split(): + if ":" in part: + key, value = part.split(":", 1) + # Remove % sign from CPU percentage + value = value.rstrip("%") + try: + metrics[key] = float(value) + except ValueError: + logger.warning(f"Could not parse {key}:{value} for command: {cmd}") + + resources[cmd] = metrics + else: + logger.warning(f"Resource file not found for: {cmd}") + + # Log comprehensive resource usage table + logger.section("RESOURCE USAGE RESULTS") + for cmd, metrics in resources.items(): + logger.info(f"Command: {cmd}") + logger.info( + f" Time: {metrics.get('e', 0):.2f}s wall, " + f"{metrics.get('U', 0):.2f}s user, " + f"{metrics.get('S', 0):.2f}s system" + ) + logger.info( + f" Memory: {metrics.get('M', 0) / 1024:.1f} MB peak, " + f"{metrics.get('K', 0) / 1024:.1f} MB avg" + ) + logger.info(f" CPU: {metrics.get('P', 0):.1f}%") + + finally: + # Clean up temp files + if track_resources: + for resource_file in resource_files: + if resource_file.exists(): + resource_file.unlink() + + return resources if track_resources else None diff --git a/uv.lock b/uv.lock index 560ef3f01..26bfb9af8 100644 --- a/uv.lock +++ b/uv.lock @@ -17,7 +17,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.0" +version = "3.13.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -28,25 +28,25 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/62/f1/8515650ac3121a9e55c7b217c60e7fae3e0134b5acfe65691781b5356929/aiohttp-3.13.0.tar.gz", hash = "sha256:378dbc57dd8cf341ce243f13fa1fa5394d68e2e02c15cd5f28eae35a70ec7f67", size = 7832348, upload-time = "2025-10-06T19:58:48.089Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/fa/3ae643cd525cf6844d3dc810481e5748107368eb49563c15a5fb9f680750/aiohttp-3.13.1.tar.gz", hash = "sha256:4b7ee9c355015813a6aa085170b96ec22315dabc3d866fd77d147927000e9464", size = 7835344, upload-time = "2025-10-17T14:03:29.337Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/86/2c/ac53efdc9c10e41399acc2395af98f835b86d0141d5c3820857eb9f6a14a/aiohttp-3.13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:00243e51f16f6ec0fb021659d4af92f675f3cf9f9b39efd142aa3ad641d8d1e6", size = 730090, upload-time = "2025-10-06T19:56:16.858Z" }, - { url = "https://files.pythonhosted.org/packages/13/18/1ac95683e1c1d48ef4503965c96f5401618a04c139edae12e200392daae8/aiohttp-3.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:059978d2fddc462e9211362cbc8446747ecd930537fa559d3d25c256f032ff54", size = 488041, upload-time = "2025-10-06T19:56:18.659Z" }, - { url = "https://files.pythonhosted.org/packages/fd/79/ef0d477c771a642d1a881b92d226314c43d3c74bc674c93e12e679397a97/aiohttp-3.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:564b36512a7da3b386143c611867e3f7cfb249300a1bf60889bd9985da67ab77", size = 486989, upload-time = "2025-10-06T19:56:20.371Z" }, - { url = "https://files.pythonhosted.org/packages/37/b4/0e440481a0e77a551d6c5dcab5d11f1ff6b2b2ddb8dedc24f54f5caad732/aiohttp-3.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4aa995b9156ae499393d949a456a7ab0b994a8241a96db73a3b73c7a090eff6a", size = 1718331, upload-time = "2025-10-06T19:56:22.188Z" }, - { url = "https://files.pythonhosted.org/packages/e6/59/76c421cc4a75bb1aceadb92f20ee6f05a990aa6960c64b59e8e0d340e3f5/aiohttp-3.13.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55ca0e95a3905f62f00900255ed807c580775174252999286f283e646d675a49", size = 1686263, upload-time = "2025-10-06T19:56:24.393Z" }, - { url = "https://files.pythonhosted.org/packages/ec/ac/5095f12a79c7775f402cfc3e83651b6e0a92ade10ddf7f2c78c4fed79f71/aiohttp-3.13.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:49ce7525853a981fc35d380aa2353536a01a9ec1b30979ea4e35966316cace7e", size = 1754265, upload-time = "2025-10-06T19:56:26.365Z" }, - { url = "https://files.pythonhosted.org/packages/05/d7/a48e4989bd76cc70600c505bbdd0d90ca1ad7f9053eceeb9dbcf9345a9ec/aiohttp-3.13.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2117be9883501eaf95503bd313eb4c7a23d567edd44014ba15835a1e9ec6d852", size = 1856486, upload-time = "2025-10-06T19:56:28.438Z" }, - { url = "https://files.pythonhosted.org/packages/1e/02/45b388b49e37933f316e1fb39c0de6fb1d77384b0c8f4cf6af5f2cbe3ea6/aiohttp-3.13.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d169c47e40c911f728439da853b6fd06da83761012e6e76f11cb62cddae7282b", size = 1737545, upload-time = "2025-10-06T19:56:30.688Z" }, - { url = "https://files.pythonhosted.org/packages/6c/a7/4fde058f1605c34a219348a83a99f14724cc64e68a42480fc03cf40f9ea3/aiohttp-3.13.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:703ad3f742fc81e543638a7bebddd35acadaa0004a5e00535e795f4b6f2c25ca", size = 1552958, upload-time = "2025-10-06T19:56:32.528Z" }, - { url = "https://files.pythonhosted.org/packages/d1/12/0bac4d29231981e3aa234e88d1931f6ba38135ff4c2cf3afbb7895527630/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5bf635c3476f4119b940cc8d94ad454cbe0c377e61b4527f0192aabeac1e9370", size = 1681166, upload-time = "2025-10-06T19:56:34.81Z" }, - { url = "https://files.pythonhosted.org/packages/71/95/b829eb5f8ac1ca1d8085bb8df614c8acf3ff32e23ad5ad1173c7c9761daa/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:cfe6285ef99e7ee51cef20609be2bc1dd0e8446462b71c9db8bb296ba632810a", size = 1710516, upload-time = "2025-10-06T19:56:36.787Z" }, - { url = "https://files.pythonhosted.org/packages/47/6d/15ccf4ef3c254d899f62580e0c7fc717014f4d14a3ac31771e505d2c736c/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:34d8af6391c5f2e69749d7f037b614b8c5c42093c251f336bdbfa4b03c57d6c4", size = 1731354, upload-time = "2025-10-06T19:56:38.659Z" }, - { url = "https://files.pythonhosted.org/packages/46/6a/8acf6c57e03b6fdcc8b4c06392e66abaff3213ea275e41db3edb20738d91/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:12f5d820fadc5848d4559ea838aef733cf37ed2a1103bba148ac2f5547c14c29", size = 1548040, upload-time = "2025-10-06T19:56:40.578Z" }, - { url = "https://files.pythonhosted.org/packages/75/7d/fbfd59ab2a83fe2578ce79ac3db49727b81e9f4c3376217ad09c03c6d279/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:0f1338b61ea66f4757a0544ed8a02ccbf60e38d9cfb3225888888dd4475ebb96", size = 1756031, upload-time = "2025-10-06T19:56:42.492Z" }, - { url = "https://files.pythonhosted.org/packages/99/e7/cc9f0fdf06cab3ca61e6b62bff9a4b978b8ca736e9d76ddf54365673ab19/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:582770f82513419512da096e8df21ca44f86a2e56e25dc93c5ab4df0fe065bf0", size = 1714933, upload-time = "2025-10-06T19:56:45.542Z" }, - { url = "https://files.pythonhosted.org/packages/db/43/7abbe1de94748a58a71881163ee280fd3217db36e8344d109f63638fe16a/aiohttp-3.13.0-cp313-cp313-win32.whl", hash = "sha256:3194b8cab8dbc882f37c13ef1262e0a3d62064fa97533d3aa124771f7bf1ecee", size = 423799, upload-time = "2025-10-06T19:56:47.779Z" }, - { url = "https://files.pythonhosted.org/packages/c9/58/afab7f2b9e7df88c995995172eb78cae8a3d5a62d5681abaade86b3f0089/aiohttp-3.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:7897298b3eedc790257fef8a6ec582ca04e9dbe568ba4a9a890913b925b8ea21", size = 450138, upload-time = "2025-10-06T19:56:49.49Z" }, + { url = "https://files.pythonhosted.org/packages/16/6d/d267b132342e1080f4c1bb7e1b4e96b168b3cbce931ec45780bff693ff95/aiohttp-3.13.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:55785a7f8f13df0c9ca30b5243d9909bd59f48b274262a8fe78cee0828306e5d", size = 730727, upload-time = "2025-10-17T14:00:39.681Z" }, + { url = "https://files.pythonhosted.org/packages/92/c8/1cf495bac85cf71b80fad5f6d7693e84894f11b9fe876b64b0a1e7cbf32f/aiohttp-3.13.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4bef5b83296cebb8167707b4f8d06c1805db0af632f7a72d7c5288a84667e7c3", size = 488678, upload-time = "2025-10-17T14:00:41.541Z" }, + { url = "https://files.pythonhosted.org/packages/a8/19/23c6b81cca587ec96943d977a58d11d05a82837022e65cd5502d665a7d11/aiohttp-3.13.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:27af0619c33f9ca52f06069ec05de1a357033449ab101836f431768ecfa63ff5", size = 487637, upload-time = "2025-10-17T14:00:43.527Z" }, + { url = "https://files.pythonhosted.org/packages/48/58/8f9464afb88b3eed145ad7c665293739b3a6f91589694a2bb7e5778cbc72/aiohttp-3.13.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a47fe43229a8efd3764ef7728a5c1158f31cdf2a12151fe99fde81c9ac87019c", size = 1718975, upload-time = "2025-10-17T14:00:45.496Z" }, + { url = "https://files.pythonhosted.org/packages/e1/8b/c3da064ca392b2702f53949fd7c403afa38d9ee10bf52c6ad59a42537103/aiohttp-3.13.1-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6e68e126de5b46e8b2bee73cab086b5d791e7dc192056916077aa1e2e2b04437", size = 1686905, upload-time = "2025-10-17T14:00:47.707Z" }, + { url = "https://files.pythonhosted.org/packages/0a/a4/9c8a3843ecf526daee6010af1a66eb62579be1531d2d5af48ea6f405ad3c/aiohttp-3.13.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e65ef49dd22514329c55970d39079618a8abf856bae7147913bb774a3ab3c02f", size = 1754907, upload-time = "2025-10-17T14:00:49.702Z" }, + { url = "https://files.pythonhosted.org/packages/a4/80/1f470ed93e06436e3fc2659a9fc329c192fa893fb7ed4e884d399dbfb2a8/aiohttp-3.13.1-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0e425a7e0511648b3376839dcc9190098671a47f21a36e815b97762eb7d556b0", size = 1857129, upload-time = "2025-10-17T14:00:51.822Z" }, + { url = "https://files.pythonhosted.org/packages/cc/e6/33d305e6cce0a8daeb79c7d8d6547d6e5f27f4e35fa4883fc9c9eb638596/aiohttp-3.13.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:010dc9b7110f055006acd3648d5d5955bb6473b37c3663ec42a1b4cba7413e6b", size = 1738189, upload-time = "2025-10-17T14:00:53.976Z" }, + { url = "https://files.pythonhosted.org/packages/ac/42/8df03367e5a64327fe0c39291080697795430c438fc1139c7cc1831aa1df/aiohttp-3.13.1-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:1b5c722d0ca5f57d61066b5dfa96cdb87111e2519156b35c1f8dd17c703bee7a", size = 1553608, upload-time = "2025-10-17T14:00:56.144Z" }, + { url = "https://files.pythonhosted.org/packages/96/17/6d5c73cd862f1cf29fddcbb54aac147037ff70a043a2829d03a379e95742/aiohttp-3.13.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:93029f0e9b77b714904a281b5aa578cdc8aa8ba018d78c04e51e1c3d8471b8ec", size = 1681809, upload-time = "2025-10-17T14:00:58.603Z" }, + { url = "https://files.pythonhosted.org/packages/be/31/8926c8ab18533f6076ce28d2c329a203b58c6861681906e2d73b9c397588/aiohttp-3.13.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:d1824c7d08d8ddfc8cb10c847f696942e5aadbd16fd974dfde8bd2c3c08a9fa1", size = 1711161, upload-time = "2025-10-17T14:01:01.744Z" }, + { url = "https://files.pythonhosted.org/packages/f2/36/2f83e1ca730b1e0a8cf1c8ab9559834c5eec9f5da86e77ac71f0d16b521d/aiohttp-3.13.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:8f47d0ff5b3eb9c1278a2f56ea48fda667da8ebf28bd2cb378b7c453936ce003", size = 1731999, upload-time = "2025-10-17T14:01:04.626Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ec/1f818cc368dfd4d5ab4e9efc8f2f6f283bfc31e1c06d3e848bcc862d4591/aiohttp-3.13.1-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:8a396b1da9b51ded79806ac3b57a598f84e0769eaa1ba300655d8b5e17b70c7b", size = 1548684, upload-time = "2025-10-17T14:01:06.828Z" }, + { url = "https://files.pythonhosted.org/packages/d3/ad/33d36efd16e4fefee91b09a22a3a0e1b830f65471c3567ac5a8041fac812/aiohttp-3.13.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:d9c52a65f54796e066b5d674e33b53178014752d28bca555c479c2c25ffcec5b", size = 1756676, upload-time = "2025-10-17T14:01:09.517Z" }, + { url = "https://files.pythonhosted.org/packages/3c/c4/4a526d84e77d464437713ca909364988ed2e0cd0cdad2c06cb065ece9e08/aiohttp-3.13.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a89da72d18d6c95a653470b78d8ee5aa3c4b37212004c103403d0776cbea6ff0", size = 1715577, upload-time = "2025-10-17T14:01:11.958Z" }, + { url = "https://files.pythonhosted.org/packages/a2/21/e39638b7d9c7f1362c4113a91870f89287e60a7ea2d037e258b81e8b37d5/aiohttp-3.13.1-cp313-cp313-win32.whl", hash = "sha256:02e0258b7585ddf5d01c79c716ddd674386bfbf3041fbbfe7bdf9c7c32eb4a9b", size = 424468, upload-time = "2025-10-17T14:01:14.344Z" }, + { url = "https://files.pythonhosted.org/packages/cc/00/f3a92c592a845ebb2f47d102a67f35f0925cb854c5e7386f1a3a1fdff2ab/aiohttp-3.13.1-cp313-cp313-win_amd64.whl", hash = "sha256:ef56ffe60e8d97baac123272bde1ab889ee07d3419606fae823c80c2b86c403e", size = 450806, upload-time = "2025-10-17T14:01:16.437Z" }, ] [[package]] @@ -629,11 +629,11 @@ wheels = [ [[package]] name = "iniconfig" -version = "2.1.0" +version = "2.3.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] [[package]] @@ -974,11 +974,11 @@ wheels = [ [[package]] name = "narwhals" -version = "2.8.0" +version = "2.9.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ae/05/79a5b5a795f36c1aaa002d194c1ef71e5d95f7e1900155bbfde734815ab9/narwhals-2.8.0.tar.gz", hash = "sha256:52e0b22d54718264ae703bd9293af53b04abc995a1414908c3b807ba8c913858", size = 574277, upload-time = "2025-10-13T08:44:28.81Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/95/aa46616f5e567ff5d262f4c207d5ca79cb2766010c786c351b8e7f930ef4/narwhals-2.9.0.tar.gz", hash = "sha256:d8cde40a6a8a7049d8e66608b7115ab19464acc6f305d136a8dc8ba396c4acfe", size = 584098, upload-time = "2025-10-20T12:19:16.893Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/86/ac808ecb94322a3f1ea31627d13ab3e50dd4333564d711e0e481ad0f4586/narwhals-2.8.0-py3-none-any.whl", hash = "sha256:6304856676ba4a79fd34148bda63aed8060dd6edb1227edf3659ce5e091de73c", size = 415852, upload-time = "2025-10-13T08:44:25.421Z" }, + { url = "https://files.pythonhosted.org/packages/13/34/00c7ae8194074ed82b64e0bb7c24220eac5f77ac90c16e23cf0d2cfd2a03/narwhals-2.9.0-py3-none-any.whl", hash = "sha256:c59f7de4763004ae81691ce16df71b4e55aead0ead7ccde8c8f2ef8c9559c765", size = 422255, upload-time = "2025-10-20T12:19:15.228Z" }, ] [[package]] @@ -1372,18 +1372,18 @@ wheels = [ [[package]] name = "psutil" -version = "7.1.0" +version = "7.1.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b3/31/4723d756b59344b643542936e37a31d1d3204bcdc42a7daa8ee9eb06fb50/psutil-7.1.0.tar.gz", hash = "sha256:655708b3c069387c8b77b072fc429a57d0e214221d01c0a772df7dfedcb3bcd2", size = 497660, upload-time = "2025-09-17T20:14:52.902Z" } +sdist = { url = "https://files.pythonhosted.org/packages/89/fc/889242351a932d6183eec5df1fc6539b6f36b6a88444f1e63f18668253aa/psutil-7.1.1.tar.gz", hash = "sha256:092b6350145007389c1cfe5716050f02030a05219d90057ea867d18fe8d372fc", size = 487067, upload-time = "2025-10-19T15:43:59.373Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/62/ce4051019ee20ce0ed74432dd73a5bb087a6704284a470bb8adff69a0932/psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13", size = 245242, upload-time = "2025-09-17T20:14:56.126Z" }, - { url = "https://files.pythonhosted.org/packages/38/61/f76959fba841bf5b61123fbf4b650886dc4094c6858008b5bf73d9057216/psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5", size = 246682, upload-time = "2025-09-17T20:14:58.25Z" }, - { url = "https://files.pythonhosted.org/packages/88/7a/37c99d2e77ec30d63398ffa6a660450b8a62517cabe44b3e9bae97696e8d/psutil-7.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e4454970b32472ce7deaa45d045b34d3648ce478e26a04c7e858a0a6e75ff3", size = 287994, upload-time = "2025-09-17T20:14:59.901Z" }, - { url = "https://files.pythonhosted.org/packages/9d/de/04c8c61232f7244aa0a4b9a9fbd63a89d5aeaf94b2fc9d1d16e2faa5cbb0/psutil-7.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70e113920d51e89f212dd7be06219a9b88014e63a4cec69b684c327bc474e3", size = 291163, upload-time = "2025-09-17T20:15:01.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/58/c4f976234bf6d4737bc8c02a81192f045c307b72cf39c9e5c5a2d78927f6/psutil-7.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d4a113425c037300de3ac8b331637293da9be9713855c4fc9d2d97436d7259d", size = 293625, upload-time = "2025-09-17T20:15:04.492Z" }, - { url = "https://files.pythonhosted.org/packages/79/87/157c8e7959ec39ced1b11cc93c730c4fb7f9d408569a6c59dbd92ceb35db/psutil-7.1.0-cp37-abi3-win32.whl", hash = "sha256:09ad740870c8d219ed8daae0ad3b726d3bf9a028a198e7f3080f6a1888b99bca", size = 244812, upload-time = "2025-09-17T20:15:07.462Z" }, - { url = "https://files.pythonhosted.org/packages/bf/e9/b44c4f697276a7a95b8e94d0e320a7bf7f3318521b23de69035540b39838/psutil-7.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:57f5e987c36d3146c0dd2528cd42151cf96cd359b9d67cfff836995cc5df9a3d", size = 247965, upload-time = "2025-09-17T20:15:09.673Z" }, - { url = "https://files.pythonhosted.org/packages/26/65/1070a6e3c036f39142c2820c4b52e9243246fcfc3f96239ac84472ba361e/psutil-7.1.0-cp37-abi3-win_arm64.whl", hash = "sha256:6937cb68133e7c97b6cc9649a570c9a18ba0efebed46d8c5dae4c07fa1b67a07", size = 244971, upload-time = "2025-09-17T20:15:12.262Z" }, + { url = "https://files.pythonhosted.org/packages/51/30/f97f8fb1f9ecfbeae4b5ca738dcae66ab28323b5cfbc96cb5565f3754056/psutil-7.1.1-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:8fa59d7b1f01f0337f12cd10dbd76e4312a4d3c730a4fedcbdd4e5447a8b8460", size = 244221, upload-time = "2025-10-19T15:44:03.145Z" }, + { url = "https://files.pythonhosted.org/packages/7b/98/b8d1f61ebf35f4dbdbaabadf9208282d8adc820562f0257e5e6e79e67bf2/psutil-7.1.1-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a95104eae85d088891716db676f780c1404fc15d47fde48a46a5d61e8f5ad2c", size = 245660, upload-time = "2025-10-19T15:44:05.657Z" }, + { url = "https://files.pythonhosted.org/packages/f0/4a/b8015d7357fefdfe34bc4a3db48a107bae4bad0b94fb6eb0613f09a08ada/psutil-7.1.1-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:98629cd8567acefcc45afe2f4ba1e9290f579eacf490a917967decce4b74ee9b", size = 286963, upload-time = "2025-10-19T15:44:08.877Z" }, + { url = "https://files.pythonhosted.org/packages/3d/3c/b56076bb35303d0733fc47b110a1c9cce081a05ae2e886575a3587c1ee76/psutil-7.1.1-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92ebc58030fb054fa0f26c3206ef01c31c29d67aee1367e3483c16665c25c8d2", size = 290118, upload-time = "2025-10-19T15:44:11.897Z" }, + { url = "https://files.pythonhosted.org/packages/dc/af/c13d360c0adc6f6218bf9e2873480393d0f729c8dd0507d171f53061c0d3/psutil-7.1.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:146a704f224fb2ded2be3da5ac67fc32b9ea90c45b51676f9114a6ac45616967", size = 292587, upload-time = "2025-10-19T15:44:14.67Z" }, + { url = "https://files.pythonhosted.org/packages/90/2d/c933e7071ba60c7862813f2c7108ec4cf8304f1c79660efeefd0de982258/psutil-7.1.1-cp37-abi3-win32.whl", hash = "sha256:295c4025b5cd880f7445e4379e6826f7307e3d488947bf9834e865e7847dc5f7", size = 243772, upload-time = "2025-10-19T15:44:16.938Z" }, + { url = "https://files.pythonhosted.org/packages/be/f3/11fd213fff15427bc2853552138760c720fd65032d99edfb161910d04127/psutil-7.1.1-cp37-abi3-win_amd64.whl", hash = "sha256:9b4f17c5f65e44f69bd3a3406071a47b79df45cf2236d1f717970afcb526bcd3", size = 246936, upload-time = "2025-10-19T15:44:18.663Z" }, + { url = "https://files.pythonhosted.org/packages/0a/8d/8a9a45c8b655851f216c1d44f68e3533dc8d2c752ccd0f61f1aa73be4893/psutil-7.1.1-cp37-abi3-win_arm64.whl", hash = "sha256:5457cf741ca13da54624126cd5d333871b454ab133999a9a103fb097a7d7d21a", size = 243944, upload-time = "2025-10-19T15:44:20.666Z" }, ] [[package]] @@ -1437,7 +1437,7 @@ wheels = [ [[package]] name = "pydantic" -version = "2.12.2" +version = "2.12.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-types" }, @@ -1445,9 +1445,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/35/d319ed522433215526689bad428a94058b6dd12190ce7ddd78618ac14b28/pydantic-2.12.2.tar.gz", hash = "sha256:7b8fa15b831a4bbde9d5b84028641ac3080a4ca2cbd4a621a661687e741624fd", size = 816358, upload-time = "2025-10-14T15:02:21.842Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/1e/4f0a3233767010308f2fd6bd0814597e3f63f1dc98304a9112b8759df4ff/pydantic-2.12.3.tar.gz", hash = "sha256:1da1c82b0fc140bb0103bc1441ffe062154c8d38491189751ee00fd8ca65ce74", size = 819383, upload-time = "2025-10-17T15:04:21.222Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6c/98/468cb649f208a6f1279448e6e5247b37ae79cf5e4041186f1e2ef3d16345/pydantic-2.12.2-py3-none-any.whl", hash = "sha256:25ff718ee909acd82f1ff9b1a4acfd781bb23ab3739adaa7144f19a6a4e231ae", size = 460628, upload-time = "2025-10-14T15:02:19.623Z" }, + { url = "https://files.pythonhosted.org/packages/a1/6b/83661fa77dcefa195ad5f8cd9af3d1a7450fd57cc883ad04d65446ac2029/pydantic-2.12.3-py3-none-any.whl", hash = "sha256:6986454a854bc3bc6e5443e1369e06a3a456af9d339eda45510f517d9ea5c6bf", size = 462431, upload-time = "2025-10-17T15:04:19.346Z" }, ] [[package]] @@ -1817,15 +1817,15 @@ wheels = [ [[package]] name = "sentry-sdk" -version = "2.42.0" +version = "2.42.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/b2/7481156cf42b7f66cffb371e504b7ace12b4f016b8872ffcf0873ae9534b/sentry_sdk-2.42.0.tar.gz", hash = "sha256:91c69c9372fb5fb4df0ac39456ccf7286f0428b3ee1cdd389f9dd36c04e0f5c9", size = 351242, upload-time = "2025-10-15T07:41:15.577Z" } +sdist = { url = "https://files.pythonhosted.org/packages/31/04/ec8c1dd9250847303d98516e917978cb1c7083024770d86d657d2ccb5a70/sentry_sdk-2.42.1.tar.gz", hash = "sha256:8598cc6edcfe74cb8074ba6a7c15338cdee93d63d3eb9b9943b4b568354ad5b6", size = 354839, upload-time = "2025-10-20T12:38:40.45Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/4a/9810a246ec5d1df2ae066efefeecfa91d3c548fa2bd5390184e016112887/sentry_sdk-2.42.0-py2.py3-none-any.whl", hash = "sha256:1a7986e638306ff158f52dd47d9480a4055e6c289388caa90628acb2563fe7bd", size = 379496, upload-time = "2025-10-15T07:41:13.802Z" }, + { url = "https://files.pythonhosted.org/packages/0f/cb/c21b96ff379923310b4fb2c06e8d560d801e24aeb300faa72a04776868fc/sentry_sdk-2.42.1-py2.py3-none-any.whl", hash = "sha256:f8716b50c927d3beb41bc88439dc6bcd872237b596df5b14613e2ade104aee02", size = 380952, upload-time = "2025-10-20T12:38:38.88Z" }, ] [[package]] From 3a206ed64c65c7854157d0cc1126929c9b767c28 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:23:25 +0100 Subject: [PATCH 22/77] allow specifying either config path or mrc cfg in pipeline cfg --- spd/clustering/merge_run_config.py | 19 +++++++ spd/clustering/scripts/run_pipeline.py | 74 ++++++++++++++++++++++++-- 2 files changed, 90 insertions(+), 3 deletions(-) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index 60a5244d6..6671127b9 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -1,5 +1,8 @@ """ClusteringRunConfig""" +import base64 +import hashlib +import json from pathlib import Path from typing import Any, Self @@ -127,3 +130,19 @@ def model_dump_with_properties(self) -> dict[str, Any]: ) return base_dump + + def stable_hash_b64(self) -> str: + """Generate a stable, deterministic base64-encoded hash of this config. + + Uses SHA256 hash of the JSON representation with sorted keys for determinism. + Returns URL-safe base64 encoding without padding. + + Returns: + URL-safe base64-encoded hash (without padding) + """ + config_dict: dict[str, Any] = self.model_dump(mode="json") + config_json: str = json.dumps(config_dict, indent=2, sort_keys=True) + hash_digest: bytes = hashlib.sha256(config_json.encode()).digest() + # Use base64 URL-safe encoding and strip padding for filesystem safety + hash_b64: str = base64.urlsafe_b64encode(hash_digest).decode().rstrip("=") + return hash_b64 diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index cde83ffa1..34a7ef2f8 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -25,12 +25,14 @@ from typing import Any import wandb_workspaces.workspaces as ws -from pydantic import Field, PositiveInt, field_validator +from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig from spd.clustering.consts import DistancesMethod +from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.storage import StorageBase from spd.log import logger +from spd.settings import SPD_CACHE_DIR from spd.utils.command_utils import run_script_array_local from spd.utils.general_utils import replace_pydantic_model from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str @@ -69,7 +71,14 @@ def distances_path(self, method: DistancesMethod) -> Path: class ClusteringPipelineConfig(BaseConfig): """Configuration for submitting an ensemble of clustering runs to SLURM.""" - run_clustering_config_path: Path = Field(description="Path to ClusteringRunConfig file.") + run_clustering_config_path: Path | None = Field( + default=None, + description="Path to ClusteringRunConfig file. Mutually exclusive with run_clustering_config.", + ) + run_clustering_config: ClusteringRunConfig | None = Field( + default=None, + description="Inline ClusteringRunConfig. Mutually exclusive with run_clustering_config_path.", + ) n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") distances_methods: list[DistancesMethod] = Field( description="List of method(s) to use for calculating distances" @@ -84,6 +93,25 @@ class ClusteringPipelineConfig(BaseConfig): wandb_entity: str = Field(description="WandB entity (team/user) name") create_git_snapshot: bool = Field(description="Create a git snapshot for the run") + @model_validator(mode="after") + def validate_config_fields(self) -> "ClusteringPipelineConfig": + """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" + has_path: bool = self.run_clustering_config_path is not None + has_inline: bool = self.run_clustering_config is not None + + if not has_path and not has_inline: + raise ValueError( + "Must specify exactly one of 'run_clustering_config_path' or 'run_clustering_config'" + ) + + if has_path and has_inline: + raise ValueError( + "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " + "Use only one." + ) + + return self + @field_validator("distances_methods") @classmethod def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesMethod]: @@ -94,6 +122,46 @@ def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesM return v + def get_config_path(self) -> Path: + """Get the path to the ClusteringRunConfig file. + + - If run_clustering_config_path is provided, returns it directly. + - If run_clustering_config is provided, caches it to a deterministic path + based on its content hash and returns that path. + - if the config file already exists in the cache, assert that it is identical. + + Returns: + Path to the (potentially newly created) ClusteringRunConfig file + """ + if self.run_clustering_config_path is not None: + return self.run_clustering_config_path + + assert self.run_clustering_config is not None, ( + "Either run_clustering_config_path or run_clustering_config must be set" + ) + + # Generate deterministic hash from config + hash_b64: str = self.run_clustering_config.stable_hash_b64() + + # Create cache directory + cache_dir: Path = SPD_CACHE_DIR / "merge_run_configs" + cache_dir.mkdir(parents=True, exist_ok=True) + + # Write config to cache if it doesn't exist + config_path: Path = cache_dir / f"{hash_b64}.json" + if not config_path.exists(): + self.run_clustering_config.to_file(config_path) + logger.info(f"Cached inline config to {config_path}") + else: + # Verify that existing file matches + existing_config = ClusteringRunConfig.from_file(config_path) + if existing_config != self.run_clustering_config: + raise ValueError( + f"Hash collision detected for config hash {hash_b64} at {config_path}\n{existing_config=}\n{self.run_clustering_config=}" + ) + + return config_path + def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: """Create WandB workspace view for clustering runs. @@ -148,7 +216,7 @@ def generate_clustering_commands( "python", "spd/clustering/scripts/run_clustering.py", "--config", - pipeline_config.run_clustering_config_path.as_posix(), + pipeline_config.get_config_path().as_posix(), "--pipeline-run-id", pipeline_run_id, "--idx-in-ensemble", From eb831c0eff389fea68a836a59db8ea5160ef454a Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:26:56 +0100 Subject: [PATCH 23/77] [wip] reorg configs --- spd/clustering/configs/{ => mrc}/example.yaml | 0 spd/clustering/configs/{ => mrc}/resid_mlp1.json | 0 spd/clustering/configs/{ => mrc}/resid_mlp2.json | 0 spd/clustering/configs/{ => mrc}/resid_mlp3.json | 0 spd/clustering/configs/{ => mrc}/simplestories_dev.json | 0 spd/clustering/configs/{ => mrc}/test-resid_mlp1.json | 0 spd/clustering/configs/{ => mrc}/test-simplestories.json | 0 7 files changed, 0 insertions(+), 0 deletions(-) rename spd/clustering/configs/{ => mrc}/example.yaml (100%) rename spd/clustering/configs/{ => mrc}/resid_mlp1.json (100%) rename spd/clustering/configs/{ => mrc}/resid_mlp2.json (100%) rename spd/clustering/configs/{ => mrc}/resid_mlp3.json (100%) rename spd/clustering/configs/{ => mrc}/simplestories_dev.json (100%) rename spd/clustering/configs/{ => mrc}/test-resid_mlp1.json (100%) rename spd/clustering/configs/{ => mrc}/test-simplestories.json (100%) diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/mrc/example.yaml similarity index 100% rename from spd/clustering/configs/example.yaml rename to spd/clustering/configs/mrc/example.yaml diff --git a/spd/clustering/configs/resid_mlp1.json b/spd/clustering/configs/mrc/resid_mlp1.json similarity index 100% rename from spd/clustering/configs/resid_mlp1.json rename to spd/clustering/configs/mrc/resid_mlp1.json diff --git a/spd/clustering/configs/resid_mlp2.json b/spd/clustering/configs/mrc/resid_mlp2.json similarity index 100% rename from spd/clustering/configs/resid_mlp2.json rename to spd/clustering/configs/mrc/resid_mlp2.json diff --git a/spd/clustering/configs/resid_mlp3.json b/spd/clustering/configs/mrc/resid_mlp3.json similarity index 100% rename from spd/clustering/configs/resid_mlp3.json rename to spd/clustering/configs/mrc/resid_mlp3.json diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/mrc/simplestories_dev.json similarity index 100% rename from spd/clustering/configs/simplestories_dev.json rename to spd/clustering/configs/mrc/simplestories_dev.json diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/mrc/test-resid_mlp1.json similarity index 100% rename from spd/clustering/configs/test-resid_mlp1.json rename to spd/clustering/configs/mrc/test-resid_mlp1.json diff --git a/spd/clustering/configs/test-simplestories.json b/spd/clustering/configs/mrc/test-simplestories.json similarity index 100% rename from spd/clustering/configs/test-simplestories.json rename to spd/clustering/configs/mrc/test-simplestories.json From 89e5c36aef43139e7c19364f85a39f0d1a7b1ea0 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:52:03 +0100 Subject: [PATCH 24/77] added default `None` for slurm partition and job name prefix --- spd/clustering/scripts/run_pipeline.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 34a7ef2f8..d970cc31d 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -84,14 +84,18 @@ class ClusteringPipelineConfig(BaseConfig): description="List of method(s) to use for calculating distances" ) base_output_dir: Path = Field(description="Base directory for outputs of clustering runs.") - slurm_job_name_prefix: str | None = Field(description="Prefix for SLURM job names") - slurm_partition: str | None = Field(description="SLURM partition to use") + slurm_job_name_prefix: str | None = Field( + default=None, description="Prefix for SLURM job names" + ) + slurm_partition: str | None = Field(default=None, description="SLURM partition to use") wandb_project: str | None = Field( default=None, description="Weights & Biases project name (set to None to disable WandB logging)", ) - wandb_entity: str = Field(description="WandB entity (team/user) name") - create_git_snapshot: bool = Field(description="Create a git snapshot for the run") + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") + create_git_snapshot: bool = Field( + default=False, description="Create a git snapshot for the run" + ) @model_validator(mode="after") def validate_config_fields(self) -> "ClusteringPipelineConfig": From 8910bb479081dd8be312707b88fed20f5b18b69d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 11:54:33 +0100 Subject: [PATCH 25/77] refactor configs, add config tests --- spd/clustering/configs/README.md | 1 + spd/clustering/configs/mrc/resid_mlp1.json | 5 +- spd/clustering/configs/mrc/resid_mlp2.json | 4 +- spd/clustering/configs/mrc/resid_mlp3.json | 23 - .../configs/pipeline-dev-simplestories.yaml | 2 +- .../configs/pipeline-test-resid_mlp1.yaml | 2 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- tests/clustering/test_pipeline_config.py | 473 ++++++++++++++++++ 9 files changed, 480 insertions(+), 34 deletions(-) create mode 100644 spd/clustering/configs/README.md delete mode 100644 spd/clustering/configs/mrc/resid_mlp3.json create mode 100644 tests/clustering/test_pipeline_config.py diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md new file mode 100644 index 000000000..ed6efe090 --- /dev/null +++ b/spd/clustering/configs/README.md @@ -0,0 +1 @@ +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/mrc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/mrc/resid_mlp1.json b/spd/clustering/configs/mrc/resid_mlp1.json index a7d118ac7..506717282 100644 --- a/spd/clustering/configs/mrc/resid_mlp1.json +++ b/spd/clustering/configs/mrc/resid_mlp1.json @@ -10,12 +10,9 @@ "module_name_filter": null }, "experiment_key": "resid_mlp1", - "distances_methods": ["perm_invariant_hamming"], - "n_batches": 8, "batch_size": 128, - "wandb_enabled": true, "wandb_project": "spd-cluster", - "intervals": { + "logging_intervals": { "stat": 1, "tensor": 5, "plot": 5, diff --git a/spd/clustering/configs/mrc/resid_mlp2.json b/spd/clustering/configs/mrc/resid_mlp2.json index 2be350979..af645f3bd 100644 --- a/spd/clustering/configs/mrc/resid_mlp2.json +++ b/spd/clustering/configs/mrc/resid_mlp2.json @@ -10,11 +10,9 @@ "module_name_filter": null }, "experiment_key": "resid_mlp2", - "n_batches": 16, "batch_size": 1024, - "wandb_enabled": true, "wandb_project": "spd-cluster", - "intervals": { + "logging_intervals": { "stat": 1, "tensor": 5, "plot": 5, diff --git a/spd/clustering/configs/mrc/resid_mlp3.json b/spd/clustering/configs/mrc/resid_mlp3.json deleted file mode 100644 index 5d87e08d5..000000000 --- a/spd/clustering/configs/mrc/resid_mlp3.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "merge_config": { - "activation_threshold": 0.01, - "alpha": 1, - "iters": 350, - "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, - "filter_dead_threshold": 0.01, - "module_name_filter": null - }, - "experiment_key": "resid_mlp3", - "n_batches": 4, - "batch_size": 1024, - "wandb_enabled": true, - "wandb_project": "spd-cluster", - "intervals": { - "stat": 1, - "tensor": 32, - "plot": 32, - "artifact": 32 - } -} \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index 6909c5841..7eef9cfc9 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/simplestories_dev.json" +run_clustering_config_path: "spd/clustering/configs/mrc/simplestories_dev.json" n_runs: 4 distances_methods: ["matching_dist", "matching_dist_vec", "perm_invariant_hamming"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index a413a5438..a3c02da5e 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/test-resid_mlp1.json" +run_clustering_config_path: "spd/clustering/configs/mrc/test-resid_mlp1.json" n_runs: 3 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index e406628c4..c98895ab4 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/test-simplestories.json" +run_clustering_config_path: "spd/clustering/configs/mrc/test-simplestories.json" n_runs: 2 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 6a40c9b29..42db7ac84 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/example.yaml" +run_clustering_config_path: "spd/clustering/configs/mrc/example.yaml" n_runs: 2 distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py new file mode 100644 index 000000000..f64ff15c6 --- /dev/null +++ b/tests/clustering/test_pipeline_config.py @@ -0,0 +1,473 @@ +"""Tests for ClusteringPipelineConfig and ClusteringRunConfig with inline config support.""" + +from pathlib import Path + +import pytest + +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_run_config import ClusteringRunConfig +from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig +from spd.settings import REPO_ROOT, SPD_CACHE_DIR + + +class TestClusteringRunConfigStableHash: + """Test ClusteringRunConfig.stable_hash_b64() method.""" + + def test_deterministic_hash(self): + """Test that stable_hash_b64 is deterministic for identical configs.""" + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + hash1 = config1.stable_hash_b64() + hash2 = config2.stable_hash_b64() + + assert hash1 == hash2 + assert isinstance(hash1, str) + assert len(hash1) > 0 + + def test_different_configs_different_hashes(self): + """Test that different configs produce different hashes.""" + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run2", # Different model_path + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + config3 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=64, # Different batch_size + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + hash1 = config1.stable_hash_b64() + hash2 = config2.stable_hash_b64() + hash3 = config3.stable_hash_b64() + + assert hash1 != hash2 + assert hash1 != hash3 + assert hash2 != hash3 + + def test_hash_is_url_safe(self): + """Test that hash is URL-safe base64 (no padding, URL-safe chars).""" + config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + hash_value = config.stable_hash_b64() + + # Should not contain padding + assert "=" not in hash_value + + # Should only contain URL-safe base64 characters + valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert all(c in valid_chars for c in hash_value) + + def test_nested_config_changes_hash(self): + """Test that changes in nested merge_config affect the hash.""" + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(activation_threshold=0.1), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(activation_threshold=0.2), # Different threshold + ) + + assert config1.stable_hash_b64() != config2.stable_hash_b64() + + +class TestClusteringPipelineConfigValidation: + """Test ClusteringPipelineConfig validation logic.""" + + def test_error_when_neither_field_provided(self): + """Test that error is raised when neither path nor inline config is provided.""" + with pytest.raises(ValueError, match="Must specify exactly one"): + ClusteringPipelineConfig( + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + def test_error_when_both_fields_provided(self): + """Test that error is raised when both path and inline config are provided.""" + with pytest.raises(ValueError, match="Cannot specify both"): + ClusteringPipelineConfig( + run_clustering_config_path=Path("some/path.json"), + run_clustering_config=ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + merge_config=MergeConfig(), + dataset_seed=0, + idx_in_ensemble=0, + ), + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + def test_success_with_only_path(self): + """Test that config validates successfully with only path provided.""" + config = ClusteringPipelineConfig( + run_clustering_config_path=Path("some/path.json"), + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.run_clustering_config_path == Path("some/path.json") + assert config.run_clustering_config is None + + def test_success_with_only_inline_config(self): + """Test that config validates successfully with only inline config provided.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.run_clustering_config_path is None + assert config.run_clustering_config == inline_config + + +class TestClusteringPipelineConfigGetConfigPath: + """Test ClusteringPipelineConfig.get_config_path() method.""" + + def test_returns_path_directly_when_using_path_field(self): + """Test that get_config_path returns the path directly when using run_clustering_config_path.""" + expected_path = Path("some/path.json") + + config = ClusteringPipelineConfig( + run_clustering_config_path=expected_path, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.get_config_path() == expected_path + + def test_creates_cached_file_when_using_inline_config(self): + """Test that get_config_path creates a cached file when using inline config.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + config_path = config.get_config_path() + + # Check that file exists + assert config_path.exists() + + # Check that it's in the expected directory + expected_cache_dir = SPD_CACHE_DIR / "merge_run_configs" + assert config_path.parent == expected_cache_dir + + # Check that filename is the hash + expected_hash = inline_config.stable_hash_b64() + assert config_path.name == f"{expected_hash}.json" + + # Check that file contents match the config + loaded_config = ClusteringRunConfig.from_file(config_path) + assert loaded_config == inline_config + + # Clean up + config_path.unlink() + + def test_reuses_existing_cached_file(self): + """Test that get_config_path reuses existing cached file with same hash.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config1 = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + # First call creates the file + config_path1 = config1.get_config_path() + assert config_path1.exists() + + # Record modification time + mtime1 = config_path1.stat().st_mtime + + # Create another config with same inline config + config2 = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=3, # Different n_runs shouldn't matter + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + # Second call should reuse the file + config_path2 = config2.get_config_path() + + assert config_path1 == config_path2 + assert config_path2.stat().st_mtime == mtime1 # File not modified + + # Clean up + config_path1.unlink() + + def test_hash_collision_detection(self): + """Test that hash collision is detected when existing file differs.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + # Create a fake collision by manually creating a file with same hash + hash_value = inline_config.stable_hash_b64() + cache_dir = SPD_CACHE_DIR / "merge_run_configs" + cache_dir.mkdir(parents=True, exist_ok=True) + collision_path = cache_dir / f"{hash_value}.json" + + # Write a different config to the file + different_config = ClusteringRunConfig( + model_path="wandb:test/project/run2", # Different! + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + different_config.to_file(collision_path) + + try: + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + # Should raise ValueError about hash collision + with pytest.raises(ValueError, match="Hash collision detected"): + config.get_config_path() + finally: + # Clean up + if collision_path.exists(): + collision_path.unlink() + + def test_cache_directory_created_if_not_exists(self): + """Test that cache directory is created if it doesn't exist.""" + inline_config = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + idx_in_ensemble=0, + merge_config=MergeConfig(), + ) + + config = ClusteringPipelineConfig( + run_clustering_config=inline_config, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + cache_dir = SPD_CACHE_DIR / "merge_run_configs" + + # Even if cache dir doesn't exist, get_config_path should create it + config_path = config.get_config_path() + + assert cache_dir.exists() + assert config_path.exists() + + # Clean up + config_path.unlink() + + +class TestAllConfigsValidation: + """Test that all existing config files can be loaded and validated.""" + + def test_all_pipeline_configs_valid(self): + """Test that all pipeline config files are valid.""" + configs_dir = REPO_ROOT / "spd" / "clustering" / "configs" + + # Find all YAML/YML files in the configs directory (not subdirectories) + pipeline_config_files = list(configs_dir.glob("*.yaml")) + list(configs_dir.glob("*.yml")) + + # Should have at least some configs + assert len(pipeline_config_files) > 0, "No pipeline config files found" + + errors: list[tuple[Path, Exception]] = [] + + for config_file in pipeline_config_files: + try: + config = ClusteringPipelineConfig.from_file(config_file) + # Basic sanity checks + assert config.n_runs > 0 + assert len(config.distances_methods) > 0 + assert config.wandb_entity is not None + except Exception as e: + errors.append((config_file, e)) + + # Report all errors at once + if errors: + error_msg = "Failed to validate pipeline configs:\n" + for path, exc in errors: + error_msg += f" - {path.name}: {exc}\n" + pytest.fail(error_msg) + + def test_all_merge_run_configs_valid(self): + """Test that all merge run config files are valid.""" + mrc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "mrc" + + # Find all JSON/YAML/YML files in the mrc directory + mrc_files = ( + list(mrc_dir.glob("*.json")) + + list(mrc_dir.glob("*.yaml")) + + list(mrc_dir.glob("*.yml")) + ) + + # Should have at least some configs + assert len(mrc_files) > 0, "No merge run config files found" + + errors: list[tuple[Path, Exception]] = [] + + for config_file in mrc_files: + try: + config = ClusteringRunConfig.from_file(config_file) + # Basic sanity checks + assert config.batch_size > 0 + assert config.model_path.startswith("wandb:") + assert config.merge_config is not None + except Exception as e: + errors.append((config_file, e)) + + # Report all errors at once + if errors: + error_msg = "Failed to validate merge run configs:\n" + for path, exc in errors: + error_msg += f" - {path.name}: {exc}\n" + pytest.fail(error_msg) + + def test_pipeline_configs_reference_valid_mrc_files(self): + """Test that pipeline configs reference merge run config files that exist.""" + configs_dir = REPO_ROOT / "spd" / "clustering" / "configs" + pipeline_config_files = list(configs_dir.glob("*.yaml")) + list(configs_dir.glob("*.yml")) + + errors: list[tuple[Path, str]] = [] + + for config_file in pipeline_config_files: + try: + config = ClusteringPipelineConfig.from_file(config_file) + + # Skip configs that use inline config + if config.run_clustering_config is not None: + continue + + # Check that referenced file exists + assert config.run_clustering_config_path is not None + mrc_path = REPO_ROOT / config.run_clustering_config_path + + if not mrc_path.exists(): + errors.append( + ( + config_file, + f"References non-existent file: {config.run_clustering_config_path}", + ) + ) + else: + # Try to load the referenced config + try: + ClusteringRunConfig.from_file(mrc_path) + except Exception as e: + errors.append( + ( + config_file, + f"Referenced file {mrc_path.name} is invalid: {e}", + ) + ) + except Exception as e: + errors.append((config_file, f"Failed to load pipeline config: {e}")) + + if errors: + error_msg = "Pipeline configs with invalid merge run config references:\n" + for path, msg in errors: + error_msg += f" - {path.name}: {msg}\n" + pytest.fail(error_msg) From 0b957f5fbf1ef97102c8aef2a8c83816ed9b4635 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:09:40 +0100 Subject: [PATCH 26/77] fix tests --- spd/clustering/merge_run_config.py | 2 + spd/clustering/scripts/run_pipeline.py | 23 ++- tests/clustering/test_pipeline_config.py | 209 +++-------------------- 3 files changed, 45 insertions(+), 189 deletions(-) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index 6671127b9..19450ff66 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -54,6 +54,8 @@ class ClusteringRunConfig(BaseConfig): default=None, description="Ensemble identifier for WandB grouping", ) + # TODO: allow idx_in_ensemble to be `None` if ensemble_id is `None`? + # TODO: allow idx_in_ensemble to be auto-assigned by reading from db if -1? idx_in_ensemble: int = Field(0, description="Index of this run in the ensemble") merge_config: MergeConfig = Field(description="Merge algorithm configuration") diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index d970cc31d..334f5b418 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -98,7 +98,7 @@ class ClusteringPipelineConfig(BaseConfig): ) @model_validator(mode="after") - def validate_config_fields(self) -> "ClusteringPipelineConfig": + def validate_mrc_fields(self) -> "ClusteringPipelineConfig": """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" has_path: bool = self.run_clustering_config_path is not None has_inline: bool = self.run_clustering_config is not None @@ -108,11 +108,19 @@ def validate_config_fields(self) -> "ClusteringPipelineConfig": "Must specify exactly one of 'run_clustering_config_path' or 'run_clustering_config'" ) - if has_path and has_inline: - raise ValueError( - "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " - "Use only one." - ) + if has_path: + if has_inline: + raise ValueError( + "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " + "Use only one." + ) + else: + # Ensure the path exists + # pyright ignore because it doesn't recognize that has_path implies not None + if not self.run_clustering_config_path.exists(): # pyright: ignore[reportOptionalMemberAccess] + raise ValueError( + f"run_clustering_config_path does not exist: {self.run_clustering_config_path = }" + ) return self @@ -138,6 +146,9 @@ def get_config_path(self) -> Path: Path to the (potentially newly created) ClusteringRunConfig file """ if self.run_clustering_config_path is not None: + assert self.run_clustering_config_path.exists(), ( + f"no file at run_clustering_config_path: {self.run_clustering_config_path = }" + ) return self.run_clustering_config_path assert self.run_clustering_config is not None, ( diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index f64ff15c6..826a7fc28 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -13,8 +13,9 @@ class TestClusteringRunConfigStableHash: """Test ClusteringRunConfig.stable_hash_b64() method.""" - def test_deterministic_hash(self): - """Test that stable_hash_b64 is deterministic for identical configs.""" + def test_stable_hash_b64(self): + """Test that stable_hash_b64 is deterministic, unique, and URL-safe.""" + # Create 4 configs: 2 identical, 2 different config1 = ClusteringRunConfig( model_path="wandb:test/project/run1", batch_size=32, @@ -29,83 +30,44 @@ def test_deterministic_hash(self): idx_in_ensemble=0, merge_config=MergeConfig(), ) - - hash1 = config1.stable_hash_b64() - hash2 = config2.stable_hash_b64() - - assert hash1 == hash2 - assert isinstance(hash1, str) - assert len(hash1) > 0 - - def test_different_configs_different_hashes(self): - """Test that different configs produce different hashes.""" - config1 = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) - config2 = ClusteringRunConfig( + config3 = ClusteringRunConfig( model_path="wandb:test/project/run2", # Different model_path batch_size=32, dataset_seed=0, idx_in_ensemble=0, merge_config=MergeConfig(), ) - config3 = ClusteringRunConfig( + config4 = ClusteringRunConfig( model_path="wandb:test/project/run1", - batch_size=64, # Different batch_size + batch_size=32, dataset_seed=0, idx_in_ensemble=0, - merge_config=MergeConfig(), + merge_config=MergeConfig( + activation_threshold=0.2 + ), # Different merge_config to test nested fields ) hash1 = config1.stable_hash_b64() hash2 = config2.stable_hash_b64() hash3 = config3.stable_hash_b64() + hash4 = config4.stable_hash_b64() - assert hash1 != hash2 - assert hash1 != hash3 - assert hash2 != hash3 - - def test_hash_is_url_safe(self): - """Test that hash is URL-safe base64 (no padding, URL-safe chars).""" - config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) + # Identical configs produce identical hashes + assert hash1 == hash2 - hash_value = config.stable_hash_b64() + # Different configs produce different hashes + assert hash1 != hash3 + assert hash1 != hash4 + assert hash3 != hash4 - # Should not contain padding - assert "=" not in hash_value + # Hashes are strings + assert isinstance(hash1, str) + assert len(hash1) > 0 - # Should only contain URL-safe base64 characters + # Hashes are URL-safe base64 (no padding, URL-safe chars only) + assert "=" not in hash1 valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") - assert all(c in valid_chars for c in hash_value) - - def test_nested_config_changes_hash(self): - """Test that changes in nested merge_config affect the hash.""" - config1 = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(activation_threshold=0.1), - ) - config2 = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(activation_threshold=0.2), # Different threshold - ) - - assert config1.stable_hash_b64() != config2.stable_hash_b64() + assert all(c in valid_chars for c in hash1) class TestClusteringPipelineConfigValidation: @@ -145,49 +107,13 @@ def test_error_when_both_fields_provided(self): create_git_snapshot=False, ) - def test_success_with_only_path(self): - """Test that config validates successfully with only path provided.""" - config = ClusteringPipelineConfig( - run_clustering_config_path=Path("some/path.json"), - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - assert config.run_clustering_config_path == Path("some/path.json") - assert config.run_clustering_config is None - - def test_success_with_only_inline_config(self): - """Test that config validates successfully with only inline config provided.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) - - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - assert config.run_clustering_config_path is None - assert config.run_clustering_config == inline_config - class TestClusteringPipelineConfigGetConfigPath: """Test ClusteringPipelineConfig.get_config_path() method.""" def test_returns_path_directly_when_using_path_field(self): """Test that get_config_path returns the path directly when using run_clustering_config_path.""" - expected_path = Path("some/path.json") + expected_path = Path("spd/clustering/configs/mrc/resid_mlp1.json") config = ClusteringPipelineConfig( run_clustering_config_path=expected_path, @@ -330,36 +256,6 @@ def test_hash_collision_detection(self): if collision_path.exists(): collision_path.unlink() - def test_cache_directory_created_if_not_exists(self): - """Test that cache directory is created if it doesn't exist.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - idx_in_ensemble=0, - merge_config=MergeConfig(), - ) - - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - cache_dir = SPD_CACHE_DIR / "merge_run_configs" - - # Even if cache dir doesn't exist, get_config_path should create it - config_path = config.get_config_path() - - assert cache_dir.exists() - assert config_path.exists() - - # Clean up - config_path.unlink() - class TestAllConfigsValidation: """Test that all existing config files can be loaded and validated.""" @@ -378,11 +274,8 @@ def test_all_pipeline_configs_valid(self): for config_file in pipeline_config_files: try: - config = ClusteringPipelineConfig.from_file(config_file) - # Basic sanity checks - assert config.n_runs > 0 - assert len(config.distances_methods) > 0 - assert config.wandb_entity is not None + _config = ClusteringPipelineConfig.from_file(config_file) + assert _config.get_config_path().exists() except Exception as e: errors.append((config_file, e)) @@ -411,11 +304,7 @@ def test_all_merge_run_configs_valid(self): for config_file in mrc_files: try: - config = ClusteringRunConfig.from_file(config_file) - # Basic sanity checks - assert config.batch_size > 0 - assert config.model_path.startswith("wandb:") - assert config.merge_config is not None + _config = ClusteringRunConfig.from_file(config_file) except Exception as e: errors.append((config_file, e)) @@ -425,49 +314,3 @@ def test_all_merge_run_configs_valid(self): for path, exc in errors: error_msg += f" - {path.name}: {exc}\n" pytest.fail(error_msg) - - def test_pipeline_configs_reference_valid_mrc_files(self): - """Test that pipeline configs reference merge run config files that exist.""" - configs_dir = REPO_ROOT / "spd" / "clustering" / "configs" - pipeline_config_files = list(configs_dir.glob("*.yaml")) + list(configs_dir.glob("*.yml")) - - errors: list[tuple[Path, str]] = [] - - for config_file in pipeline_config_files: - try: - config = ClusteringPipelineConfig.from_file(config_file) - - # Skip configs that use inline config - if config.run_clustering_config is not None: - continue - - # Check that referenced file exists - assert config.run_clustering_config_path is not None - mrc_path = REPO_ROOT / config.run_clustering_config_path - - if not mrc_path.exists(): - errors.append( - ( - config_file, - f"References non-existent file: {config.run_clustering_config_path}", - ) - ) - else: - # Try to load the referenced config - try: - ClusteringRunConfig.from_file(mrc_path) - except Exception as e: - errors.append( - ( - config_file, - f"Referenced file {mrc_path.name} is invalid: {e}", - ) - ) - except Exception as e: - errors.append((config_file, f"Failed to load pipeline config: {e}")) - - if errors: - error_msg = "Pipeline configs with invalid merge run config references:\n" - for path, msg in errors: - error_msg += f" - {path.name}: {msg}\n" - pytest.fail(error_msg) From 7de545b1e5502146623c83a16fed704d0eeff007 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:33:04 +0100 Subject: [PATCH 27/77] allow `None` or `-1` idx_in_ensemble - idx_in_ensemble is None iff ensemble_id is None - idx_in_ensemble == -1 will make register_clustering_run() auto-assign next avalible index - added tests for ensemble registry --- spd/clustering/ensemble_registry.py | 28 ++- spd/clustering/merge_run_config.py | 41 ++-- spd/clustering/scripts/run_clustering.py | 22 ++- tests/clustering/test_ensemble_registry.py | 215 +++++++++++++++++++++ tests/clustering/test_pipeline_config.py | 9 - 5 files changed, 278 insertions(+), 37 deletions(-) create mode 100644 tests/clustering/test_ensemble_registry.py diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py index 7756877d8..b3b1711ab 100644 --- a/spd/clustering/ensemble_registry.py +++ b/spd/clustering/ensemble_registry.py @@ -6,6 +6,7 @@ import sqlite3 from contextlib import contextmanager +from spd.clustering.merge_run_config import ClusteringEnsembleIndex from spd.settings import SPD_CACHE_DIR # SQLite database path @@ -39,21 +40,42 @@ def _get_connection(): conn.close() -def register_clustering_run(pipeline_run_id: str, idx: int, clustering_run_id: str) -> None: +def register_clustering_run( + pipeline_run_id: str, idx: ClusteringEnsembleIndex, clustering_run_id: str +) -> int: """Register a clustering run as part of a pipeline ensemble. Args: pipeline_run_id: The ensemble/pipeline run ID - idx: Index of this run in the ensemble + idx: Index of this run in the ensemble. If -1, auto-assigns the next available index. clustering_run_id: The individual clustering run ID + + Returns: + The index assigned to this run (either the provided idx or the auto-assigned one) """ with _get_connection() as conn: + # Use BEGIN IMMEDIATE for thread-safe auto-increment + conn.execute("BEGIN IMMEDIATE") + + assigned_idx: int + if idx == -1: + # Auto-assign next available index + cursor = conn.execute( + "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", + (pipeline_run_id,), + ) + assigned_idx = cursor.fetchone()[0] + else: + assigned_idx = idx + conn.execute( "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", - (pipeline_run_id, idx, clustering_run_id), + (pipeline_run_id, assigned_idx, clustering_run_id), ) conn.commit() + return assigned_idx + def get_clustering_runs(pipeline_run_id: str) -> list[tuple[int, str]]: """Get all clustering runs for a pipeline ensemble. diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/merge_run_config.py index 19450ff66..f82e00203 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/merge_run_config.py @@ -4,9 +4,9 @@ import hashlib import json from pathlib import Path -from typing import Any, Self +from typing import Any, Literal, Self -from pydantic import Field, PositiveInt, model_validator +from pydantic import Field, NonNegativeInt, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig from spd.clustering.merge_config import MergeConfig @@ -31,6 +31,10 @@ class LoggingIntervals(BaseConfig): ) +ClusteringEnsembleIndex = NonNegativeInt | Literal[-1] +"index in an ensemble; -1 will cause register_clustering_run() to auto-assign the next available index" + + class ClusteringRunConfig(BaseConfig): """Configuration for a single clustering run. @@ -54,9 +58,11 @@ class ClusteringRunConfig(BaseConfig): default=None, description="Ensemble identifier for WandB grouping", ) - # TODO: allow idx_in_ensemble to be `None` if ensemble_id is `None`? - # TODO: allow idx_in_ensemble to be auto-assigned by reading from db if -1? - idx_in_ensemble: int = Field(0, description="Index of this run in the ensemble") + # TODO: given our use of `register_clustering_run()` and the atomic guarantees of that, do we even need this index? + # probably still nice to have for clarity + idx_in_ensemble: ClusteringEnsembleIndex | None = Field( + default=None, description="Index of this run in the ensemble" + ) merge_config: MergeConfig = Field(description="Merge algorithm configuration") logging_intervals: LoggingIntervals = Field( @@ -74,16 +80,6 @@ class ClusteringRunConfig(BaseConfig): description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", ) - # TODO: no way to check this without knowing task - # @model_validator(mode="after") - # def validate_streaming_compatibility(self) -> Self: - # """Ensure dataset_streaming is only enabled for compatible tasks.""" - # if self.dataset_streaming and self.task_name != "lm": - # raise ValueError( - # f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" - # ) - # return self - @model_validator(mode="before") def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: experiment_key: str | None = values.get("experiment_key") @@ -105,11 +101,18 @@ def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: return values - @model_validator(mode="after") - def validate_model_path(self) -> Self: + @field_validator("model_path") + def validate_model_path(cls, v: str) -> str: """Validate that model_path is a proper WandB path.""" - if not self.model_path.startswith("wandb:"): - raise ValueError(f"model_path must start with 'wandb:', got: {self.model_path}") + if not v.startswith("wandb:"): + raise ValueError(f"model_path must start with 'wandb:', got: {v}") + return v + + @model_validator(mode="after") + def validate_ensemble_id_index(self) -> Self: + assert (self.idx_in_ensemble is None) == (self.ensemble_id is None), ( + "If ensemble_id is None, idx_in_ensemble must also be None" + ) return self @property diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 7c614407a..04c48be7f 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -229,17 +229,29 @@ def main(run_config: ClusteringRunConfig) -> Path: # Register with ensemble if this is part of a pipeline if run_config.ensemble_id: assert run_config.idx_in_ensemble is not None, ( - "idx_in_ensemble must be set when ensemble_id is provided" + "idx_in_ensemble must be set when ensemble_id is provided! to auto-assign, set idx_in_ensemble = -1.\n" + f"{'!' * 50}\nNOTE: this should be an unreachable state -- such a case should have been caught by the pydantic validator.\n{'!' * 50}" ) - register_clustering_run( + assigned_idx: int = register_clustering_run( run_config.ensemble_id, run_config.idx_in_ensemble, clustering_run_id, ) + + # Update config if index was auto-assigned + if run_config.idx_in_ensemble == -1: + run_config = replace_pydantic_model(run_config, {"idx_in_ensemble": assigned_idx}) + logger.info(f"Auto-assigned ensemble index: {assigned_idx}") + logger.info( - f"Registered with pipeline {run_config.ensemble_id} at index {run_config.idx_in_ensemble} in {_ENSEMBLE_REGISTRY_DB}" + f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx} in {_ENSEMBLE_REGISTRY_DB}" ) + # save config + run_config.to_file(storage.config_path) + logger.info(f"Config saved to {storage.config_path}") + + # start logger.info("Starting clustering run") logger.info(f"Output directory: {storage.base_dir}") device = get_device() @@ -347,9 +359,7 @@ def main(run_config: ClusteringRunConfig) -> Path: log_callback=log_callback, ) - # 8. Save merge history and config - run_config.to_file(storage.config_path) - logger.info(f"Config saved to {storage.config_path}") + # 8. Save merge history history.save(storage.history_path) logger.info(f"History saved to {storage.history_path}") diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py new file mode 100644 index 000000000..e71e8e228 --- /dev/null +++ b/tests/clustering/test_ensemble_registry.py @@ -0,0 +1,215 @@ +"""Tests for ensemble_registry module.""" + +import tempfile +from pathlib import Path +from typing import Any + +import pytest + +from spd.clustering.ensemble_registry import ( + get_clustering_runs, + register_clustering_run, +) + + +@pytest.fixture +def temp_registry_db(monkeypatch: Any): + """Create a temporary registry database for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + temp_db_path = Path(tmpdir) / "test_registry.db" + monkeypatch.setattr("spd.clustering.ensemble_registry._ENSEMBLE_REGISTRY_DB", temp_db_path) + yield temp_db_path + + +class TestRegisterClusteringRun: + """Test register_clustering_run() function.""" + + def test_register_with_explicit_index(self, _temp_registry_db: Any): + """Test registering a run with an explicit index.""" + pipeline_id = "pipeline_001" + idx = 0 + run_id = "run_001" + + assigned_idx = register_clustering_run(pipeline_id, idx, run_id) + + # Should return the same index + assert assigned_idx == idx + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001")] + + def test_register_multiple_explicit_indices(self, _temp_registry_db: Any): + """Test registering multiple runs with explicit indices.""" + pipeline_id = "pipeline_002" + + idx0 = register_clustering_run(pipeline_id, 0, "run_001") + idx1 = register_clustering_run(pipeline_id, 1, "run_002") + idx2 = register_clustering_run(pipeline_id, 2, "run_003") + + assert idx0 == 0 + assert idx1 == 1 + assert idx2 == 2 + + # Verify order in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] + + def test_auto_assign_single_index(self, _temp_registry_db: Any): + """Test auto-assigning a single index.""" + pipeline_id = "pipeline_003" + run_id = "run_001" + + assigned_idx = register_clustering_run(pipeline_id, -1, run_id) + + # First auto-assigned index should be 0 + assert assigned_idx == 0 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001")] + + def test_auto_assign_multiple_indices(self, _temp_registry_db: Any): + """Test auto-assigning multiple indices sequentially.""" + pipeline_id = "pipeline_004" + + idx0 = register_clustering_run(pipeline_id, -1, "run_001") + idx1 = register_clustering_run(pipeline_id, -1, "run_002") + idx2 = register_clustering_run(pipeline_id, -1, "run_003") + + # Should auto-assign 0, 1, 2 + assert idx0 == 0 + assert idx1 == 1 + assert idx2 == 2 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] + + def test_auto_assign_after_explicit_indices(self, _temp_registry_db: Any): + """Test that auto-assignment continues from max existing index.""" + pipeline_id = "pipeline_005" + + # Register explicit indices + register_clustering_run(pipeline_id, 0, "run_001") + register_clustering_run(pipeline_id, 1, "run_002") + + # Auto-assign should get index 2 + idx = register_clustering_run(pipeline_id, -1, "run_003") + assert idx == 2 + + # Auto-assign again should get index 3 + idx = register_clustering_run(pipeline_id, -1, "run_004") + assert idx == 3 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_001"), + (1, "run_002"), + (2, "run_003"), + (3, "run_004"), + ] + + def test_auto_assign_with_gaps(self, _temp_registry_db: Any): + """Test that auto-assignment uses max+1, even with gaps.""" + pipeline_id = "pipeline_006" + + # Register with gaps: 0, 5, 10 + register_clustering_run(pipeline_id, 0, "run_001") + register_clustering_run(pipeline_id, 5, "run_002") + register_clustering_run(pipeline_id, 10, "run_003") + + # Auto-assign should get index 11 (max + 1) + idx = register_clustering_run(pipeline_id, -1, "run_004") + assert idx == 11 + + # Verify in database (ordered by idx) + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_001"), + (5, "run_002"), + (10, "run_003"), + (11, "run_004"), + ] + + def test_mixed_explicit_and_auto_assign(self, _temp_registry_db: Any): + """Test mixing explicit and auto-assigned indices.""" + pipeline_id = "pipeline_007" + + # Mix of explicit and auto-assigned + idx0 = register_clustering_run(pipeline_id, -1, "run_001") # auto: 0 + idx1 = register_clustering_run(pipeline_id, 5, "run_002") # explicit: 5 + idx2 = register_clustering_run(pipeline_id, -1, "run_003") # auto: 6 + idx3 = register_clustering_run(pipeline_id, 2, "run_004") # explicit: 2 + idx4 = register_clustering_run(pipeline_id, -1, "run_005") # auto: 7 + + assert idx0 == 0 + assert idx1 == 5 + assert idx2 == 6 + assert idx3 == 2 + assert idx4 == 7 + + # Verify in database (ordered by idx) + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_001"), + (2, "run_004"), + (5, "run_002"), + (6, "run_003"), + (7, "run_005"), + ] + + def test_different_pipelines_independent(self, _temp_registry_db: Any): + """Test that different pipelines have independent index sequences.""" + pipeline_a = "pipeline_a" + pipeline_b = "pipeline_b" + + # Both should start at 0 when auto-assigning + idx_a0 = register_clustering_run(pipeline_a, -1, "run_a1") + idx_b0 = register_clustering_run(pipeline_b, -1, "run_b1") + + assert idx_a0 == 0 + assert idx_b0 == 0 + + # Both should increment independently + idx_a1 = register_clustering_run(pipeline_a, -1, "run_a2") + idx_b1 = register_clustering_run(pipeline_b, -1, "run_b2") + + assert idx_a1 == 1 + assert idx_b1 == 1 + + # Verify in database + runs_a = get_clustering_runs(pipeline_a) + runs_b = get_clustering_runs(pipeline_b) + + assert runs_a == [(0, "run_a1"), (1, "run_a2")] + assert runs_b == [(0, "run_b1"), (1, "run_b2")] + + +class TestGetClusteringRuns: + """Test get_clustering_runs() function.""" + + def test_get_empty_pipeline(self, _temp_registry_db: Any): + """Test getting runs from a pipeline that doesn't exist.""" + runs = get_clustering_runs("nonexistent_pipeline") + assert runs == [] + + def test_get_runs_sorted_by_index(self, _temp_registry_db: Any): + """Test that runs are returned sorted by index.""" + pipeline_id = "pipeline_sort" + + # Register out of order + register_clustering_run(pipeline_id, 5, "run_005") + register_clustering_run(pipeline_id, 1, "run_001") + register_clustering_run(pipeline_id, 3, "run_003") + register_clustering_run(pipeline_id, 0, "run_000") + + # Should be returned in sorted order + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_000"), + (1, "run_001"), + (3, "run_003"), + (5, "run_005"), + ] diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 826a7fc28..010195694 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -20,28 +20,24 @@ def test_stable_hash_b64(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) config2 = ClusteringRunConfig( model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) config3 = ClusteringRunConfig( model_path="wandb:test/project/run2", # Different model_path batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) config4 = ClusteringRunConfig( model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig( activation_threshold=0.2 ), # Different merge_config to test nested fields @@ -96,7 +92,6 @@ def test_error_when_both_fields_provided(self): batch_size=32, merge_config=MergeConfig(), dataset_seed=0, - idx_in_ensemble=0, ), n_runs=2, distances_methods=["perm_invariant_hamming"], @@ -132,7 +127,6 @@ def test_creates_cached_file_when_using_inline_config(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) @@ -171,7 +165,6 @@ def test_reuses_existing_cached_file(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) @@ -216,7 +209,6 @@ def test_hash_collision_detection(self): model_path="wandb:test/project/run1", batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) @@ -231,7 +223,6 @@ def test_hash_collision_detection(self): model_path="wandb:test/project/run2", # Different! batch_size=32, dataset_seed=0, - idx_in_ensemble=0, merge_config=MergeConfig(), ) different_config.to_file(collision_path) From 3d45ac4f38ab1a15415384ba05955eb56aa4fa3f Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:40:14 +0100 Subject: [PATCH 28/77] whoops, wrong name on fixture --- tests/clustering/test_ensemble_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py index e71e8e228..ff751d5c0 100644 --- a/tests/clustering/test_ensemble_registry.py +++ b/tests/clustering/test_ensemble_registry.py @@ -13,7 +13,7 @@ @pytest.fixture -def temp_registry_db(monkeypatch: Any): +def _temp_registry_db(monkeypatch: Any): """Create a temporary registry database for testing.""" with tempfile.TemporaryDirectory() as tmpdir: temp_db_path = Path(tmpdir) / "test_registry.db" From 4adde100ca740571a13efd9309639c8958d09bcb Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:41:45 +0100 Subject: [PATCH 29/77] fix idx passed in tests when not needed --- tests/clustering/test_run_clustering_happy_path.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 91a7cf2ad..12c12c8b0 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -16,7 +16,6 @@ def test_run_clustering_happy_path(): model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run batch_size=4, dataset_seed=0, - idx_in_ensemble=0, base_output_dir=Path(temp_dir), ensemble_id=None, merge_config=MergeConfig( From 189b64aee14189f5e2f566caf0bf3c1dc26799aa Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:43:21 +0100 Subject: [PATCH 30/77] rename "mrc" -> "crc" in paths I forgot its no longer called "MergeRunConfig" --- spd/clustering/configs/README.md | 2 +- .../configs/{mrc => crc}/example.yaml | 0 .../configs/{mrc => crc}/resid_mlp1.json | 0 .../configs/{mrc => crc}/resid_mlp2.json | 0 .../{mrc => crc}/simplestories_dev.json | 0 .../configs/{mrc => crc}/test-resid_mlp1.json | 0 .../{mrc => crc}/test-simplestories.json | 0 .../configs/pipeline-dev-simplestories.yaml | 2 +- .../configs/pipeline-test-resid_mlp1.yaml | 2 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- spd/clustering/scripts/run_pipeline.py | 2 +- tests/clustering/test_pipeline_config.py | 18 +++++++++--------- 13 files changed, 15 insertions(+), 15 deletions(-) rename spd/clustering/configs/{mrc => crc}/example.yaml (100%) rename spd/clustering/configs/{mrc => crc}/resid_mlp1.json (100%) rename spd/clustering/configs/{mrc => crc}/resid_mlp2.json (100%) rename spd/clustering/configs/{mrc => crc}/simplestories_dev.json (100%) rename spd/clustering/configs/{mrc => crc}/test-resid_mlp1.json (100%) rename spd/clustering/configs/{mrc => crc}/test-simplestories.json (100%) diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md index ed6efe090..51db8e8a0 100644 --- a/spd/clustering/configs/README.md +++ b/spd/clustering/configs/README.md @@ -1 +1 @@ -this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/mrc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/mrc/example.yaml b/spd/clustering/configs/crc/example.yaml similarity index 100% rename from spd/clustering/configs/mrc/example.yaml rename to spd/clustering/configs/crc/example.yaml diff --git a/spd/clustering/configs/mrc/resid_mlp1.json b/spd/clustering/configs/crc/resid_mlp1.json similarity index 100% rename from spd/clustering/configs/mrc/resid_mlp1.json rename to spd/clustering/configs/crc/resid_mlp1.json diff --git a/spd/clustering/configs/mrc/resid_mlp2.json b/spd/clustering/configs/crc/resid_mlp2.json similarity index 100% rename from spd/clustering/configs/mrc/resid_mlp2.json rename to spd/clustering/configs/crc/resid_mlp2.json diff --git a/spd/clustering/configs/mrc/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json similarity index 100% rename from spd/clustering/configs/mrc/simplestories_dev.json rename to spd/clustering/configs/crc/simplestories_dev.json diff --git a/spd/clustering/configs/mrc/test-resid_mlp1.json b/spd/clustering/configs/crc/test-resid_mlp1.json similarity index 100% rename from spd/clustering/configs/mrc/test-resid_mlp1.json rename to spd/clustering/configs/crc/test-resid_mlp1.json diff --git a/spd/clustering/configs/mrc/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json similarity index 100% rename from spd/clustering/configs/mrc/test-simplestories.json rename to spd/clustering/configs/crc/test-simplestories.json diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index 7eef9cfc9..dc6e729d3 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/simplestories_dev.json" +run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" n_runs: 4 distances_methods: ["matching_dist", "matching_dist_vec", "perm_invariant_hamming"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index a3c02da5e..db72fa3c0 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/test-resid_mlp1.json" +run_clustering_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" n_runs: 3 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index c98895ab4..24e686023 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/test-simplestories.json" +run_clustering_config_path: "spd/clustering/configs/crc/test-simplestories.json" n_runs: 2 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 42db7ac84..297b47d7b 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/mrc/example.yaml" +run_clustering_config_path: "spd/clustering/configs/crc/example.yaml" n_runs: 2 distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 334f5b418..7b04bcfc0 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -98,7 +98,7 @@ class ClusteringPipelineConfig(BaseConfig): ) @model_validator(mode="after") - def validate_mrc_fields(self) -> "ClusteringPipelineConfig": + def validate_crc_fields(self) -> "ClusteringPipelineConfig": """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" has_path: bool = self.run_clustering_config_path is not None has_inline: bool = self.run_clustering_config is not None diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 010195694..723192118 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -108,7 +108,7 @@ class TestClusteringPipelineConfigGetConfigPath: def test_returns_path_directly_when_using_path_field(self): """Test that get_config_path returns the path directly when using run_clustering_config_path.""" - expected_path = Path("spd/clustering/configs/mrc/resid_mlp1.json") + expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") config = ClusteringPipelineConfig( run_clustering_config_path=expected_path, @@ -279,21 +279,21 @@ def test_all_pipeline_configs_valid(self): def test_all_merge_run_configs_valid(self): """Test that all merge run config files are valid.""" - mrc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "mrc" + crc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "crc" - # Find all JSON/YAML/YML files in the mrc directory - mrc_files = ( - list(mrc_dir.glob("*.json")) - + list(mrc_dir.glob("*.yaml")) - + list(mrc_dir.glob("*.yml")) + # Find all JSON/YAML/YML files in the crc directory + crc_files = ( + list(crc_dir.glob("*.json")) + + list(crc_dir.glob("*.yaml")) + + list(crc_dir.glob("*.yml")) ) # Should have at least some configs - assert len(mrc_files) > 0, "No merge run config files found" + assert len(crc_files) > 0, "No merge run config files found" errors: list[tuple[Path, Exception]] = [] - for config_file in mrc_files: + for config_file in crc_files: try: _config = ClusteringRunConfig.from_file(config_file) except Exception as e: From 57f445a1fb276a47652a6a070d1da68a4637cf03 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:44:25 +0100 Subject: [PATCH 31/77] rename merge_run_config.py -> clustering_run_config.py --- .../{merge_run_config.py => clustering_run_config.py} | 0 spd/clustering/ensemble_registry.py | 2 +- spd/clustering/scripts/run_clustering.py | 2 +- spd/clustering/scripts/run_pipeline.py | 2 +- tests/clustering/scripts/cluster_ss.py | 2 +- tests/clustering/test_pipeline_config.py | 2 +- tests/clustering/test_run_clustering_happy_path.py | 2 +- 7 files changed, 6 insertions(+), 6 deletions(-) rename spd/clustering/{merge_run_config.py => clustering_run_config.py} (100%) diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/clustering_run_config.py similarity index 100% rename from spd/clustering/merge_run_config.py rename to spd/clustering/clustering_run_config.py diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py index b3b1711ab..540312d8e 100644 --- a/spd/clustering/ensemble_registry.py +++ b/spd/clustering/ensemble_registry.py @@ -6,7 +6,7 @@ import sqlite3 from contextlib import contextmanager -from spd.clustering.merge_run_config import ClusteringEnsembleIndex +from spd.clustering.clustering_run_config import ClusteringEnsembleIndex from spd.settings import SPD_CACHE_DIR # SQLite database path diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 04c48be7f..6b52a8bb3 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -31,6 +31,7 @@ component_activations, process_activations, ) +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( ActivationsTensor, BatchTensor, @@ -43,7 +44,6 @@ from spd.clustering.math.semilog import semilog from spd.clustering.merge import merge_iteration from spd.clustering.merge_history import MergeHistory -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration from spd.clustering.storage import StorageBase diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 7b04bcfc0..5910599b7 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -28,8 +28,8 @@ from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import DistancesMethod -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.storage import StorageBase from spd.log import logger from spd.settings import SPD_CACHE_DIR diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index acb6f394e..173b8abe5 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,11 +16,11 @@ component_activations, process_activations, ) +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_dists_distribution from spd.models.component_model import ComponentModel, SPDRunInfo diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 723192118..311981037 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -4,8 +4,8 @@ import pytest +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig from spd.settings import REPO_ROOT, SPD_CACHE_DIR diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 12c12c8b0..57bb5e1ff 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -3,8 +3,8 @@ import pytest +from spd.clustering.clustering_run_config import ClusteringRunConfig, LoggingIntervals from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_run_config import ClusteringRunConfig, LoggingIntervals from spd.clustering.scripts.run_clustering import main From 91f53484f0fab58dca1b0f018fa09680ee6849fa Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:45:43 +0100 Subject: [PATCH 32/77] fix pyright --- tests/clustering/test_ensemble_registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py index ff751d5c0..bb2936cfd 100644 --- a/tests/clustering/test_ensemble_registry.py +++ b/tests/clustering/test_ensemble_registry.py @@ -13,7 +13,7 @@ @pytest.fixture -def _temp_registry_db(monkeypatch: Any): +def _temp_registry_db(monkeypatch: Any): # pyright: ignore[reportUnusedFunction] """Create a temporary registry database for testing.""" with tempfile.TemporaryDirectory() as tmpdir: temp_db_path = Path(tmpdir) / "test_registry.db" From 11e5501637625630fadb87ec7e67eadff53f6e3b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 12:50:22 +0100 Subject: [PATCH 33/77] fix idx_in_ensemble being passed in tests --- tests/clustering/scripts/cluster_ss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 173b8abe5..3f5da34a0 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -51,7 +51,6 @@ model_path=MODEL_PATH, batch_size=2, dataset_seed=42, - idx_in_ensemble=0, dataset_streaming=True, # no effect since we do this manually ) From 1d96054af7ab1d86352acdfb81cac66eac42b801 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 13:17:17 +0100 Subject: [PATCH 34/77] rename cache dir 'merge_run_configs' -> 'clustering_run_configs' --- spd/clustering/scripts/run_pipeline.py | 2 +- tests/clustering/test_pipeline_config.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 5910599b7..cebc8fb06 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -159,7 +159,7 @@ def get_config_path(self) -> Path: hash_b64: str = self.run_clustering_config.stable_hash_b64() # Create cache directory - cache_dir: Path = SPD_CACHE_DIR / "merge_run_configs" + cache_dir: Path = SPD_CACHE_DIR / "clustering_run_configs" cache_dir.mkdir(parents=True, exist_ok=True) # Write config to cache if it doesn't exist diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 311981037..05dfa17b0 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -145,7 +145,7 @@ def test_creates_cached_file_when_using_inline_config(self): assert config_path.exists() # Check that it's in the expected directory - expected_cache_dir = SPD_CACHE_DIR / "merge_run_configs" + expected_cache_dir = SPD_CACHE_DIR / "clustering_run_configs" assert config_path.parent == expected_cache_dir # Check that filename is the hash @@ -214,7 +214,7 @@ def test_hash_collision_detection(self): # Create a fake collision by manually creating a file with same hash hash_value = inline_config.stable_hash_b64() - cache_dir = SPD_CACHE_DIR / "merge_run_configs" + cache_dir = SPD_CACHE_DIR / "clustering_run_configs" cache_dir.mkdir(parents=True, exist_ok=True) collision_path = cache_dir / f"{hash_value}.json" @@ -277,7 +277,7 @@ def test_all_pipeline_configs_valid(self): error_msg += f" - {path.name}: {exc}\n" pytest.fail(error_msg) - def test_all_merge_run_configs_valid(self): + def test_all_clustering_run_configs_valid(self): """Test that all merge run config files are valid.""" crc_dir = REPO_ROOT / "spd" / "clustering" / "configs" / "crc" From a1f1146d0480b4ee08cfc2a7070be6170b9394d1 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 13:57:09 +0100 Subject: [PATCH 35/77] remove component popping changes brought in from PR https://github.com/goodfire-ai/spd/pull/206 branch clustering/refactor-multi-batch commit [9cbb52f](https://github.com/goodfire-ai/spd/pull/206/commits/9cbb52fd09cac8d79481a16de0a9e4c517960a33) --- spd/clustering/compute_costs.py | 110 +----------------- spd/clustering/configs/crc/example.yaml | 1 - spd/clustering/configs/crc/resid_mlp1.json | 1 - spd/clustering/configs/crc/resid_mlp2.json | 1 - .../configs/crc/simplestories_dev.json | 3 +- .../configs/crc/test-resid_mlp1.json | 1 - .../configs/crc/test-simplestories.json | 1 - spd/clustering/merge.py | 51 +------- spd/clustering/merge_config.py | 5 - tests/clustering/scripts/cluster_resid_mlp.py | 1 - tests/clustering/scripts/cluster_ss.py | 3 +- tests/clustering/test_calc_distances.py | 1 - tests/clustering/test_merge_config.py | 2 - tests/clustering/test_merge_integration.py | 36 ------ .../test_run_clustering_happy_path.py | 1 - 15 files changed, 4 insertions(+), 214 deletions(-) diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index ba1ff274c..f1b3425d1 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -1,7 +1,7 @@ import math import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from spd.clustering.consts import ClusterCoactivationShaped, MergePair @@ -187,111 +187,3 @@ def recompute_coacts_merge_pair( coact_new, activation_mask_new, ) - - -def recompute_coacts_pop_group( - coact: ClusterCoactivationShaped, - merges: GroupMerge, - component_idx: int, - activation_mask: Bool[Tensor, "n_samples k_groups"], - activation_mask_orig: Bool[Tensor, "n_samples n_components"], -) -> tuple[ - GroupMerge, - Float[Tensor, "k_groups+1 k_groups+1"], - Bool[Tensor, "n_samples k_groups+1"], -]: - # sanity check dims - # ================================================== - - k_groups: int = coact.shape[0] - n_samples: int = activation_mask.shape[0] - k_groups_new: int = k_groups + 1 - assert coact.shape[1] == k_groups, "Coactivation matrix must be square" - assert activation_mask.shape[1] == k_groups, ( - "Activation mask must match coactivation matrix shape" - ) - assert n_samples == activation_mask_orig.shape[0], ( - "Activation mask original must match number of samples" - ) - - # get the activations we need - # ================================================== - # which group does the component belong to? - group_idx: int = int(merges.group_idxs[component_idx].item()) - group_size_old: int = int(merges.components_per_group[group_idx].item()) - group_size_new: int = group_size_old - 1 - - # activations of component we are popping out - acts_pop: Bool[Tensor, " samples"] = activation_mask_orig[:, component_idx] - - # activations of the "remainder" -- everything other than the component we are popping out, - # in the group we're popping it out of - acts_remainder: Bool[Tensor, " samples"] = ( - activation_mask_orig[ - :, [i for i in merges.components_in_group(group_idx) if i != component_idx] - ] - .max(dim=-1) - .values - ) - - # assemble the new activation mask - # ================================================== - # first concat the popped-out component onto the end - activation_mask_new: Bool[Tensor, " samples k_groups+1"] = torch.cat( - [activation_mask, acts_pop.unsqueeze(1)], - dim=1, - ) - # then replace the group we are popping out of with the remainder - activation_mask_new[:, group_idx] = acts_remainder - - # assemble the new coactivation matrix - # ================================================== - coact_new: Float[Tensor, "k_groups+1 k_groups+1"] = torch.full( - (k_groups_new, k_groups_new), - fill_value=float("nan"), - dtype=coact.dtype, - device=coact.device, - ) - # copy in the old coactivation matrix - coact_new[:k_groups, :k_groups] = coact.clone() - # compute new coactivations we need - coact_pop: Float[Tensor, " k_groups"] = acts_pop.float() @ activation_mask_new.float() - coact_remainder: Float[Tensor, " k_groups"] = ( - acts_remainder.float() @ activation_mask_new.float() - ) - - # replace the relevant rows and columns - coact_new[group_idx, :] = coact_remainder - coact_new[:, group_idx] = coact_remainder - coact_new[-1, :] = coact_pop - coact_new[:, -1] = coact_pop - - # assemble the new group merge - # ================================================== - group_idxs_new: Int[Tensor, " k_groups+1"] = merges.group_idxs.clone() - # the popped-out component is now its own group - new_group_idx: int = k_groups_new - 1 - group_idxs_new[component_idx] = new_group_idx - merge_new: GroupMerge = GroupMerge( - group_idxs=group_idxs_new, - k_groups=k_groups_new, - ) - - # sanity check - assert merge_new.components_per_group.shape == (k_groups_new,), ( - "New merge must have k_groups+1 components" - ) - assert merge_new.components_per_group[new_group_idx] == 1, ( - "New group must have exactly one component" - ) - assert merge_new.components_per_group[group_idx] == group_size_new, ( - "Old group must have one less component" - ) - - # return - # ================================================== - return ( - merge_new, - coact_new, - activation_mask_new, - ) diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml index efa36d693..3729941ce 100644 --- a/spd/clustering/configs/crc/example.yaml +++ b/spd/clustering/configs/crc/example.yaml @@ -12,7 +12,6 @@ merge_config: merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' merge_pair_sampling_kwargs: threshold: 0.05 # For range sampler: fraction of the range of costs to sample from - pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold: 0.001 # Threshold for filtering dead components module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules diff --git a/spd/clustering/configs/crc/resid_mlp1.json b/spd/clustering/configs/crc/resid_mlp1.json index 506717282..1e13ce23e 100644 --- a/spd/clustering/configs/crc/resid_mlp1.json +++ b/spd/clustering/configs/crc/resid_mlp1.json @@ -5,7 +5,6 @@ "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/resid_mlp2.json b/spd/clustering/configs/crc/resid_mlp2.json index af645f3bd..edc4849e2 100644 --- a/spd/clustering/configs/crc/resid_mlp2.json +++ b/spd/clustering/configs/crc/resid_mlp2.json @@ -5,7 +5,6 @@ "iters": 100, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.01, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json index f585e848f..e1647b6e4 100644 --- a/spd/clustering/configs/crc/simplestories_dev.json +++ b/spd/clustering/configs/crc/simplestories_dev.json @@ -4,8 +4,7 @@ "alpha": 1.0, "iters": 100, "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.01}, - "pop_component_prob": 0, + "merge_pair_sampling_kwargs": {"threshold": 0.001}, "filter_dead_threshold": 0.1, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/test-resid_mlp1.json b/spd/clustering/configs/crc/test-resid_mlp1.json index 01b510200..4b3a26ff8 100644 --- a/spd/clustering/configs/crc/test-resid_mlp1.json +++ b/spd/clustering/configs/crc/test-resid_mlp1.json @@ -5,7 +5,6 @@ "iters": 16, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.1, "module_name_filter": null }, diff --git a/spd/clustering/configs/crc/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json index 147634edb..911f71529 100644 --- a/spd/clustering/configs/crc/test-simplestories.json +++ b/spd/clustering/configs/crc/test-simplestories.json @@ -5,7 +5,6 @@ "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.9, "module_name_filter": "model.layers.0" }, diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index fd982b83f..dba55c878 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -8,7 +8,7 @@ from typing import Protocol import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from tqdm import tqdm @@ -16,7 +16,6 @@ compute_mdl_cost, compute_merge_costs, recompute_coacts_merge_pair, - recompute_coacts_pop_group, ) from spd.clustering.consts import ( ActivationsTensor, @@ -76,24 +75,6 @@ def merge_iteration( # determine number of iterations based on config and number of components num_iters: int = merge_config.get_num_iters(c_components) - # pop logic setup - # -------------------------------------------------- - # for speed, we precompute whether to pop components and which components to pop - # if we are not popping, we don't need these variables and can also delete other things - do_pop: bool = merge_config.pop_component_prob > 0.0 - if do_pop: - # at each iteration, we will pop a component with probability `pop_component_prob` - iter_pop: Bool[Tensor, " iters"] = ( - torch.rand(num_iters, device=coact.device) < merge_config.pop_component_prob - ) - # we pick a subcomponent at random, and if we decide to pop, we pop that one out of its group - # if the component is a singleton, nothing happens. this naturally biases towards popping - # less at the start and more at the end, since the effective probability of popping a component - # is actually something like `pop_component_prob * (c_components - k_groups) / c_components` - pop_component_idx: Int[Tensor, " iters"] = torch.randint( - 0, c_components, (num_iters,), device=coact.device - ) - # initialize vars # -------------------------------------------------- # start with an identity merge @@ -110,12 +91,6 @@ def merge_iteration( labels=component_labels, ) - # free up memory - if not do_pop: - del coact - del activation_mask_orig - activation_mask_orig = None - # merge iteration # ================================================== pbar: tqdm[int] = tqdm( @@ -124,30 +99,6 @@ def merge_iteration( total=num_iters, ) for iter_idx in pbar: - # pop components - # -------------------------------------------------- - if do_pop and iter_pop[iter_idx]: # pyright: ignore[reportPossiblyUnboundVariable] - # we split up the group which our chosen component belongs to - pop_component_idx_i: int = int(pop_component_idx[iter_idx].item()) # pyright: ignore[reportPossiblyUnboundVariable] - n_components_in_pop_grp: int = int( - current_merge.components_per_group[ # pyright: ignore[reportArgumentType] - current_merge.group_idxs[pop_component_idx_i].item() - ] - ) - - # but, if the component is the only one in its group, there is nothing to do - if n_components_in_pop_grp > 1: - current_merge, current_coact, current_act_mask = recompute_coacts_pop_group( - coact=current_coact, - merges=current_merge, - component_idx=pop_component_idx_i, - activation_mask=current_act_mask, - # this complains if `activation_mask_orig is None`, but this is only the case - # if `do_pop` is False, which it won't be here. we do this to save memory - activation_mask_orig=activation_mask_orig, # pyright: ignore[reportArgumentType] - ) - k_groups = current_coact.shape[0] - # compute costs, figure out what to merge # -------------------------------------------------- # HACK: this is messy diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 3bf8b6d5b..f471879b2 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -23,7 +23,6 @@ "iters", "merge_pair_sampling_method", "merge_pair_sampling_kwargs", - "pop_component_prob", "filter_dead_threshold", ] @@ -65,10 +64,6 @@ class MergeConfig(BaseConfig): default_factory=lambda: {"threshold": 0.05}, description="Keyword arguments for the merge pair sampling method.", ) - pop_component_prob: Probability = Field( - default=0, - description="Probability of popping a component in each iteration. If 0, no components are popped.", - ) filter_dead_threshold: float = Field( default=0.001, description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index 1d2a69c93..bbfb5259e 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -120,7 +120,6 @@ iters=int(PROCESSED_ACTIVATIONS.n_components_alive * 0.9), merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, - pop_component_prob=0, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 3f5da34a0..45c142fa0 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -47,7 +47,7 @@ # Use load_dataset with RunConfig to get real data CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(), + merge_config=MergeConfig(batch_size=2), model_path=MODEL_PATH, batch_size=2, dataset_seed=42, @@ -103,7 +103,6 @@ iters=2, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, module_name_filter=FILTER_MODULES, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/test_calc_distances.py b/tests/clustering/test_calc_distances.py index d8971df05..b06350f4b 100644 --- a/tests/clustering/test_calc_distances.py +++ b/tests/clustering/test_calc_distances.py @@ -11,7 +11,6 @@ def test_merge_history_normalization_happy_path(): iters=3, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) histories = [] diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py index 9f191075b..63f4e88f7 100644 --- a/tests/clustering/test_merge_config.py +++ b/tests/clustering/test_merge_config.py @@ -74,7 +74,6 @@ def test_config_with_all_parameters(self): iters=200, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.5}, - pop_component_prob=0.1, filter_dead_threshold=0.001, module_name_filter="model.layers", ) @@ -84,7 +83,6 @@ def test_config_with_all_parameters(self): assert config.iters == 200 assert config.merge_pair_sampling_method == "mcmc" assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} - assert config.pop_component_prob == 0.1 assert config.filter_dead_threshold == 0.001 assert config.module_name_filter == "model.layers" diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 14811b7c5..8492300de 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -25,7 +25,6 @@ def test_merge_with_range_sampler(self): iters=5, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -59,7 +58,6 @@ def test_merge_with_mcmc_sampler(self): iters=5, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -77,37 +75,6 @@ def test_merge_with_mcmc_sampler(self): assert history.merges.k_groups[-1].item() < n_components assert history.merges.k_groups[-1].item() >= 2 - def test_merge_with_popping(self): - """Test merge iteration with component popping.""" - # Create test data - n_samples = 100 - n_components = 15 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) - - # Configure with popping enabled - config = MergeConfig( - activation_threshold=0.1, - alpha=1.0, - iters=10, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0.3, # 30% chance of popping - filter_dead_threshold=0.001, - ) - - # Run merge iteration - history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels - ) - - # Check results - assert history is not None - # First entry is after first merge, so should be n_components - 1 - assert history.merges.k_groups[0].item() == n_components - 1 - # Final group count depends on pops, but should be less than initial - assert history.merges.k_groups[-1].item() < n_components - def test_merge_comparison_samplers(self): """Compare behavior of different samplers with same data.""" # Create test data with clear structure @@ -128,7 +95,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum - pop_component_prob=0, ) history_range = merge_iteration( @@ -144,7 +110,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp - pop_component_prob=0, ) history_mcmc = merge_iteration( @@ -173,7 +138,6 @@ def test_merge_with_small_components(self): iters=1, # Just one merge merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 2.0}, - pop_component_prob=0, ) history = merge_iteration( diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 57bb5e1ff..5e2cbbd1c 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -24,7 +24,6 @@ def test_run_clustering_happy_path(): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0, ), wandb_project=None, wandb_entity="goodfire", From 1e3fbb292130164bbf278314dae101386ea667fc Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 13:59:30 +0100 Subject: [PATCH 36/77] dont pass batch size, change not brought in here --- tests/clustering/scripts/cluster_ss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index 45c142fa0..0b7f8de97 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -47,7 +47,7 @@ # Use load_dataset with RunConfig to get real data CONFIG: ClusteringRunConfig = ClusteringRunConfig( - merge_config=MergeConfig(batch_size=2), + merge_config=MergeConfig(), model_path=MODEL_PATH, batch_size=2, dataset_seed=42, From 866e28ce61af92462918adf62ad32dc0f8ce4c60 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 14:27:16 +0100 Subject: [PATCH 37/77] wip --- spd/clustering/ci_dt/config.py | 3 +- spd/clustering/ci_dt/run.py | 129 ++++++++++++++++--------- spd/clustering/scripts/run_pipeline.py | 2 +- 3 files changed, 86 insertions(+), 48 deletions(-) diff --git a/spd/clustering/ci_dt/config.py b/spd/clustering/ci_dt/config.py index a83c6adc4..4b4e9205b 100644 --- a/spd/clustering/ci_dt/config.py +++ b/spd/clustering/ci_dt/config.py @@ -7,7 +7,8 @@ class CIDTConfig: """Configuration for causal importance decision tree training.""" - n_samples: int = 250 + batch_size: int = 10 # Number of samples per batch for GPU inference + n_batches: int = 25 # Number of batches to process (total samples = batch_size * n_batches) activation_threshold: float = 0.01 # Threshold for boolean conversion filter_dead_threshold: float = 0.001 # Threshold for filtering dead components max_depth: int = 8 # Maximum depth for decision trees diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index cc4333d2f..7a0f7a53d 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -8,11 +8,7 @@ from jaxtyping import Bool, Float from torch import Tensor -from spd.clustering.activations import ( - ProcessedActivations, - component_activations, - process_activations, -) +from spd.clustering.activations import component_activations from spd.clustering.ci_dt.config import CIDTConfig from spd.clustering.ci_dt.core import ( LayerModel, @@ -43,7 +39,8 @@ # ----------------------- configuration ----------------------- config = CIDTConfig( - n_samples=64, # batch size 64 -> 16GB vram + batch_size=10, + n_batches=10, activation_threshold=0.01, filter_dead_threshold=0.001, max_depth=8, @@ -83,76 +80,116 @@ ) dataloader, _ = create_data_loader( dataset_config=dataset_config, - batch_size=config.n_samples, + batch_size=config.batch_size, buffer_size=cfg.task_config.buffer_size, global_seed=cfg.seed, ddp_rank=0, ddp_world_size=1, ) -batch_data = next(iter(dataloader)) -batch: Tensor = batch_data["input_ids"] -print(f"Created LM dataset with {cfg.task_config.dataset_name}, batch shape: {batch.shape}") +print(f"Created LM dataset with {cfg.task_config.dataset_name}") # %% # ----------------------- get activations ----------------------- -# Get component activations (on device) -print("Computing component activations...") -component_acts: dict[str, Tensor] = component_activations( - model=model, - device=device, - batch=batch, -) +# Loop over batches, accumulate on CPU +print(f"Computing activations for {config.n_batches} batches (batch_size={config.batch_size})...") +all_component_acts: list[dict[str, Tensor]] = [] -# Process activations (filter dead components, concatenate) -print("Processing activations...") -processed_acts: ProcessedActivations = process_activations( - component_acts, - filter_dead_threshold=config.filter_dead_threshold, - seq_mode="seq_mean", # LM task needs seq_mean -) +for batch_idx in range(config.n_batches): + batch_data = next(iter(dataloader)) + batch: Tensor = batch_data["input_ids"] + + # Get activations on GPU + component_acts_gpu: dict[str, Tensor] = component_activations( + model=model, device=device, batch=batch + ) -print(f"Total components (before filtering): {processed_acts.n_components_original}") -print(f"Alive components: {processed_acts.n_components_alive}") -print(f"Dead components: {processed_acts.n_components_dead}") -print(f"Module keys: {processed_acts.module_keys}") + # Move to CPU immediately and store + component_acts_cpu: dict[str, Tensor] = { + key: tensor.cpu() for key, tensor in component_acts_gpu.items() + } + all_component_acts.append(component_acts_cpu) + + print(f" Batch {batch_idx + 1}/{config.n_batches} processed") + +# Concatenate all batches on CPU +print("Concatenating batches...") +module_keys: list[str] = list(all_component_acts[0].keys()) +component_acts_concat: dict[str, Tensor] = { + key: torch.cat([batch[key] for batch in all_component_acts], dim=0) + for key in module_keys +} + +# Apply seq_mean if needed (LM task) +print("Applying seq_mean over sequence dimension...") +component_acts_concat = { + key: act.mean(dim=1) if act.ndim == 3 else act + for key, act in component_acts_concat.items() +} + +# Filter dead components (on CPU) +print("\nFiltering dead components...") +component_acts_filtered: dict[str, Float[np.ndarray, "n_samples n_components"]] = {} +n_total_original: int = 0 +n_total_alive: int = 0 +n_total_dead: int = 0 + +for module_key in module_keys: + acts_tensor: Tensor = component_acts_concat[module_key] + n_components_original: int = acts_tensor.shape[1] + n_total_original += n_components_original + + # Filter components where max activation < threshold + max_acts: Tensor = acts_tensor.max(dim=0).values + alive_mask: Bool[Tensor, "n_components"] = max_acts >= config.filter_dead_threshold + + acts_alive: Tensor = acts_tensor[:, alive_mask] + acts_np: Float[np.ndarray, "n_samples n_components"] = acts_alive.numpy() + + n_dead: int = (~alive_mask).sum().item() + n_alive: int = alive_mask.sum().item() + n_total_alive += n_alive + n_total_dead += n_dead + + component_acts_filtered[module_key] = acts_np + print(f" {module_key:30s} {n_alive:5d} alive, {n_dead:5d} dead") + +print(f"\nTotal components (before filtering): {n_total_original}") +print(f"Alive components: {n_total_alive}") +print(f"Dead components: {n_total_dead}") +print(f"Module keys: {module_keys}") # %% # ----------------------- convert to layers ----------------------- -# Move to CPU and convert to numpy for sklearn -# Group by module to create "layers" for decision trees +# Convert to boolean and filter constant components print("\nConverting to boolean layers...") layers_true: list[Bool[np.ndarray, "n_samples n_components"]] = [] -for module_key in processed_acts.module_keys: - # Get the activations for this module from activations_raw, move to CPU - module_acts_cpu: Float[np.ndarray, "n_samples n_components"] = ( - processed_acts.activations_raw[module_key].cpu().numpy() - ) + +for module_key in module_keys: + module_acts: Float[np.ndarray, "n_samples n_components"] = component_acts_filtered[module_key] + + # Convert to boolean module_acts_bool: Bool[np.ndarray, "n_samples n_components"] = ( - module_acts_cpu >= config.activation_threshold + module_acts >= config.activation_threshold ).astype(bool) # Filter out components that are always dead or always alive # (they provide no information for decision trees) - n_before: int = module_acts_bool.shape[1] component_variance: Float[np.ndarray, "n_components"] = module_acts_bool.var(axis=0) varying_mask: Bool[np.ndarray, "n_components"] = component_variance > 0 # Count always-dead and always-alive components for diagnostics always_dead_mask: Bool[np.ndarray, "n_components"] = ~module_acts_bool.any(axis=0) always_alive_mask: Bool[np.ndarray, "n_components"] = module_acts_bool.all(axis=0) - n_always_dead: int = always_dead_mask.sum() - n_always_alive: int = always_alive_mask.sum() + n_always_dead: int = int(always_dead_mask.sum()) + n_always_alive: int = int(always_alive_mask.sum()) - module_acts_filtered: Bool[np.ndarray, "n_samples n_varying"] = module_acts_bool[:, varying_mask] - n_after: int = module_acts_filtered.shape[1] + module_acts_varying: Bool[np.ndarray, "n_samples n_varying"] = module_acts_bool[:, varying_mask] - layers_true.append(module_acts_filtered) - print( - f"Layer {len(layers_true) - 1} ({module_key}): {n_after} varying components " - f"({n_always_dead} always dead, {n_always_alive} always alive removed)" - ) + layers_true.append(module_acts_varying) + n_varying: int = module_acts_varying.shape[1] + print(f" {module_key:30s} {n_varying:5d} varying, {n_always_dead:5d} dead, {n_always_alive:5d} const") print(f"\nCreated {len(layers_true)} layers for decision tree training") diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 7b6af0e82..975d57002 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -72,7 +72,7 @@ class ClusteringPipelineConfig(BaseConfig): run_clustering_config_path: Path = Field(description="Path to ClusteringRunConfig file.") n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") - distances_method: DistancesMethod = Field(description="Method to use for calculating distances") + distances_methods: list[DistancesMethod] = Field(description="Methods to use for calculating distances") base_output_dir: Path = Field(description="Base directory for outputs of clustering runs.") slurm_job_name_prefix: str | None = Field(description="Prefix for SLURM job names") slurm_partition: str | None = Field(description="SLURM partition to use") From 12e54e001ac4efa4d6073d233fb9bed2c4f1e65e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Tue, 21 Oct 2025 15:14:54 +0100 Subject: [PATCH 38/77] add some js from feature/clustering-dashboard branch --- spd/clustering/ci_dt/js/cluster-detail.js | 740 + spd/clustering/ci_dt/js/cluster-selection.js | 841 ++ .../ci_dt/js/model-visualization.js | 222 + spd/clustering/ci_dt/js/pkg/jszip.js | 11577 ++++++++++++++++ spd/clustering/ci_dt/js/pkg/jszip.min.js | 13 + spd/clustering/ci_dt/js/token-display.js | 90 + spd/clustering/ci_dt/js/util/ColorUtil.js | 107 + spd/clustering/ci_dt/js/util/DataFrame.js | 269 + spd/clustering/ci_dt/js/util/array.js | 447 + spd/clustering/ci_dt/js/util/config.js | 614 + .../ci_dt/js/util/github-integration.js | 241 + spd/clustering/ci_dt/js/util/histogram.js | 57 + spd/clustering/ci_dt/js/util/notif.js | 269 + spd/clustering/ci_dt/js/util/sparklines.js | 364 + spd/clustering/ci_dt/js/util/table.js | 1116 ++ spd/clustering/ci_dt/js/util/yaml.js | 58 + 16 files changed, 17025 insertions(+) create mode 100644 spd/clustering/ci_dt/js/cluster-detail.js create mode 100644 spd/clustering/ci_dt/js/cluster-selection.js create mode 100644 spd/clustering/ci_dt/js/model-visualization.js create mode 100644 spd/clustering/ci_dt/js/pkg/jszip.js create mode 100644 spd/clustering/ci_dt/js/pkg/jszip.min.js create mode 100644 spd/clustering/ci_dt/js/token-display.js create mode 100644 spd/clustering/ci_dt/js/util/ColorUtil.js create mode 100644 spd/clustering/ci_dt/js/util/DataFrame.js create mode 100644 spd/clustering/ci_dt/js/util/array.js create mode 100644 spd/clustering/ci_dt/js/util/config.js create mode 100644 spd/clustering/ci_dt/js/util/github-integration.js create mode 100644 spd/clustering/ci_dt/js/util/histogram.js create mode 100644 spd/clustering/ci_dt/js/util/notif.js create mode 100644 spd/clustering/ci_dt/js/util/sparklines.js create mode 100644 spd/clustering/ci_dt/js/util/table.js create mode 100644 spd/clustering/ci_dt/js/util/yaml.js diff --git a/spd/clustering/ci_dt/js/cluster-detail.js b/spd/clustering/ci_dt/js/cluster-detail.js new file mode 100644 index 000000000..83abfb96e --- /dev/null +++ b/spd/clustering/ci_dt/js/cluster-detail.js @@ -0,0 +1,740 @@ +let clusterData = null; +let allClusters = null; +let textSamples = {}; +let activationsArray = null; +let activationsMap = {}; +let currentClusterHash = null; +let modelInfo = {}; +let explanations = {}; + +// Component-level data +let componentActivations = {}; // Map component labels to their activation data +let enabledComponents = new Set(); // Track which components are enabled +let combinationStrategy = 'max'; // How to combine component activations: 'max', 'sum', 'mean' + +async function init() { + // Get cluster hash from URL + const urlParams = new URLSearchParams(window.location.search); + currentClusterHash = urlParams.get('id'); + + if (!currentClusterHash) { + const loading = document.getElementById('loading'); + if (!loading) { + const msg = 'Fatal error: loading element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + loading.textContent = 'No cluster ID specified'; + return; + } + + await loadData(); +} + +async function loadData() { + const progressBar = NOTIF.pbar('Loading cluster data...'); + + try { + progressBar.progress(0.1); + + // Load data in parallel + let clusters, samples, activationsMapResponse, modelInfoResponse; + + const clustersPath = CONFIG.getDataPath('clusters'); + const textSamplesPath = CONFIG.getDataPath('textSamples'); + const activationsMapPath = CONFIG.getDataPath('activationsMap'); + const modelInfoPath = CONFIG.getDataPath('modelInfo'); + const explanationsPath = CONFIG.getDataPath('explanations'); + + try { + [clusters, samples, activationsMapResponse, modelInfoResponse] = await Promise.all([ + loadJSONL(clustersPath, 'cluster_hash').catch(e => { + throw new Error(`Failed to load ${clustersPath}: ${e.message}`); + }), + loadJSONL(textSamplesPath, 'text_hash').catch(e => { + throw new Error(`Failed to load ${textSamplesPath}: ${e.message}`); + }), + fetch(activationsMapPath).catch(e => { + throw new Error(`Failed to load ${activationsMapPath}: ${e.message}`); + }), + fetch(modelInfoPath).catch(e => { + throw new Error(`Failed to load ${modelInfoPath}: ${e.message}`); + }) + ]); + + // Load explanations (non-critical, don't fail if missing) + explanations = await loadJSONL(explanationsPath, 'cluster_id').catch(() => ({})); + } catch (error) { + progressBar.complete(); + NOTIF.error(error.message, error, null); + const loading = document.getElementById('loading'); + if (loading) { + loading.textContent = error.message; + } else { + console.error('loading element not found, cannot display error message'); + } + throw error; + } + + progressBar.progress(0.4); + + if (!activationsMapResponse.ok) { + const msg = `Failed to load ${activationsMapPath} (HTTP ${activationsMapResponse.status})`; + NOTIF.error(msg, null, null); + throw new Error(msg); + } + if (!modelInfoResponse.ok) { + const msg = `Failed to load ${modelInfoPath} (HTTP ${modelInfoResponse.status})`; + NOTIF.error(msg, null, null); + throw new Error(msg); + } + + allClusters = clusters; + textSamples = samples; + + try { + activationsMap = await activationsMapResponse.json(); + } catch (error) { + const msg = `Failed to parse ${activationsMapPath} (invalid JSON)`; + NOTIF.error(msg, error, null); + throw new Error(msg); + } + + try { + modelInfo = await modelInfoResponse.json(); + } catch (error) { + const msg = `Failed to parse ${modelInfoPath} (invalid JSON)`; + NOTIF.error(msg, error, null); + throw new Error(msg); + } + + progressBar.progress(0.6); + + if (!allClusters[currentClusterHash]) { + const msg = 'Cluster not found'; + NOTIF.error(msg, null, null); + const loading = document.getElementById('loading'); + if (loading) { + loading.textContent = msg; + } else { + console.error('loading element not found, cannot display error message'); + } + progressBar.complete(); + return; + } + + clusterData = allClusters[currentClusterHash]; + + // Load activations (float16 compressed npz) + const activationsPath = CONFIG.getDataPath('activations'); + try { + activationsArray = await NDArray.load(activationsPath); + } catch (error) { + const msg = `Failed to load ${activationsPath}`; + NOTIF.error(msg, error, null); + throw new Error(msg); + } + + progressBar.progress(0.9); + + displayCluster(); + progressBar.complete(); + const loading = document.getElementById('loading'); + if (!loading) { + const msg = 'Fatal error: loading element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + loading.style.display = 'none'; + } catch (error) { + progressBar.complete(); + console.error('Load error:', error); + console.error('Stack:', error.stack); + } +} + +function displayCluster() { + // Update title + const clusterTitle = document.getElementById('clusterTitle'); + if (!clusterTitle) { + const msg = 'Fatal error: clusterTitle element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + clusterTitle.textContent = `Cluster ${currentClusterHash}`; + + // Display component count + const componentCount = document.getElementById('componentCount'); + if (!componentCount) { + const msg = 'Fatal error: componentCount element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + componentCount.textContent = clusterData.components.length; + + // Display explanation and setup copy handler + displayExplanation(); + setupCopyHandler(); + + // Initialize component data + initializeComponentData(); + + // Display model visualization + displayModelVisualization(); + + // Setup components table + setupComponentsTable(); + + // Setup hover highlighting between model view and components table + setupModelViewHighlighting(); + + // Display histogram plots + displayHistograms(); + + // Display token activation stats if available + if (clusterData.stats && clusterData.stats.token_activations) { + displayTokenActivations(); + } + + // Display samples + displaySamples(); +} + +function displayExplanation() { + const explanationSpan = document.getElementById('clusterExplanation'); + if (!explanationSpan) return; + + const explanationData = explanations[currentClusterHash]; + if (explanationData && explanationData.explanation) { + explanationSpan.textContent = explanationData.explanation; + explanationSpan.style.fontStyle = 'normal'; + explanationSpan.style.color = '#000'; + } else { + explanationSpan.textContent = 'No explanation'; + explanationSpan.style.fontStyle = 'italic'; + explanationSpan.style.color = '#666'; + } +} + +function setupCopyHandler() { + const copyBtn = document.getElementById('copyTemplateBtn'); + if (!copyBtn) return; + + copyBtn.addEventListener('click', async () => { + const template = JSON.stringify({ + cluster_id: currentClusterHash, + explanation: "" + }) + '\n'; + + try { + await navigator.clipboard.writeText(template); + NOTIF.success('Template copied to clipboard!'); + } catch (err) { + // Fallback for older browsers + const textArea = document.createElement('textarea'); + textArea.value = template; + textArea.style.position = 'fixed'; + textArea.style.left = '-999999px'; + document.body.appendChild(textArea); + textArea.select(); + try { + document.execCommand('copy'); + NOTIF.success('Template copied to clipboard!'); + } catch (e) { + NOTIF.error('Failed to copy template', e, null); + } + document.body.removeChild(textArea); + } + }); +} + +function initializeComponentData() { + // Load component activations if available + if (clusterData.component_activations) { + componentActivations = clusterData.component_activations; + } + + // Enable all components by default + enabledComponents.clear(); + clusterData.components.forEach(comp => { + enabledComponents.add(comp.label); + }); +} + +function displayModelVisualization() { + const modelViewDiv = document.getElementById('modelView'); + if (!modelViewDiv) { + const msg = 'Fatal error: modelView element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + renderModelView(modelViewDiv, currentClusterHash, allClusters, modelInfo, CONFIG.visualization.colormap, CONFIG.visualization.modelViewCellSize); +} + +function displayHistograms() { + const stats = clusterData.stats; + if (!stats) return; + + const histogramPlots = document.getElementById('histogramPlots'); + if (!histogramPlots) { + const msg = 'Fatal error: histogramPlots element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + histogramPlots.innerHTML = ''; + + // Color mapping for different histogram types + const statColors = { + 'all_activations': '#4169E1', + 'max_activation-max-16': '#DC143C', + 'max_activation-max-32': '#DC143C', + 'mean_activation-max-16': '#228B22', + 'median_activation-max-16': '#FF8C00', + 'min_activation-max-16': '#9370DB', + 'max_activation_position': '#FF6347' + }; + + // Discover all histogram stats + const histogramStats = []; + for (const [key, value] of Object.entries(stats)) { + if (value && typeof value === 'object' && 'bin_counts' in value && 'bin_edges' in value) { + histogramStats.push(key); + } + } + + // Create a plot for each histogram stat + histogramStats.forEach(statKey => { + const histData = stats[statKey]; + const color = statColors[statKey] || '#808080'; + const label = statKey.replace(/-/g, ' ').replace(/_/g, ' ') + .split(' ') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(' '); + + // Create container for this plot + const plotContainer = document.createElement('div'); + plotContainer.style.display = 'flex'; + plotContainer.style.flexDirection = 'column'; + plotContainer.style.alignItems = 'center'; + plotContainer.style.minWidth = '250px'; + + // Add label + const plotLabel = document.createElement('div'); + plotLabel.textContent = label; + plotLabel.style.fontSize = '12px'; + plotLabel.style.fontWeight = 'bold'; + plotLabel.style.marginBottom = '5px'; + plotLabel.style.textAlign = 'center'; + plotContainer.appendChild(plotLabel); + + // Create sparkline + const sparklineContainer = document.createElement('div'); + sparklineContainer.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + const min = histData.bin_edges[0]; + const max = histData.bin_edges[histData.bin_edges.length - 1]; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth || 200, + height: CONFIG.visualization.sparklineHeight || 60, + color: color, + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin || 35} + }); + + sparklineContainer.innerHTML = svg; + + // Add tooltip with statistics + const mean = calculateHistogramMean(histData); + const median = calculateHistogramMedian(histData); + const totalCount = histData.bin_counts.reduce((a, b) => a + b, 0); + sparklineContainer.title = `${label} (n=${totalCount})\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + plotContainer.appendChild(sparklineContainer); + histogramPlots.appendChild(plotContainer); + }); +} + +function displayTokenActivations() { + const tokenStats = clusterData.stats.token_activations; + + // Show the section + const tokenActivations = document.getElementById('tokenActivations'); + if (!tokenActivations) { + const msg = 'Fatal error: tokenActivations element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + tokenActivations.style.display = 'block'; + + // Setup top tokens table + if (tokenStats.top_tokens && tokenStats.top_tokens.length > 0) { + const tableData = tokenStats.top_tokens.map((item, idx) => ({ + rank: idx + 1, + token: item.token, + count: item.count, + percentage: ((item.count / tokenStats.total_activations) * 100) + })); + + const maxPercentage = tableData.length > 0 ? tableData[0].percentage : 0; + + const tableConfig = { + data: tableData, + columns: [ + { + key: 'rank', + label: '#', + type: 'number', + width: '40px', + align: 'right' + }, + { + key: 'token', + label: 'Token', + type: 'string', + width: '120px', + renderer: (value) => { + // Show token in a monospace box with visual formatting + const tokenDisplay = value.replace(/ /g, '·').replace(/\n/g, '↵'); + return `${tokenDisplay}`; + } + }, + { + key: 'percentage', + label: '%', + type: 'number', + width: '70px', + align: 'right', + renderer: (value) => { + const percentageValue = value; + const percentage = percentageValue.toFixed(1); + + // Color based on percentage (normalized by max percentage) + const normalizedPct = maxPercentage > 0 ? percentageValue / maxPercentage : 0; + const intensity = Math.floor((1 - normalizedPct) * 255); + const bgColor = `rgb(255, ${intensity}, ${intensity})`; + + const span = document.createElement('span'); + span.textContent = `${percentage}%`; + span.style.backgroundColor = bgColor; + span.style.padding = '2px 4px'; + span.style.borderRadius = '2px'; + + return span; + }, + infoFunction: () => { + return `Unique: ${tokenStats.total_unique_tokens.toLocaleString()} | Total: ${tokenStats.total_activations.toLocaleString()} | Entropy: ${tokenStats.entropy.toFixed(2)} | Conc: ${(tokenStats.concentration_ratio * 100).toFixed(1)}%`; + } + } + ], + pageSize: 10, + showFilters: false, + showInfo: true + }; + + new DataTable('#topTokensTable', tableConfig); + } +} + +function setupComponentsTable() { + const tableData = clusterData.components.map(comp => ({ + label: comp.label, + module: comp.module, + index: comp.index, + enabled: enabledComponents.has(comp.label) + })); + + const tableConfig = { + data: tableData, + columns: [ + { + key: 'enabled', + label: '✓', + type: 'boolean', + width: '40px', + align: 'center', + renderer: (value, row) => { + const checkbox = document.createElement('input'); + checkbox.type = 'checkbox'; + checkbox.checked = value; + checkbox.style.cursor = 'pointer'; + checkbox.addEventListener('change', (e) => { + onComponentToggle(row.label, e.target.checked); + }); + return checkbox; + }, + filterable: false + }, + { + key: 'module', + label: 'Module', + type: 'string', + width: '250px' + }, + { + key: 'index', + label: 'Index', + type: 'number', + width: '80px', + align: 'right' + } + ], + pageSize: CONFIG.clusterPage.pageSize, + showFilters: false + }; + + new DataTable('#componentsTable', tableConfig); +} + +function onComponentToggle(componentLabel, isEnabled) { + if (isEnabled) { + enabledComponents.add(componentLabel); + } else { + enabledComponents.delete(componentLabel); + } + + // Recompute and redisplay activations + recomputeDisplayedActivations(); +} + +async function recomputeDisplayedActivations() { + // If no components are enabled or component activations not available, use cluster-level + if (enabledComponents.size === 0 || !componentActivations || Object.keys(componentActivations).length === 0) { + // Just redisplay with cluster-level activations (default) + displaySamples(); + return; + } + + // If all components are enabled, use cluster-level activations (faster) + if (enabledComponents.size === clusterData.components.length) { + displaySamples(); + return; + } + + // Recompute activations based on enabled components + displaySamples(); +} + +function combineComponentActivations(componentActsList, strategy) { + // componentActsList: array of activation arrays [n_ctx] + // Returns: combined activation array [n_ctx] + + if (componentActsList.length === 0) { + return null; + } + + if (componentActsList.length === 1) { + return componentActsList[0]; + } + + const n_ctx = componentActsList[0].length; + const combined = new Array(n_ctx).fill(0); + + if (strategy === 'max') { + for (let i = 0; i < n_ctx; i++) { + let maxVal = componentActsList[0][i]; + for (let j = 1; j < componentActsList.length; j++) { + if (componentActsList[j][i] > maxVal) { + maxVal = componentActsList[j][i]; + } + } + combined[i] = maxVal; + } + } else if (strategy === 'sum') { + for (let i = 0; i < n_ctx; i++) { + let sum = 0; + for (let j = 0; j < componentActsList.length; j++) { + sum += componentActsList[j][i]; + } + combined[i] = sum; + } + } else if (strategy === 'mean') { + for (let i = 0; i < n_ctx; i++) { + let sum = 0; + for (let j = 0; j < componentActsList.length; j++) { + sum += componentActsList[j][i]; + } + combined[i] = sum / componentActsList.length; + } + } + + return combined; +} + +function setupModelViewHighlighting() { + // Get all model view cells + const modelViewCells = document.querySelectorAll('.modelview-module-cell'); + + // Get components table + const componentsTable = document.querySelector('#componentsTable'); + if (!componentsTable) return; + + modelViewCells.forEach(cell => { + cell.addEventListener('mouseenter', (e) => { + const moduleName = e.target.dataset.module; + if (!moduleName) return; + + // Find and highlight all rows in the components table that match this module + const tableRows = componentsTable.querySelectorAll('.tablejs-data-row'); + tableRows.forEach(row => { + const cells = row.querySelectorAll('td'); + if (cells.length > 1) { + const moduleCell = cells[1]; // Second column is module name (first is checkbox) + if (moduleCell && moduleCell.textContent === moduleName) { + row.style.backgroundColor = '#fff3cd'; // Light yellow highlight + } + } + }); + }); + + cell.addEventListener('mouseleave', () => { + // Remove highlighting from all rows + const tableRows = componentsTable.querySelectorAll('.tablejs-data-row'); + tableRows.forEach(row => { + row.style.backgroundColor = ''; + }); + }); + }); +} + +function displaySamples() { + const tbody = document.getElementById('samplesTableBody'); + if (!tbody) { + const msg = 'Fatal error: samplesTableBody element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + tbody.innerHTML = ''; + + // Get the main criterion samples (max_activation) + const criterionKey = Object.keys(clusterData.criterion_samples)[0]; + if (!criterionKey) { + tbody.innerHTML = 'No samples available'; + return; + } + + const sampleHashes = clusterData.criterion_samples[criterionKey]; + const samplesToShow = Math.min(CONFIG.clusterPage.maxSamplesPerCluster, sampleHashes.length); + + // Check if we need to use component-level activations + const useComponentActivations = componentActivations && + Object.keys(componentActivations).length > 0 && + enabledComponents.size < clusterData.components.length; + + for (let i = 0; i < samplesToShow; i++) { + const textHash = sampleHashes[i]; + const textSample = textSamples[textHash]; + + if (!textSample) { + console.warn(`Text sample not found for hash: ${textHash}`); + continue; + } + + let activationsData; + + if (useComponentActivations) { + // Compute combined activations from enabled components + const componentActsList = []; + + for (const comp of clusterData.components) { + if (enabledComponents.has(comp.label) && componentActivations[comp.label]) { + const compData = componentActivations[comp.label]; + // Find the activation for this text sample + const hashIdx = compData.activation_sample_hashes.indexOf(`${currentClusterHash}:${comp.label}:${textHash}`); + if (hashIdx !== -1) { + const activationIdx = compData.activation_indices[hashIdx]; + if (activationIdx !== undefined && activationsArray) { + const compActivations = activationsArray.get(activationIdx); + componentActsList.push(Array.from(compActivations.data)); + } + } + } + } + + if (componentActsList.length > 0) { + activationsData = combineComponentActivations(componentActsList, combinationStrategy); + } + } + + // Fall back to cluster-level activations if component activations not available + if (!activationsData) { + const fullHash = `${currentClusterHash}:${textHash}`; + const activationIdx = activationsMap[fullHash]; + + if (activationIdx !== undefined && activationsArray) { + const activations = activationsArray.get(activationIdx); + activationsData = Array.from(activations.data); + } + } + + let tokenViz; + if (activationsData) { + // Find max position + const maxPosition = activationsData.indexOf(Math.max(...activationsData)); + + // Use the proper token visualization with coloring and tooltips + tokenViz = createTokenVisualizationWithTooltip( + textSample.tokens, + activationsData, + maxPosition + ); + } else { + // Fallback to simple visualization if no activations + console.warn(`No activations found for sample ${i}`); + tokenViz = createSimpleTokenViz(textSample.tokens); + } + + const tr = document.createElement('tr'); + tr.innerHTML = ` + ${i + 1} + + `; + + // Add token visualization to last cell + tr.lastElementChild.appendChild(tokenViz); + + tbody.appendChild(tr); + } + + if (sampleHashes.length > CONFIG.clusterPage.maxSamplesPerCluster) { + const tr = document.createElement('tr'); + tr.innerHTML = ` + ... and ${sampleHashes.length - CONFIG.clusterPage.maxSamplesPerCluster} more samples + `; + tbody.appendChild(tr); + } +} + +function createSimpleTokenViz(tokens) { + const container = document.createElement('div'); + container.className = 'token-container'; + container.textContent = tokens.join(' '); + return container; +} + +// Initialize config and load data on page load +(async () => { + await initConfig(); + init(); +})(); \ No newline at end of file diff --git a/spd/clustering/ci_dt/js/cluster-selection.js b/spd/clustering/ci_dt/js/cluster-selection.js new file mode 100644 index 000000000..6a5ce1142 --- /dev/null +++ b/spd/clustering/ci_dt/js/cluster-selection.js @@ -0,0 +1,841 @@ +let clusterData = {}; +let modelInfo = {}; +let dataTable = null; +let explanations = {}; + +// Alpine.js data component for model info +const modelInfoData = { + data: {}, + hasData: false, + + async loadData() { + try { + const response = await fetch(CONFIG.getDataPath('modelInfo')); + this.data = await response.json(); + this.hasData = Object.keys(this.data).length > 0; + + // Also populate global modelInfo for DataTable renderers + modelInfo = this.data; + + console.log('Model info loaded:', this.hasData, Object.keys(this.data)); + } catch (error) { + console.error('Failed to load model info:', error); + this.hasData = false; + } + }, + + formatParameters(totalParams) { + if (!totalParams) return '-'; + if (totalParams >= 1000000) return (totalParams / 1000000).toFixed(1) + 'M'; + if (totalParams >= 1000) return (totalParams / 1000).toFixed(1) + 'K'; + return totalParams.toString(); + }, + + formatWandBLink(path) { + if (!path) return '-'; + + // Remove "wandb:" prefix if present + const cleanPath = path.replace(/^wandb:/, ''); + + // Convert to WandB URL + const url = `https://wandb.ai/${cleanPath}`; + + // Show shortened path in link text + const displayText = cleanPath.length > 60 + ? cleanPath.substring(0, 57) + '...' + : cleanPath; + + return `${displayText}`; + } +}; + +// Custom column renderers +const columnRenderers = { + modelView: function(value, row, col) { + const container = document.createElement('div'); + container.className = 'modelview-cell'; + + renderModelView(container, row.clusterHash, clusterData, modelInfo, CONFIG.visualization.colormap, CONFIG.visualization.modelViewCellSizeTable); + + return container; + }, + + modulesSummary: function(value, row, col) { + const modules = row.modules; + const container = document.createElement('div'); + container.className = 'module-summary'; + + if (modules.length === 1) { + const parts = modules[0].split('.'); + container.textContent = parts.length > 2 ? parts.slice(-2).join('.') : modules[0]; + } else if (modules.length <= 3) { + container.textContent = modules.map(m => { + const parts = m.split('.'); + return parts.length > 2 ? parts.slice(-2).join('.') : m; + }).join(', '); + } else { + container.textContent = `${modules.length} modules`; + } + + container.title = modules.join('\n'); + return container; + }, + + activationHistogram: function(value, row, col) { + const histData = row.stats.all_activations; + if (!histData) { + return 'No data'; + } + + const container = document.createElement('div'); + container.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + const min = row.stats.min_activation; + const max = row.stats.max_activation; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + // Pass bin centers as x-values and counts as y-values + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth, + height: CONFIG.visualization.sparklineHeight, + color: '#4169E1', + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin} + }); + + container.innerHTML = svg; + + const mean = row.stats.mean_activation; + const median = calculateHistogramMedian(histData); + const n = row.stats.n_tokens; + + container.title = `All Activations Histogram (n=${n})\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + return container; + }, + + maxActivationDistribution: function(value, row, col) { + const histData = row.stats['max_activation-max-16']; + if (!histData) { + return 'No data'; + } + + const container = document.createElement('div'); + container.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + const min = histData.bin_edges[0]; + const max = histData.bin_edges[histData.bin_edges.length - 1]; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + // Pass bin centers as x-values and counts as y-values + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth, + height: CONFIG.visualization.sparklineHeight, + color: '#DC143C', + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin} + }); + + container.innerHTML = svg; + + const n = row.stats.n_samples; + const mean = calculateHistogramMean(histData); + const median = calculateHistogramMedian(histData); + + container.title = `Max Activation Distribution (n=${n} samples)\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + return container; + }, + + clusterLink: function(value, row, col) { + return `View →`; + }, + + explanation: function(value, row, col) { + if (!value) { + return ''; + } + // Truncate long explanations + const maxLength = 60; + if (value.length > maxLength) { + const truncated = value.substring(0, maxLength) + '...'; + const span = document.createElement('span'); + span.textContent = truncated; + span.title = value; // Show full text on hover + return span; + } + return value; + }, + + tokenEntropy: function(value, row, col) { + const tokenStats = row.stats.token_activations; + if (!tokenStats) { + return 'N/A'; + } + return tokenStats.entropy.toFixed(2); + }, + + tokenConcentration: function(value, row, col) { + const tokenStats = row.stats.token_activations; + if (!tokenStats) { + return 'N/A'; + } + return (tokenStats.concentration_ratio * 100).toFixed(1) + '%'; + }, + + topToken: function(value, row, col) { + const tokenStats = row.stats.token_activations; + if (!tokenStats || !tokenStats.top_tokens || tokenStats.top_tokens.length === 0) { + return 'N/A'; + } + + const container = document.createElement('div'); + container.style.fontFamily = 'monospace'; + container.style.fontSize = '11px'; + container.style.lineHeight = '1.4'; + + const topN = Math.min(5, tokenStats.top_tokens.length); + const maxPercentage = tokenStats.top_tokens.length > 0 + ? ((tokenStats.top_tokens[0].count / tokenStats.total_activations) * 100) + : 0; + + for (let i = 0; i < topN; i++) { + const token = tokenStats.top_tokens[i]; + const tokenDisplay = token.token.replace(/ /g, '·').replace(/\n/g, '↵'); + const percentageValue = ((token.count / tokenStats.total_activations) * 100); + const percentage = percentageValue.toFixed(1); + + // Color based on percentage (normalized by max percentage) + const normalizedPct = maxPercentage > 0 ? percentageValue / maxPercentage : 0; + const intensity = Math.floor((1 - normalizedPct) * 255); + const bgColor = `rgb(255, ${intensity}, ${intensity})`; + + const line = document.createElement('div'); + line.style.display = 'flex'; + line.style.justifyContent = 'space-between'; + line.style.gap = '8px'; + + const tokenSpan = document.createElement('span'); + tokenSpan.innerHTML = `${tokenDisplay}`; + tokenSpan.style.textAlign = 'left'; + + const pctSpan = document.createElement('span'); + pctSpan.textContent = `${percentage}%`; + pctSpan.style.textAlign = 'right'; + pctSpan.style.backgroundColor = bgColor; + pctSpan.style.padding = '2px 4px'; + pctSpan.style.borderRadius = '2px'; + + line.appendChild(tokenSpan); + line.appendChild(pctSpan); + container.appendChild(line); + } + + return container; + }, + + // Generic histogram renderer for any BinnedData stat + genericHistogram: function(statKey, color, title) { + return function(value, row, col) { + const histData = row.stats[statKey]; + if (!histData || !histData.bin_counts) { + return 'No data'; + } + + const container = document.createElement('div'); + container.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + // Calculate statistics of underlying data + const min = histData.bin_edges[0]; + const max = histData.bin_edges[histData.bin_edges.length - 1]; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + // Pass bin centers as x-values and counts as y-values + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth, + height: CONFIG.visualization.sparklineHeight, + color: color, + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin} + }); + + container.innerHTML = svg; + + const mean = calculateHistogramMean(histData); + const median = calculateHistogramMedian(histData); + const totalCount = histData.bin_counts.reduce((a, b) => a + b, 0); + + container.title = `${title} (n=${totalCount})\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + return container; + }; + } +}; + +// ============================================================================ +// Helper Functions for Filtering and Sorting +// ============================================================================ + +/** + * Create a filter function for module arrays that supports wildcards, multiple patterns, and negation + * @param {string} filterValue - The filter pattern (supports * wildcards, , for OR, & for AND, @ for all-match, ! for negation) + * @returns {Function|null} Filter function or null if invalid + */ +function createModuleFilter(filterValue) { + if (!filterValue || !filterValue.trim()) return null; + + // Split by comma for OR groups + const orGroups = filterValue.split(',').map(g => g.trim()).filter(g => g); + + // Parse each OR group (which may contain & for AND) + const parsedOrGroups = orGroups.map(group => { + // Split by & for AND conditions within this OR group + const andConditions = group.split('&').map(c => c.trim()).filter(c => c); + + return andConditions.map(condition => { + let mode = 'some'; // default: at least one module matches + let negate = false; + let pattern = condition.toLowerCase(); + + // Check for @ prefix (all modules must match) + if (pattern.startsWith('@')) { + mode = 'every'; + pattern = pattern.substring(1); + } + // Check for ! prefix (no modules can match) + else if (pattern.startsWith('!')) { + negate = true; + pattern = pattern.substring(1); + } + + const regex = pattern.includes('*') + ? new RegExp('^' + pattern.replace(/\*/g, '.*') + '$') + : null; + + return { mode, negate, pattern, regex }; + }); + }); + + return (cellValue) => { + // cellValue is the modules array + if (!Array.isArray(cellValue)) return false; + + // OR logic across groups + return parsedOrGroups.some(andGroup => { + // AND logic within group + return andGroup.every(condition => { + const matchFn = (module) => { + const moduleLower = module.toLowerCase(); + return condition.regex + ? condition.regex.test(moduleLower) + : moduleLower.includes(condition.pattern); + }; + + if (condition.mode === 'every') { + // ALL modules must match + const result = cellValue.every(matchFn); + return condition.negate ? !result : result; + } else { + // At least ONE module must match (or none if negated) + const result = cellValue.some(matchFn); + return condition.negate ? !result : result; + } + }); + }); + }; +} + +/** + * Sort function for module arrays + * Primary: number of modules (ascending) + * Secondary: alphabetically by first module name + * @param {Array} modules - Array of module names + * @returns {string} Sortable string representation + */ +function sortModules(modules) { + if (!Array.isArray(modules) || modules.length === 0) return ''; + + // Pad module count for proper numeric sorting, then add first module name + const count = modules.length.toString().padStart(5, '0'); + const firstName = modules[0].toLowerCase(); + return `${count}_${firstName}`; +} + +/** + * Parse extended histogram filter syntax (e.g., "mean>0.5", "max<10", "mean>0.5, max<10") + * @param {string} filterValue - The filter string (can be comma-separated for multiple conditions) + * @returns {Array|null} Array of parsed filters [{ statType, operator, value }] or null if plain numeric + */ +function parseHistogramFilter(filterValue) { + const trimmed = filterValue.trim(); + if (!trimmed) return null; + + // Split by comma to support multiple conditions + const conditions = trimmed.split(',').map(c => c.trim()); + const parsedConditions = []; + + for (const condition of conditions) { + // Match pattern: statType operator value (e.g., "mean>0.5", "median<=0.2") + const match = condition.match(/^(mean|median|max|min|range|sum)\s*(==|!=|>=|<=|>|<)\s*(-?\d+\.?\d*)$/i); + + if (match) { + parsedConditions.push({ + statType: match[1].toLowerCase(), + operator: match[2], + value: parseFloat(match[3]) + }); + } else { + // If any condition doesn't match, return null to use default filter + return null; + } + } + + // Return array of conditions, or null if none were found + return parsedConditions.length > 0 ? parsedConditions : null; +} + +/** + * Create a filter function for histogram columns with extended syntax + * Supports multiple comma-separated conditions (AND logic) + * @param {string} statKey - The statistics key + * @param {string} filterValue - The filter string (e.g., "mean>0.5, max<10") + * @returns {Function|null} Filter function or null to use default + */ +function createHistogramFilter(statKey, filterValue) { + const parsedConditions = parseHistogramFilter(filterValue); + + if (!parsedConditions) { + // Return null to let default numeric filter handle it + // Default will filter on the sort value (mean by default) + return null; + } + + return (cellValue, row) => { + // All conditions must be satisfied (AND logic) + for (const condition of parsedConditions) { + const { statType, operator, value } = condition; + const histData = row.stats[statKey]; + + if (!histData || !histData.bin_counts || !histData.bin_edges) return false; + + // Calculate the requested statistic + let statValue; + switch (statType) { + case 'mean': + // For all_activations, use precomputed mean + if (statKey === 'all_activations' && row.stats.mean_activation !== undefined) { + statValue = row.stats.mean_activation; + } else { + statValue = calculateHistogramMean(histData); + } + break; + case 'median': + statValue = calculateHistogramMedian(histData); + break; + case 'max': + statValue = histData.bin_edges[histData.bin_edges.length - 1]; + break; + case 'min': + statValue = histData.bin_edges[0]; + break; + case 'range': + statValue = histData.bin_edges[histData.bin_edges.length - 1] - histData.bin_edges[0]; + break; + case 'sum': + statValue = histData.bin_counts.reduce((a, b) => a + b, 0); + break; + default: + return false; + } + + if (statValue === null || statValue === undefined) return false; + + let conditionMet = false; + switch (operator) { + case '==': conditionMet = Math.abs(statValue - value) < 0.0001; break; + case '!=': conditionMet = Math.abs(statValue - value) >= 0.0001; break; + case '>': conditionMet = statValue > value; break; + case '<': conditionMet = statValue < value; break; + case '>=': conditionMet = statValue >= value; break; + case '<=': conditionMet = statValue <= value; break; + default: conditionMet = false; + } + + // If any condition fails, return false + if (!conditionMet) return false; + } + + // All conditions passed + return true; + }; +} + +/** + * Get the top token string for sorting + * @param {object} value - Cell value (stats object) + * @param {object} row - The data row + * @returns {string} The top token string for sorting + */ +function sortTopToken(value, row) { + const tokenStats = row.stats.token_activations; + if (!tokenStats || !tokenStats.top_tokens || tokenStats.top_tokens.length === 0) { + return ''; + } + return tokenStats.top_tokens[0].token.toLowerCase(); +} + +/** + * Create a filter function for top tokens + * @param {string} filterValue - The filter string + * @returns {Function|null} Filter function or null if invalid + */ +function createTopTokenFilter(filterValue) { + if (!filterValue || !filterValue.trim()) return null; + + const pattern = filterValue.toLowerCase().trim(); + + return (cellValue, row) => { + const tokenStats = row.stats.token_activations; + if (!tokenStats || !tokenStats.top_tokens) return false; + + // Search in top 10 tokens + const topN = Math.min(10, tokenStats.top_tokens.length); + for (let i = 0; i < topN; i++) { + const token = tokenStats.top_tokens[i].token.toLowerCase(); + if (token.includes(pattern)) { + return true; + } + } + return false; + }; +} + +/** + * Create a filter function for numeric comparisons with operators + * @param {string} filterValue - The filter string (e.g., ">2.5", "<=0.8") + * @param {Function} valueExtractor - Function to extract numeric value from cellValue + * @returns {Function|null} Filter function or null if invalid + */ +function createNumericFilter(filterValue, valueExtractor) { + if (!filterValue || !filterValue.trim()) return null; + + const trimmed = filterValue.trim(); + + // Match pattern: operator value (e.g., ">2.5", "<=0.8") + const match = trimmed.match(/^(==|!=|>=|<=|>|<)\s*(-?\d+\.?\d*)$/); + + if (!match) { + // Try plain number (defaults to ==) + const plainNum = parseFloat(trimmed); + if (!isNaN(plainNum)) { + return (cellValue, row) => { + const value = valueExtractor(cellValue); + if (value === null || value === undefined) return false; + return Math.abs(value - plainNum) < 0.0001; + }; + } + return null; + } + + const operator = match[1]; + const targetValue = parseFloat(match[2]); + + return (cellValue, row) => { + const value = valueExtractor(cellValue); + if (value === null || value === undefined) return false; + + switch (operator) { + case '==': return Math.abs(value - targetValue) < 0.0001; + case '!=': return Math.abs(value - targetValue) >= 0.0001; + case '>': return value > targetValue; + case '<': return value < targetValue; + case '>=': return value >= targetValue; + case '<=': return value <= targetValue; + default: return false; + } + }; +} + +function processClusterData() { + const tableData = []; + + for (const [clusterHash, cluster] of Object.entries(clusterData)) { + const modules = new Set(); + cluster.components.forEach(comp => { + modules.add(comp.module); + }); + + const stats = cluster.stats; + + // Extract cluster ID from hash (format: "runid-iteration-clusteridx") + const parts = clusterHash.split('-'); + const clusterId = parseInt(parts[parts.length - 1]); + + // Get explanation for this cluster + const explanationData = explanations[clusterHash]; + const explanation = explanationData ? explanationData.explanation : null; + + tableData.push({ + id: clusterId, + clusterHash: clusterHash, + componentCount: cluster.components.length, + modules: Array.from(modules), + stats: stats, + explanation: explanation + }); + } + + return tableData; +} + +async function loadData() { + // Load cluster data (model info is handled by Alpine.js) + const clusters = await loadJSONL(CONFIG.getDataPath('clusters'), 'cluster_hash'); + + clusterData = clusters; + + // Load explanations (non-critical, don't fail if missing) + explanations = await loadJSONL(CONFIG.getDataPath('explanations'), 'cluster_id').catch(() => ({})); + + const tableData = processClusterData(); + + // Discover histogram stats from first cluster + const firstCluster = Object.values(clusterData)[0]; + const histogramStats = []; + if (firstCluster && firstCluster.stats) { + for (const [key, value] of Object.entries(firstCluster.stats)) { + if (value && typeof value === 'object' && 'bin_counts' in value && 'bin_edges' in value) { + histogramStats.push(key); + } + } + } + + // Base columns + const columns = [ + { + key: 'id', + label: 'ID', + type: 'number', + width: '10px', + align: 'center' + }, + { + key: 'componentCount', + label: 'Comps', + type: 'number', + width: '10px', + align: 'right' + }, + { + key: 'modules', + label: 'Model View', + type: 'string', + width: '21px', + align: 'center', + renderer: columnRenderers.modelView, + sortFunction: (modules) => sortModules(modules), + filterFunction: (filterValue) => createModuleFilter(filterValue), + filterTooltip: 'Filter by module. Separate with , (OR) or & (AND). Use * for wildcards. Prefix @ for all-match, ! to exclude. Examples: *mlp*,*attn* (OR), *mlp*&*attn* (AND), @*proj* (all), !*o_proj* (exclude)' + }, + { + key: 'modules', + label: 'Modules', + type: 'string', + width: '10px', + renderer: columnRenderers.modulesSummary, + sortFunction: (modules) => sortModules(modules), + filterFunction: (filterValue) => createModuleFilter(filterValue), + filterTooltip: 'Filter by module. Separate with , (OR) or & (AND). Use * for wildcards. Prefix @ for all-match, ! to exclude. Examples: *mlp*,*attn* (OR), *mlp*&*attn* (AND), @*proj* (all), !*o_proj* (exclude)' + } + ]; + + // Add histogram columns dynamically + const statColors = { + 'all_activations': '#4169E1', + 'max_activation-max-16': '#DC143C', + 'max_activation-max-32': '#DC143C', + 'mean_activation-max-16': '#228B22', + 'median_activation-max-16': '#FF8C00', + 'min_activation-max-16': '#9370DB', + 'max_activation_position': '#FF6347' + }; + + histogramStats.forEach(statKey => { + const color = statColors[statKey] || '#808080'; + const label = statKey.replace(/-/g, ' ').replace(/_/g, ' ') + .split(' ') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(' '); + + columns.push({ + id: 'histogram_' + statKey, + key: 'stats', + label: label, + type: 'number', + width: '200px', + align: 'center', + renderer: columnRenderers.genericHistogram(statKey, color, label), + sortFunction: (value, row) => { + const histData = row.stats[statKey]; + if (!histData || !histData.bin_counts || !histData.bin_edges) return -Infinity; + // For all_activations, use precomputed mean + if (statKey === 'all_activations' && row.stats.mean_activation !== undefined) { + return row.stats.mean_activation; + } + // Otherwise calculate mean from histogram + return calculateHistogramMean(histData); + }, + filterFunction: (filterValue) => createHistogramFilter(statKey, filterValue), + filterTooltip: 'Filter by statistics. Use: mean>0.5, median<0.2, max>=1.0, min>-0.1, range<5, sum>100. Combine with commas (e.g., mean>0.5, max<10)' + }); + }); + + // Token activation columns + columns.push({ + id: 'top_tokens', + key: 'stats', + label: 'Top Tokens', + type: 'string', + width: '150px', + align: 'left', + renderer: columnRenderers.topToken, + sortFunction: (value, row) => sortTopToken(value, row), + filterFunction: (filterValue) => createTopTokenFilter(filterValue), + filterTooltip: 'Search for tokens (case-insensitive substring match)' + }); + + columns.push({ + id: 'token_entropy', + key: 'stats', + label: 'Token Entropy', + type: 'number', + width: '60px', + align: 'right', + renderer: columnRenderers.tokenEntropy, + sortFunction: (value, row) => { + const tokenStats = row.stats.token_activations; + return tokenStats ? tokenStats.entropy : -Infinity; + }, + filterFunction: (filterValue) => createNumericFilter(filterValue, (stats) => { + const tokenStats = stats?.token_activations; + return tokenStats ? tokenStats.entropy : null; + }), + filterTooltip: 'Filter by entropy. Use operators: >, <, >=, <=, ==, != (e.g., >2.5)' + }); + + columns.push({ + id: 'token_concentration', + key: 'stats', + label: 'Token Conc.', + type: 'number', + width: '60px', + align: 'right', + renderer: columnRenderers.tokenConcentration, + sortFunction: (value, row) => { + const tokenStats = row.stats.token_activations; + return tokenStats ? tokenStats.concentration_ratio : -Infinity; + }, + filterFunction: (filterValue) => createNumericFilter(filterValue, (stats) => { + const tokenStats = stats?.token_activations; + return tokenStats ? tokenStats.concentration_ratio : null; + }), + filterTooltip: 'Filter by concentration (0-1). Use operators: >, <, >=, <=, ==, != (e.g., >0.5)' + }); + + // Explanation column + columns.push({ + key: 'explanation', + label: 'Explanation', + type: 'string', + width: '200px', + align: 'left', + renderer: columnRenderers.explanation, + filterTooltip: 'Filter by explanation text (case-insensitive substring match)' + }); + + // Actions column + columns.push({ + key: 'id', + label: 'Actions', + type: 'string', + width: '20px', + align: 'center', + renderer: columnRenderers.clusterLink, + filterable: false + }); + + const tableConfig = { + data: tableData, + columns: columns, + pageSize: CONFIG.indexPage.pageSize, + pageSizeOptions: CONFIG.indexPage.pageSizeOptions, + showFilters: CONFIG.indexPage.showFilters + }; + + dataTable = new DataTable('#clusterTableContainer', tableConfig); + + const loading = document.getElementById('loading'); + if (!loading) { + const msg = 'Fatal error: loading element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + loading.style.display = 'none'; +} + +document.addEventListener('DOMContentLoaded', async () => { + await initConfig(); + + // Check if Alpine.js loaded + if (typeof Alpine === 'undefined') { + const msg = 'Fatal error: Alpine.js failed to load. Check your internet connection or CDN.'; + NOTIF.error(msg, null, null); + console.error(msg); + } else { + // Manually trigger Alpine component's loadData now that CONFIG is ready + const modelInfoEl = document.getElementById('modelInfo'); + if (modelInfoEl && Alpine.$data(modelInfoEl)) { + Alpine.$data(modelInfoEl).loadData(); + } + } + + // Load cluster data and render table + loadData(); +}); diff --git a/spd/clustering/ci_dt/js/model-visualization.js b/spd/clustering/ci_dt/js/model-visualization.js new file mode 100644 index 000000000..f42e55922 --- /dev/null +++ b/spd/clustering/ci_dt/js/model-visualization.js @@ -0,0 +1,222 @@ +// Self-contained utilities for model visualization +// No global variables, all functions take necessary data as parameters + +function getClusterModuleStats(clusterId, clusterData) { + if (!clusterData || !clusterData[clusterId]) return {}; + + const cluster = clusterData[clusterId]; + const moduleStats = {}; + + // Count components per module for this specific cluster + cluster.components.forEach(comp => { + const module = comp.module; + if (!moduleStats[module]) { + moduleStats[module] = { + componentCount: 0, + components: [] + }; + } + moduleStats[module].componentCount++; + moduleStats[module].components.push(comp); + }); + + return moduleStats; +} + +function getModuleOrder(moduleName) { + if (moduleName.includes('q_proj')) return 0; + if (moduleName.includes('k_proj')) return 1; + if (moduleName.includes('v_proj')) return 2; + if (moduleName.includes('o_proj')) return 3; + if (moduleName.includes('gate_proj')) return 10; + if (moduleName.includes('up_proj')) return 11; + if (moduleName.includes('down_proj')) return 12; + return 999; +} + +function renderModelArchitecture(clusterId, clusterData, modelInfo, colormap = 'blues') { + if (!modelInfo || !modelInfo.module_list) { + throw new Error('Model info not loaded'); + } + + const moduleStats = clusterData && clusterData[clusterId] ? getClusterModuleStats(clusterId, clusterData) : {}; + const maxComponents = Math.max(...Object.values(moduleStats).map(s => s.componentCount), 1); + + // Group ALL modules from model_info by layer and type + const layerGroups = {}; + + modelInfo.module_list.forEach(moduleName => { + const parts = moduleName.split('.'); + let layerNum = -1; + let moduleType = 'other'; + + for (let i = 0; i < parts.length; i++) { + if (parts[i] === 'layers' && i + 1 < parts.length) { + layerNum = parseInt(parts[i + 1]); + } + } + + if (moduleName.includes('self_attn')) { + moduleType = 'attention'; + } else if (moduleName.includes('mlp')) { + moduleType = 'mlp'; + } + + if (!layerGroups[layerNum]) { + layerGroups[layerNum] = { attention: [], mlp: [], other: [] }; + } + + const count = moduleStats[moduleName] ? moduleStats[moduleName].componentCount : 0; + const components = moduleStats[moduleName] ? moduleStats[moduleName].components : []; + + layerGroups[layerNum][moduleType].push({ + name: moduleName, + count: count, + components: components + }); + }); + + // Sort modules within each group by desired order + Object.values(layerGroups).forEach(layer => { + layer.attention.sort((a, b) => getModuleOrder(a.name) - getModuleOrder(b.name)); + layer.mlp.sort((a, b) => getModuleOrder(a.name) - getModuleOrder(b.name)); + }); + + const sortedLayers = Object.keys(layerGroups).sort((a, b) => a - b); + const cellSize = 12; + + const moduleElements = []; + + sortedLayers.forEach(layerNum => { + const layer = layerGroups[layerNum]; + const layerElements = []; + + // Attention row (above MLP) + if (layer.attention.length > 0) { + const attentionRow = layer.attention.map(module => ({ + type: 'cell', + module: module.name, + count: module.count, + components: module.components.map(c => c.index).join(','), + color: getColorForValue(module.count, maxComponents, colormap), + size: cellSize + })); + layerElements.push({ type: 'row', cells: attentionRow }); + } + + // MLP row (below attention) + if (layer.mlp.length > 0) { + const mlpRow = layer.mlp.map(module => ({ + type: 'cell', + module: module.name, + count: module.count, + components: module.components.map(c => c.index).join(','), + color: getColorForValue(module.count, maxComponents, colormap), + size: cellSize + })); + layerElements.push({ type: 'row', cells: mlpRow }); + } + + // Other modules + if (layer.other.length > 0) { + const otherRow = layer.other.map(module => ({ + type: 'cell', + module: module.name, + count: module.count, + components: module.components.map(c => c.index).join(','), + color: getColorForValue(module.count, maxComponents, colormap), + size: cellSize + })); + layerElements.push({ type: 'row', cells: otherRow }); + } + + if (layerElements.length > 0) { + moduleElements.push({ type: 'layer', rows: layerElements }); + } + }); + + return { + elements: moduleElements, + maxComponents: maxComponents + }; +} + +function renderToHTML(architecture) { + let html = ''; + + architecture.elements.forEach(layer => { + html += '
'; + layer.rows.forEach(row => { + html += '
'; + row.cells.forEach(cell => { + html += `
`; + }); + html += '
'; + }); + html += '
'; + }); + + return html; +} + +// Consolidated tooltip setup - works for all model visualizations +function setupTooltips(containerElement) { + const tooltip = document.getElementById('tooltip'); + if (!tooltip) return; + + const cells = containerElement.querySelectorAll('.modelview-module-cell'); + + cells.forEach(cell => { + cell.addEventListener('mouseenter', (e) => { + const module = e.target.dataset.module; + const count = e.target.dataset.count; + const components = e.target.dataset.components; + + if (module) { + tooltip.textContent = `${module}\nComponents: ${count}\nIndices: ${components || 'none'}`; + tooltip.style.display = 'block'; + tooltip.style.left = (e.pageX + 10) + 'px'; + tooltip.style.top = (e.pageY + 10) + 'px'; + } + }); + + cell.addEventListener('mouseleave', () => { + tooltip.style.display = 'none'; + }); + + cell.addEventListener('mousemove', (e) => { + tooltip.style.left = (e.pageX + 10) + 'px'; + tooltip.style.top = (e.pageY + 10) + 'px'; + }); + }); +} + +// Consolidated render function - creates model visualization in a container +function renderModelView(containerElement, clusterHash, clusterData, modelInfo, colormap = 'blues', cellSize = null) { + if (!modelInfo || !modelInfo.module_list) { + containerElement.innerHTML = 'Model info loading...'; + return; + } + + if (!clusterData || !clusterData[clusterHash]) { + containerElement.innerHTML = 'Cluster data missing'; + return; + } + + try { + const architecture = renderModelArchitecture(clusterHash, clusterData, modelInfo, colormap); + const html = renderToHTML(architecture); + containerElement.innerHTML = html; + + // Apply cell size from config if provided + if (cellSize !== null) { + containerElement.style.setProperty('--modelview-cell-size', cellSize + 'px'); + } + + // Setup tooltips after a brief delay to ensure DOM is ready + setTimeout(() => setupTooltips(containerElement), 0); + } catch (error) { + console.error('Failed to render model visualization:', error); + containerElement.innerHTML = 'Model visualization error'; + } +} \ No newline at end of file diff --git a/spd/clustering/ci_dt/js/pkg/jszip.js b/spd/clustering/ci_dt/js/pkg/jszip.js new file mode 100644 index 000000000..60fbb41a6 --- /dev/null +++ b/spd/clustering/ci_dt/js/pkg/jszip.js @@ -0,0 +1,11577 @@ +/*! + +JSZip v3.10.1 - A JavaScript class for generating and reading zip files + + +(c) 2009-2016 Stuart Knightley +Dual licenced under the MIT license or GPLv3. See https://raw.github.com/Stuk/jszip/main/LICENSE.markdown. + +JSZip uses the library pako released under the MIT license : +https://github.com/nodeca/pako/blob/main/LICENSE +*/ + +(function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.JSZip = f()}})(function(){var define,module,exports;return (function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o> 2; + enc2 = ((chr1 & 3) << 4) | (chr2 >> 4); + enc3 = remainingBytes > 1 ? (((chr2 & 15) << 2) | (chr3 >> 6)) : 64; + enc4 = remainingBytes > 2 ? (chr3 & 63) : 64; + + output.push(_keyStr.charAt(enc1) + _keyStr.charAt(enc2) + _keyStr.charAt(enc3) + _keyStr.charAt(enc4)); + + } + + return output.join(""); +}; + +// public method for decoding +exports.decode = function(input) { + var chr1, chr2, chr3; + var enc1, enc2, enc3, enc4; + var i = 0, resultIndex = 0; + + var dataUrlPrefix = "data:"; + + if (input.substr(0, dataUrlPrefix.length) === dataUrlPrefix) { + // This is a common error: people give a data url + // (data:image/png;base64,iVBOR...) with a {base64: true} and + // wonders why things don't work. + // We can detect that the string input looks like a data url but we + // *can't* be sure it is one: removing everything up to the comma would + // be too dangerous. + throw new Error("Invalid base64 input, it looks like a data url."); + } + + input = input.replace(/[^A-Za-z0-9+/=]/g, ""); + + var totalLength = input.length * 3 / 4; + if(input.charAt(input.length - 1) === _keyStr.charAt(64)) { + totalLength--; + } + if(input.charAt(input.length - 2) === _keyStr.charAt(64)) { + totalLength--; + } + if (totalLength % 1 !== 0) { + // totalLength is not an integer, the length does not match a valid + // base64 content. That can happen if: + // - the input is not a base64 content + // - the input is *almost* a base64 content, with a extra chars at the + // beginning or at the end + // - the input uses a base64 variant (base64url for example) + throw new Error("Invalid base64 input, bad content length."); + } + var output; + if (support.uint8array) { + output = new Uint8Array(totalLength|0); + } else { + output = new Array(totalLength|0); + } + + while (i < input.length) { + + enc1 = _keyStr.indexOf(input.charAt(i++)); + enc2 = _keyStr.indexOf(input.charAt(i++)); + enc3 = _keyStr.indexOf(input.charAt(i++)); + enc4 = _keyStr.indexOf(input.charAt(i++)); + + chr1 = (enc1 << 2) | (enc2 >> 4); + chr2 = ((enc2 & 15) << 4) | (enc3 >> 2); + chr3 = ((enc3 & 3) << 6) | enc4; + + output[resultIndex++] = chr1; + + if (enc3 !== 64) { + output[resultIndex++] = chr2; + } + if (enc4 !== 64) { + output[resultIndex++] = chr3; + } + + } + + return output; +}; + +},{"./support":30,"./utils":32}],2:[function(require,module,exports){ +"use strict"; + +var external = require("./external"); +var DataWorker = require("./stream/DataWorker"); +var Crc32Probe = require("./stream/Crc32Probe"); +var DataLengthProbe = require("./stream/DataLengthProbe"); + +/** + * Represent a compressed object, with everything needed to decompress it. + * @constructor + * @param {number} compressedSize the size of the data compressed. + * @param {number} uncompressedSize the size of the data after decompression. + * @param {number} crc32 the crc32 of the decompressed file. + * @param {object} compression the type of compression, see lib/compressions.js. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the compressed data. + */ +function CompressedObject(compressedSize, uncompressedSize, crc32, compression, data) { + this.compressedSize = compressedSize; + this.uncompressedSize = uncompressedSize; + this.crc32 = crc32; + this.compression = compression; + this.compressedContent = data; +} + +CompressedObject.prototype = { + /** + * Create a worker to get the uncompressed content. + * @return {GenericWorker} the worker. + */ + getContentWorker: function () { + var worker = new DataWorker(external.Promise.resolve(this.compressedContent)) + .pipe(this.compression.uncompressWorker()) + .pipe(new DataLengthProbe("data_length")); + + var that = this; + worker.on("end", function () { + if (this.streamInfo["data_length"] !== that.uncompressedSize) { + throw new Error("Bug : uncompressed data size mismatch"); + } + }); + return worker; + }, + /** + * Create a worker to get the compressed content. + * @return {GenericWorker} the worker. + */ + getCompressedWorker: function () { + return new DataWorker(external.Promise.resolve(this.compressedContent)) + .withStreamInfo("compressedSize", this.compressedSize) + .withStreamInfo("uncompressedSize", this.uncompressedSize) + .withStreamInfo("crc32", this.crc32) + .withStreamInfo("compression", this.compression) + ; + } +}; + +/** + * Chain the given worker with other workers to compress the content with the + * given compression. + * @param {GenericWorker} uncompressedWorker the worker to pipe. + * @param {Object} compression the compression object. + * @param {Object} compressionOptions the options to use when compressing. + * @return {GenericWorker} the new worker compressing the content. + */ +CompressedObject.createWorkerFrom = function (uncompressedWorker, compression, compressionOptions) { + return uncompressedWorker + .pipe(new Crc32Probe()) + .pipe(new DataLengthProbe("uncompressedSize")) + .pipe(compression.compressWorker(compressionOptions)) + .pipe(new DataLengthProbe("compressedSize")) + .withStreamInfo("compression", compression); +}; + +module.exports = CompressedObject; + +},{"./external":6,"./stream/Crc32Probe":25,"./stream/DataLengthProbe":26,"./stream/DataWorker":27}],3:[function(require,module,exports){ +"use strict"; + +var GenericWorker = require("./stream/GenericWorker"); + +exports.STORE = { + magic: "\x00\x00", + compressWorker : function () { + return new GenericWorker("STORE compression"); + }, + uncompressWorker : function () { + return new GenericWorker("STORE decompression"); + } +}; +exports.DEFLATE = require("./flate"); + +},{"./flate":7,"./stream/GenericWorker":28}],4:[function(require,module,exports){ +"use strict"; + +var utils = require("./utils"); + +/** + * The following functions come from pako, from pako/lib/zlib/crc32.js + * released under the MIT license, see pako https://github.com/nodeca/pako/ + */ + +// Use ordinary array, since untyped makes no boost here +function makeTable() { + var c, table = []; + + for(var n =0; n < 256; n++){ + c = n; + for(var k =0; k < 8; k++){ + c = ((c&1) ? (0xEDB88320 ^ (c >>> 1)) : (c >>> 1)); + } + table[n] = c; + } + + return table; +} + +// Create table on load. Just 255 signed longs. Not a problem. +var crcTable = makeTable(); + + +function crc32(crc, buf, len, pos) { + var t = crcTable, end = pos + len; + + crc = crc ^ (-1); + + for (var i = pos; i < end; i++ ) { + crc = (crc >>> 8) ^ t[(crc ^ buf[i]) & 0xFF]; + } + + return (crc ^ (-1)); // >>> 0; +} + +// That's all for the pako functions. + +/** + * Compute the crc32 of a string. + * This is almost the same as the function crc32, but for strings. Using the + * same function for the two use cases leads to horrible performances. + * @param {Number} crc the starting value of the crc. + * @param {String} str the string to use. + * @param {Number} len the length of the string. + * @param {Number} pos the starting position for the crc32 computation. + * @return {Number} the computed crc32. + */ +function crc32str(crc, str, len, pos) { + var t = crcTable, end = pos + len; + + crc = crc ^ (-1); + + for (var i = pos; i < end; i++ ) { + crc = (crc >>> 8) ^ t[(crc ^ str.charCodeAt(i)) & 0xFF]; + } + + return (crc ^ (-1)); // >>> 0; +} + +module.exports = function crc32wrapper(input, crc) { + if (typeof input === "undefined" || !input.length) { + return 0; + } + + var isArray = utils.getTypeOf(input) !== "string"; + + if(isArray) { + return crc32(crc|0, input, input.length, 0); + } else { + return crc32str(crc|0, input, input.length, 0); + } +}; + +},{"./utils":32}],5:[function(require,module,exports){ +"use strict"; +exports.base64 = false; +exports.binary = false; +exports.dir = false; +exports.createFolders = true; +exports.date = null; +exports.compression = null; +exports.compressionOptions = null; +exports.comment = null; +exports.unixPermissions = null; +exports.dosPermissions = null; + +},{}],6:[function(require,module,exports){ +"use strict"; + +// load the global object first: +// - it should be better integrated in the system (unhandledRejection in node) +// - the environment may have a custom Promise implementation (see zone.js) +var ES6Promise = null; +if (typeof Promise !== "undefined") { + ES6Promise = Promise; +} else { + ES6Promise = require("lie"); +} + +/** + * Let the user use/change some implementations. + */ +module.exports = { + Promise: ES6Promise +}; + +},{"lie":37}],7:[function(require,module,exports){ +"use strict"; +var USE_TYPEDARRAY = (typeof Uint8Array !== "undefined") && (typeof Uint16Array !== "undefined") && (typeof Uint32Array !== "undefined"); + +var pako = require("pako"); +var utils = require("./utils"); +var GenericWorker = require("./stream/GenericWorker"); + +var ARRAY_TYPE = USE_TYPEDARRAY ? "uint8array" : "array"; + +exports.magic = "\x08\x00"; + +/** + * Create a worker that uses pako to inflate/deflate. + * @constructor + * @param {String} action the name of the pako function to call : either "Deflate" or "Inflate". + * @param {Object} options the options to use when (de)compressing. + */ +function FlateWorker(action, options) { + GenericWorker.call(this, "FlateWorker/" + action); + + this._pako = null; + this._pakoAction = action; + this._pakoOptions = options; + // the `meta` object from the last chunk received + // this allow this worker to pass around metadata + this.meta = {}; +} + +utils.inherits(FlateWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +FlateWorker.prototype.processChunk = function (chunk) { + this.meta = chunk.meta; + if (this._pako === null) { + this._createPako(); + } + this._pako.push(utils.transformTo(ARRAY_TYPE, chunk.data), false); +}; + +/** + * @see GenericWorker.flush + */ +FlateWorker.prototype.flush = function () { + GenericWorker.prototype.flush.call(this); + if (this._pako === null) { + this._createPako(); + } + this._pako.push([], true); +}; +/** + * @see GenericWorker.cleanUp + */ +FlateWorker.prototype.cleanUp = function () { + GenericWorker.prototype.cleanUp.call(this); + this._pako = null; +}; + +/** + * Create the _pako object. + * TODO: lazy-loading this object isn't the best solution but it's the + * quickest. The best solution is to lazy-load the worker list. See also the + * issue #446. + */ +FlateWorker.prototype._createPako = function () { + this._pako = new pako[this._pakoAction]({ + raw: true, + level: this._pakoOptions.level || -1 // default compression + }); + var self = this; + this._pako.onData = function(data) { + self.push({ + data : data, + meta : self.meta + }); + }; +}; + +exports.compressWorker = function (compressionOptions) { + return new FlateWorker("Deflate", compressionOptions); +}; +exports.uncompressWorker = function () { + return new FlateWorker("Inflate", {}); +}; + +},{"./stream/GenericWorker":28,"./utils":32,"pako":38}],8:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("../stream/GenericWorker"); +var utf8 = require("../utf8"); +var crc32 = require("../crc32"); +var signature = require("../signature"); + +/** + * Transform an integer into a string in hexadecimal. + * @private + * @param {number} dec the number to convert. + * @param {number} bytes the number of bytes to generate. + * @returns {string} the result. + */ +var decToHex = function(dec, bytes) { + var hex = "", i; + for (i = 0; i < bytes; i++) { + hex += String.fromCharCode(dec & 0xff); + dec = dec >>> 8; + } + return hex; +}; + +/** + * Generate the UNIX part of the external file attributes. + * @param {Object} unixPermissions the unix permissions or null. + * @param {Boolean} isDir true if the entry is a directory, false otherwise. + * @return {Number} a 32 bit integer. + * + * adapted from http://unix.stackexchange.com/questions/14705/the-zip-formats-external-file-attribute : + * + * TTTTsstrwxrwxrwx0000000000ADVSHR + * ^^^^____________________________ file type, see zipinfo.c (UNX_*) + * ^^^_________________________ setuid, setgid, sticky + * ^^^^^^^^^________________ permissions + * ^^^^^^^^^^______ not used ? + * ^^^^^^ DOS attribute bits : Archive, Directory, Volume label, System file, Hidden, Read only + */ +var generateUnixExternalFileAttr = function (unixPermissions, isDir) { + + var result = unixPermissions; + if (!unixPermissions) { + // I can't use octal values in strict mode, hence the hexa. + // 040775 => 0x41fd + // 0100664 => 0x81b4 + result = isDir ? 0x41fd : 0x81b4; + } + return (result & 0xFFFF) << 16; +}; + +/** + * Generate the DOS part of the external file attributes. + * @param {Object} dosPermissions the dos permissions or null. + * @param {Boolean} isDir true if the entry is a directory, false otherwise. + * @return {Number} a 32 bit integer. + * + * Bit 0 Read-Only + * Bit 1 Hidden + * Bit 2 System + * Bit 3 Volume Label + * Bit 4 Directory + * Bit 5 Archive + */ +var generateDosExternalFileAttr = function (dosPermissions) { + // the dir flag is already set for compatibility + return (dosPermissions || 0) & 0x3F; +}; + +/** + * Generate the various parts used in the construction of the final zip file. + * @param {Object} streamInfo the hash with information about the compressed file. + * @param {Boolean} streamedContent is the content streamed ? + * @param {Boolean} streamingEnded is the stream finished ? + * @param {number} offset the current offset from the start of the zip file. + * @param {String} platform let's pretend we are this platform (change platform dependents fields) + * @param {Function} encodeFileName the function to encode the file name / comment. + * @return {Object} the zip parts. + */ +var generateZipParts = function(streamInfo, streamedContent, streamingEnded, offset, platform, encodeFileName) { + var file = streamInfo["file"], + compression = streamInfo["compression"], + useCustomEncoding = encodeFileName !== utf8.utf8encode, + encodedFileName = utils.transformTo("string", encodeFileName(file.name)), + utfEncodedFileName = utils.transformTo("string", utf8.utf8encode(file.name)), + comment = file.comment, + encodedComment = utils.transformTo("string", encodeFileName(comment)), + utfEncodedComment = utils.transformTo("string", utf8.utf8encode(comment)), + useUTF8ForFileName = utfEncodedFileName.length !== file.name.length, + useUTF8ForComment = utfEncodedComment.length !== comment.length, + dosTime, + dosDate, + extraFields = "", + unicodePathExtraField = "", + unicodeCommentExtraField = "", + dir = file.dir, + date = file.date; + + + var dataInfo = { + crc32 : 0, + compressedSize : 0, + uncompressedSize : 0 + }; + + // if the content is streamed, the sizes/crc32 are only available AFTER + // the end of the stream. + if (!streamedContent || streamingEnded) { + dataInfo.crc32 = streamInfo["crc32"]; + dataInfo.compressedSize = streamInfo["compressedSize"]; + dataInfo.uncompressedSize = streamInfo["uncompressedSize"]; + } + + var bitflag = 0; + if (streamedContent) { + // Bit 3: the sizes/crc32 are set to zero in the local header. + // The correct values are put in the data descriptor immediately + // following the compressed data. + bitflag |= 0x0008; + } + if (!useCustomEncoding && (useUTF8ForFileName || useUTF8ForComment)) { + // Bit 11: Language encoding flag (EFS). + bitflag |= 0x0800; + } + + + var extFileAttr = 0; + var versionMadeBy = 0; + if (dir) { + // dos or unix, we set the dos dir flag + extFileAttr |= 0x00010; + } + if(platform === "UNIX") { + versionMadeBy = 0x031E; // UNIX, version 3.0 + extFileAttr |= generateUnixExternalFileAttr(file.unixPermissions, dir); + } else { // DOS or other, fallback to DOS + versionMadeBy = 0x0014; // DOS, version 2.0 + extFileAttr |= generateDosExternalFileAttr(file.dosPermissions, dir); + } + + // date + // @see http://www.delorie.com/djgpp/doc/rbinter/it/52/13.html + // @see http://www.delorie.com/djgpp/doc/rbinter/it/65/16.html + // @see http://www.delorie.com/djgpp/doc/rbinter/it/66/16.html + + dosTime = date.getUTCHours(); + dosTime = dosTime << 6; + dosTime = dosTime | date.getUTCMinutes(); + dosTime = dosTime << 5; + dosTime = dosTime | date.getUTCSeconds() / 2; + + dosDate = date.getUTCFullYear() - 1980; + dosDate = dosDate << 4; + dosDate = dosDate | (date.getUTCMonth() + 1); + dosDate = dosDate << 5; + dosDate = dosDate | date.getUTCDate(); + + if (useUTF8ForFileName) { + // set the unicode path extra field. unzip needs at least one extra + // field to correctly handle unicode path, so using the path is as good + // as any other information. This could improve the situation with + // other archive managers too. + // This field is usually used without the utf8 flag, with a non + // unicode path in the header (winrar, winzip). This helps (a bit) + // with the messy Windows' default compressed folders feature but + // breaks on p7zip which doesn't seek the unicode path extra field. + // So for now, UTF-8 everywhere ! + unicodePathExtraField = + // Version + decToHex(1, 1) + + // NameCRC32 + decToHex(crc32(encodedFileName), 4) + + // UnicodeName + utfEncodedFileName; + + extraFields += + // Info-ZIP Unicode Path Extra Field + "\x75\x70" + + // size + decToHex(unicodePathExtraField.length, 2) + + // content + unicodePathExtraField; + } + + if(useUTF8ForComment) { + + unicodeCommentExtraField = + // Version + decToHex(1, 1) + + // CommentCRC32 + decToHex(crc32(encodedComment), 4) + + // UnicodeName + utfEncodedComment; + + extraFields += + // Info-ZIP Unicode Path Extra Field + "\x75\x63" + + // size + decToHex(unicodeCommentExtraField.length, 2) + + // content + unicodeCommentExtraField; + } + + var header = ""; + + // version needed to extract + header += "\x0A\x00"; + // general purpose bit flag + header += decToHex(bitflag, 2); + // compression method + header += compression.magic; + // last mod file time + header += decToHex(dosTime, 2); + // last mod file date + header += decToHex(dosDate, 2); + // crc-32 + header += decToHex(dataInfo.crc32, 4); + // compressed size + header += decToHex(dataInfo.compressedSize, 4); + // uncompressed size + header += decToHex(dataInfo.uncompressedSize, 4); + // file name length + header += decToHex(encodedFileName.length, 2); + // extra field length + header += decToHex(extraFields.length, 2); + + + var fileRecord = signature.LOCAL_FILE_HEADER + header + encodedFileName + extraFields; + + var dirRecord = signature.CENTRAL_FILE_HEADER + + // version made by (00: DOS) + decToHex(versionMadeBy, 2) + + // file header (common to file and central directory) + header + + // file comment length + decToHex(encodedComment.length, 2) + + // disk number start + "\x00\x00" + + // internal file attributes TODO + "\x00\x00" + + // external file attributes + decToHex(extFileAttr, 4) + + // relative offset of local header + decToHex(offset, 4) + + // file name + encodedFileName + + // extra field + extraFields + + // file comment + encodedComment; + + return { + fileRecord: fileRecord, + dirRecord: dirRecord + }; +}; + +/** + * Generate the EOCD record. + * @param {Number} entriesCount the number of entries in the zip file. + * @param {Number} centralDirLength the length (in bytes) of the central dir. + * @param {Number} localDirLength the length (in bytes) of the local dir. + * @param {String} comment the zip file comment as a binary string. + * @param {Function} encodeFileName the function to encode the comment. + * @return {String} the EOCD record. + */ +var generateCentralDirectoryEnd = function (entriesCount, centralDirLength, localDirLength, comment, encodeFileName) { + var dirEnd = ""; + var encodedComment = utils.transformTo("string", encodeFileName(comment)); + + // end of central dir signature + dirEnd = signature.CENTRAL_DIRECTORY_END + + // number of this disk + "\x00\x00" + + // number of the disk with the start of the central directory + "\x00\x00" + + // total number of entries in the central directory on this disk + decToHex(entriesCount, 2) + + // total number of entries in the central directory + decToHex(entriesCount, 2) + + // size of the central directory 4 bytes + decToHex(centralDirLength, 4) + + // offset of start of central directory with respect to the starting disk number + decToHex(localDirLength, 4) + + // .ZIP file comment length + decToHex(encodedComment.length, 2) + + // .ZIP file comment + encodedComment; + + return dirEnd; +}; + +/** + * Generate data descriptors for a file entry. + * @param {Object} streamInfo the hash generated by a worker, containing information + * on the file entry. + * @return {String} the data descriptors. + */ +var generateDataDescriptors = function (streamInfo) { + var descriptor = ""; + descriptor = signature.DATA_DESCRIPTOR + + // crc-32 4 bytes + decToHex(streamInfo["crc32"], 4) + + // compressed size 4 bytes + decToHex(streamInfo["compressedSize"], 4) + + // uncompressed size 4 bytes + decToHex(streamInfo["uncompressedSize"], 4); + + return descriptor; +}; + + +/** + * A worker to concatenate other workers to create a zip file. + * @param {Boolean} streamFiles `true` to stream the content of the files, + * `false` to accumulate it. + * @param {String} comment the comment to use. + * @param {String} platform the platform to use, "UNIX" or "DOS". + * @param {Function} encodeFileName the function to encode file names and comments. + */ +function ZipFileWorker(streamFiles, comment, platform, encodeFileName) { + GenericWorker.call(this, "ZipFileWorker"); + // The number of bytes written so far. This doesn't count accumulated chunks. + this.bytesWritten = 0; + // The comment of the zip file + this.zipComment = comment; + // The platform "generating" the zip file. + this.zipPlatform = platform; + // the function to encode file names and comments. + this.encodeFileName = encodeFileName; + // Should we stream the content of the files ? + this.streamFiles = streamFiles; + // If `streamFiles` is false, we will need to accumulate the content of the + // files to calculate sizes / crc32 (and write them *before* the content). + // This boolean indicates if we are accumulating chunks (it will change a lot + // during the lifetime of this worker). + this.accumulate = false; + // The buffer receiving chunks when accumulating content. + this.contentBuffer = []; + // The list of generated directory records. + this.dirRecords = []; + // The offset (in bytes) from the beginning of the zip file for the current source. + this.currentSourceOffset = 0; + // The total number of entries in this zip file. + this.entriesCount = 0; + // the name of the file currently being added, null when handling the end of the zip file. + // Used for the emitted metadata. + this.currentFile = null; + + + + this._sources = []; +} +utils.inherits(ZipFileWorker, GenericWorker); + +/** + * @see GenericWorker.push + */ +ZipFileWorker.prototype.push = function (chunk) { + + var currentFilePercent = chunk.meta.percent || 0; + var entriesCount = this.entriesCount; + var remainingFiles = this._sources.length; + + if(this.accumulate) { + this.contentBuffer.push(chunk); + } else { + this.bytesWritten += chunk.data.length; + + GenericWorker.prototype.push.call(this, { + data : chunk.data, + meta : { + currentFile : this.currentFile, + percent : entriesCount ? (currentFilePercent + 100 * (entriesCount - remainingFiles - 1)) / entriesCount : 100 + } + }); + } +}; + +/** + * The worker started a new source (an other worker). + * @param {Object} streamInfo the streamInfo object from the new source. + */ +ZipFileWorker.prototype.openedSource = function (streamInfo) { + this.currentSourceOffset = this.bytesWritten; + this.currentFile = streamInfo["file"].name; + + var streamedContent = this.streamFiles && !streamInfo["file"].dir; + + // don't stream folders (because they don't have any content) + if(streamedContent) { + var record = generateZipParts(streamInfo, streamedContent, false, this.currentSourceOffset, this.zipPlatform, this.encodeFileName); + this.push({ + data : record.fileRecord, + meta : {percent:0} + }); + } else { + // we need to wait for the whole file before pushing anything + this.accumulate = true; + } +}; + +/** + * The worker finished a source (an other worker). + * @param {Object} streamInfo the streamInfo object from the finished source. + */ +ZipFileWorker.prototype.closedSource = function (streamInfo) { + this.accumulate = false; + var streamedContent = this.streamFiles && !streamInfo["file"].dir; + var record = generateZipParts(streamInfo, streamedContent, true, this.currentSourceOffset, this.zipPlatform, this.encodeFileName); + + this.dirRecords.push(record.dirRecord); + if(streamedContent) { + // after the streamed file, we put data descriptors + this.push({ + data : generateDataDescriptors(streamInfo), + meta : {percent:100} + }); + } else { + // the content wasn't streamed, we need to push everything now + // first the file record, then the content + this.push({ + data : record.fileRecord, + meta : {percent:0} + }); + while(this.contentBuffer.length) { + this.push(this.contentBuffer.shift()); + } + } + this.currentFile = null; +}; + +/** + * @see GenericWorker.flush + */ +ZipFileWorker.prototype.flush = function () { + + var localDirLength = this.bytesWritten; + for(var i = 0; i < this.dirRecords.length; i++) { + this.push({ + data : this.dirRecords[i], + meta : {percent:100} + }); + } + var centralDirLength = this.bytesWritten - localDirLength; + + var dirEnd = generateCentralDirectoryEnd(this.dirRecords.length, centralDirLength, localDirLength, this.zipComment, this.encodeFileName); + + this.push({ + data : dirEnd, + meta : {percent:100} + }); +}; + +/** + * Prepare the next source to be read. + */ +ZipFileWorker.prototype.prepareNextSource = function () { + this.previous = this._sources.shift(); + this.openedSource(this.previous.streamInfo); + if (this.isPaused) { + this.previous.pause(); + } else { + this.previous.resume(); + } +}; + +/** + * @see GenericWorker.registerPrevious + */ +ZipFileWorker.prototype.registerPrevious = function (previous) { + this._sources.push(previous); + var self = this; + + previous.on("data", function (chunk) { + self.processChunk(chunk); + }); + previous.on("end", function () { + self.closedSource(self.previous.streamInfo); + if(self._sources.length) { + self.prepareNextSource(); + } else { + self.end(); + } + }); + previous.on("error", function (e) { + self.error(e); + }); + return this; +}; + +/** + * @see GenericWorker.resume + */ +ZipFileWorker.prototype.resume = function () { + if(!GenericWorker.prototype.resume.call(this)) { + return false; + } + + if (!this.previous && this._sources.length) { + this.prepareNextSource(); + return true; + } + if (!this.previous && !this._sources.length && !this.generatedError) { + this.end(); + return true; + } +}; + +/** + * @see GenericWorker.error + */ +ZipFileWorker.prototype.error = function (e) { + var sources = this._sources; + if(!GenericWorker.prototype.error.call(this, e)) { + return false; + } + for(var i = 0; i < sources.length; i++) { + try { + sources[i].error(e); + } catch(e) { + // the `error` exploded, nothing to do + } + } + return true; +}; + +/** + * @see GenericWorker.lock + */ +ZipFileWorker.prototype.lock = function () { + GenericWorker.prototype.lock.call(this); + var sources = this._sources; + for(var i = 0; i < sources.length; i++) { + sources[i].lock(); + } +}; + +module.exports = ZipFileWorker; + +},{"../crc32":4,"../signature":23,"../stream/GenericWorker":28,"../utf8":31,"../utils":32}],9:[function(require,module,exports){ +"use strict"; + +var compressions = require("../compressions"); +var ZipFileWorker = require("./ZipFileWorker"); + +/** + * Find the compression to use. + * @param {String} fileCompression the compression defined at the file level, if any. + * @param {String} zipCompression the compression defined at the load() level. + * @return {Object} the compression object to use. + */ +var getCompression = function (fileCompression, zipCompression) { + + var compressionName = fileCompression || zipCompression; + var compression = compressions[compressionName]; + if (!compression) { + throw new Error(compressionName + " is not a valid compression method !"); + } + return compression; +}; + +/** + * Create a worker to generate a zip file. + * @param {JSZip} zip the JSZip instance at the right root level. + * @param {Object} options to generate the zip file. + * @param {String} comment the comment to use. + */ +exports.generateWorker = function (zip, options, comment) { + + var zipFileWorker = new ZipFileWorker(options.streamFiles, comment, options.platform, options.encodeFileName); + var entriesCount = 0; + try { + + zip.forEach(function (relativePath, file) { + entriesCount++; + var compression = getCompression(file.options.compression, options.compression); + var compressionOptions = file.options.compressionOptions || options.compressionOptions || {}; + var dir = file.dir, date = file.date; + + file._compressWorker(compression, compressionOptions) + .withStreamInfo("file", { + name : relativePath, + dir : dir, + date : date, + comment : file.comment || "", + unixPermissions : file.unixPermissions, + dosPermissions : file.dosPermissions + }) + .pipe(zipFileWorker); + }); + zipFileWorker.entriesCount = entriesCount; + } catch (e) { + zipFileWorker.error(e); + } + + return zipFileWorker; +}; + +},{"../compressions":3,"./ZipFileWorker":8}],10:[function(require,module,exports){ +"use strict"; + +/** + * Representation a of zip file in js + * @constructor + */ +function JSZip() { + // if this constructor is used without `new`, it adds `new` before itself: + if(!(this instanceof JSZip)) { + return new JSZip(); + } + + if(arguments.length) { + throw new Error("The constructor with parameters has been removed in JSZip 3.0, please check the upgrade guide."); + } + + // object containing the files : + // { + // "folder/" : {...}, + // "folder/data.txt" : {...} + // } + // NOTE: we use a null prototype because we do not + // want filenames like "toString" coming from a zip file + // to overwrite methods and attributes in a normal Object. + this.files = Object.create(null); + + this.comment = null; + + // Where we are in the hierarchy + this.root = ""; + this.clone = function() { + var newObj = new JSZip(); + for (var i in this) { + if (typeof this[i] !== "function") { + newObj[i] = this[i]; + } + } + return newObj; + }; +} +JSZip.prototype = require("./object"); +JSZip.prototype.loadAsync = require("./load"); +JSZip.support = require("./support"); +JSZip.defaults = require("./defaults"); + +// TODO find a better way to handle this version, +// a require('package.json').version doesn't work with webpack, see #327 +JSZip.version = "3.10.1"; + +JSZip.loadAsync = function (content, options) { + return new JSZip().loadAsync(content, options); +}; + +JSZip.external = require("./external"); +module.exports = JSZip; + +},{"./defaults":5,"./external":6,"./load":11,"./object":15,"./support":30}],11:[function(require,module,exports){ +"use strict"; +var utils = require("./utils"); +var external = require("./external"); +var utf8 = require("./utf8"); +var ZipEntries = require("./zipEntries"); +var Crc32Probe = require("./stream/Crc32Probe"); +var nodejsUtils = require("./nodejsUtils"); + +/** + * Check the CRC32 of an entry. + * @param {ZipEntry} zipEntry the zip entry to check. + * @return {Promise} the result. + */ +function checkEntryCRC32(zipEntry) { + return new external.Promise(function (resolve, reject) { + var worker = zipEntry.decompressed.getContentWorker().pipe(new Crc32Probe()); + worker.on("error", function (e) { + reject(e); + }) + .on("end", function () { + if (worker.streamInfo.crc32 !== zipEntry.decompressed.crc32) { + reject(new Error("Corrupted zip : CRC32 mismatch")); + } else { + resolve(); + } + }) + .resume(); + }); +} + +module.exports = function (data, options) { + var zip = this; + options = utils.extend(options || {}, { + base64: false, + checkCRC32: false, + optimizedBinaryString: false, + createFolders: false, + decodeFileName: utf8.utf8decode + }); + + if (nodejsUtils.isNode && nodejsUtils.isStream(data)) { + return external.Promise.reject(new Error("JSZip can't accept a stream when loading a zip file.")); + } + + return utils.prepareContent("the loaded zip file", data, true, options.optimizedBinaryString, options.base64) + .then(function (data) { + var zipEntries = new ZipEntries(options); + zipEntries.load(data); + return zipEntries; + }).then(function checkCRC32(zipEntries) { + var promises = [external.Promise.resolve(zipEntries)]; + var files = zipEntries.files; + if (options.checkCRC32) { + for (var i = 0; i < files.length; i++) { + promises.push(checkEntryCRC32(files[i])); + } + } + return external.Promise.all(promises); + }).then(function addFiles(results) { + var zipEntries = results.shift(); + var files = zipEntries.files; + for (var i = 0; i < files.length; i++) { + var input = files[i]; + + var unsafeName = input.fileNameStr; + var safeName = utils.resolve(input.fileNameStr); + + zip.file(safeName, input.decompressed, { + binary: true, + optimizedBinaryString: true, + date: input.date, + dir: input.dir, + comment: input.fileCommentStr.length ? input.fileCommentStr : null, + unixPermissions: input.unixPermissions, + dosPermissions: input.dosPermissions, + createFolders: options.createFolders + }); + if (!input.dir) { + zip.file(safeName).unsafeOriginalName = unsafeName; + } + } + if (zipEntries.zipComment.length) { + zip.comment = zipEntries.zipComment; + } + + return zip; + }); +}; + +},{"./external":6,"./nodejsUtils":14,"./stream/Crc32Probe":25,"./utf8":31,"./utils":32,"./zipEntries":33}],12:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("../stream/GenericWorker"); + +/** + * A worker that use a nodejs stream as source. + * @constructor + * @param {String} filename the name of the file entry for this stream. + * @param {Readable} stream the nodejs stream. + */ +function NodejsStreamInputAdapter(filename, stream) { + GenericWorker.call(this, "Nodejs stream input adapter for " + filename); + this._upstreamEnded = false; + this._bindStream(stream); +} + +utils.inherits(NodejsStreamInputAdapter, GenericWorker); + +/** + * Prepare the stream and bind the callbacks on it. + * Do this ASAP on node 0.10 ! A lazy binding doesn't always work. + * @param {Stream} stream the nodejs stream to use. + */ +NodejsStreamInputAdapter.prototype._bindStream = function (stream) { + var self = this; + this._stream = stream; + stream.pause(); + stream + .on("data", function (chunk) { + self.push({ + data: chunk, + meta : { + percent : 0 + } + }); + }) + .on("error", function (e) { + if(self.isPaused) { + this.generatedError = e; + } else { + self.error(e); + } + }) + .on("end", function () { + if(self.isPaused) { + self._upstreamEnded = true; + } else { + self.end(); + } + }); +}; +NodejsStreamInputAdapter.prototype.pause = function () { + if(!GenericWorker.prototype.pause.call(this)) { + return false; + } + this._stream.pause(); + return true; +}; +NodejsStreamInputAdapter.prototype.resume = function () { + if(!GenericWorker.prototype.resume.call(this)) { + return false; + } + + if(this._upstreamEnded) { + this.end(); + } else { + this._stream.resume(); + } + + return true; +}; + +module.exports = NodejsStreamInputAdapter; + +},{"../stream/GenericWorker":28,"../utils":32}],13:[function(require,module,exports){ +"use strict"; + +var Readable = require("readable-stream").Readable; + +var utils = require("../utils"); +utils.inherits(NodejsStreamOutputAdapter, Readable); + +/** +* A nodejs stream using a worker as source. +* @see the SourceWrapper in http://nodejs.org/api/stream.html +* @constructor +* @param {StreamHelper} helper the helper wrapping the worker +* @param {Object} options the nodejs stream options +* @param {Function} updateCb the update callback. +*/ +function NodejsStreamOutputAdapter(helper, options, updateCb) { + Readable.call(this, options); + this._helper = helper; + + var self = this; + helper.on("data", function (data, meta) { + if (!self.push(data)) { + self._helper.pause(); + } + if(updateCb) { + updateCb(meta); + } + }) + .on("error", function(e) { + self.emit("error", e); + }) + .on("end", function () { + self.push(null); + }); +} + + +NodejsStreamOutputAdapter.prototype._read = function() { + this._helper.resume(); +}; + +module.exports = NodejsStreamOutputAdapter; + +},{"../utils":32,"readable-stream":16}],14:[function(require,module,exports){ +"use strict"; + +module.exports = { + /** + * True if this is running in Nodejs, will be undefined in a browser. + * In a browser, browserify won't include this file and the whole module + * will be resolved an empty object. + */ + isNode : typeof Buffer !== "undefined", + /** + * Create a new nodejs Buffer from an existing content. + * @param {Object} data the data to pass to the constructor. + * @param {String} encoding the encoding to use. + * @return {Buffer} a new Buffer. + */ + newBufferFrom: function(data, encoding) { + if (Buffer.from && Buffer.from !== Uint8Array.from) { + return Buffer.from(data, encoding); + } else { + if (typeof data === "number") { + // Safeguard for old Node.js versions. On newer versions, + // Buffer.from(number) / Buffer(number, encoding) already throw. + throw new Error("The \"data\" argument must not be a number"); + } + return new Buffer(data, encoding); + } + }, + /** + * Create a new nodejs Buffer with the specified size. + * @param {Integer} size the size of the buffer. + * @return {Buffer} a new Buffer. + */ + allocBuffer: function (size) { + if (Buffer.alloc) { + return Buffer.alloc(size); + } else { + var buf = new Buffer(size); + buf.fill(0); + return buf; + } + }, + /** + * Find out if an object is a Buffer. + * @param {Object} b the object to test. + * @return {Boolean} true if the object is a Buffer, false otherwise. + */ + isBuffer : function(b){ + return Buffer.isBuffer(b); + }, + + isStream : function (obj) { + return obj && + typeof obj.on === "function" && + typeof obj.pause === "function" && + typeof obj.resume === "function"; + } +}; + +},{}],15:[function(require,module,exports){ +"use strict"; +var utf8 = require("./utf8"); +var utils = require("./utils"); +var GenericWorker = require("./stream/GenericWorker"); +var StreamHelper = require("./stream/StreamHelper"); +var defaults = require("./defaults"); +var CompressedObject = require("./compressedObject"); +var ZipObject = require("./zipObject"); +var generate = require("./generate"); +var nodejsUtils = require("./nodejsUtils"); +var NodejsStreamInputAdapter = require("./nodejs/NodejsStreamInputAdapter"); + + +/** + * Add a file in the current folder. + * @private + * @param {string} name the name of the file + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the data of the file + * @param {Object} originalOptions the options of the file + * @return {Object} the new file. + */ +var fileAdd = function(name, data, originalOptions) { + // be sure sub folders exist + var dataType = utils.getTypeOf(data), + parent; + + + /* + * Correct options. + */ + + var o = utils.extend(originalOptions || {}, defaults); + o.date = o.date || new Date(); + if (o.compression !== null) { + o.compression = o.compression.toUpperCase(); + } + + if (typeof o.unixPermissions === "string") { + o.unixPermissions = parseInt(o.unixPermissions, 8); + } + + // UNX_IFDIR 0040000 see zipinfo.c + if (o.unixPermissions && (o.unixPermissions & 0x4000)) { + o.dir = true; + } + // Bit 4 Directory + if (o.dosPermissions && (o.dosPermissions & 0x0010)) { + o.dir = true; + } + + if (o.dir) { + name = forceTrailingSlash(name); + } + if (o.createFolders && (parent = parentFolder(name))) { + folderAdd.call(this, parent, true); + } + + var isUnicodeString = dataType === "string" && o.binary === false && o.base64 === false; + if (!originalOptions || typeof originalOptions.binary === "undefined") { + o.binary = !isUnicodeString; + } + + + var isCompressedEmpty = (data instanceof CompressedObject) && data.uncompressedSize === 0; + + if (isCompressedEmpty || o.dir || !data || data.length === 0) { + o.base64 = false; + o.binary = true; + data = ""; + o.compression = "STORE"; + dataType = "string"; + } + + /* + * Convert content to fit. + */ + + var zipObjectContent = null; + if (data instanceof CompressedObject || data instanceof GenericWorker) { + zipObjectContent = data; + } else if (nodejsUtils.isNode && nodejsUtils.isStream(data)) { + zipObjectContent = new NodejsStreamInputAdapter(name, data); + } else { + zipObjectContent = utils.prepareContent(name, data, o.binary, o.optimizedBinaryString, o.base64); + } + + var object = new ZipObject(name, zipObjectContent, o); + this.files[name] = object; + /* + TODO: we can't throw an exception because we have async promises + (we can have a promise of a Date() for example) but returning a + promise is useless because file(name, data) returns the JSZip + object for chaining. Should we break that to allow the user + to catch the error ? + + return external.Promise.resolve(zipObjectContent) + .then(function () { + return object; + }); + */ +}; + +/** + * Find the parent folder of the path. + * @private + * @param {string} path the path to use + * @return {string} the parent folder, or "" + */ +var parentFolder = function (path) { + if (path.slice(-1) === "/") { + path = path.substring(0, path.length - 1); + } + var lastSlash = path.lastIndexOf("/"); + return (lastSlash > 0) ? path.substring(0, lastSlash) : ""; +}; + +/** + * Returns the path with a slash at the end. + * @private + * @param {String} path the path to check. + * @return {String} the path with a trailing slash. + */ +var forceTrailingSlash = function(path) { + // Check the name ends with a / + if (path.slice(-1) !== "/") { + path += "/"; // IE doesn't like substr(-1) + } + return path; +}; + +/** + * Add a (sub) folder in the current folder. + * @private + * @param {string} name the folder's name + * @param {boolean=} [createFolders] If true, automatically create sub + * folders. Defaults to false. + * @return {Object} the new folder. + */ +var folderAdd = function(name, createFolders) { + createFolders = (typeof createFolders !== "undefined") ? createFolders : defaults.createFolders; + + name = forceTrailingSlash(name); + + // Does this folder already exist? + if (!this.files[name]) { + fileAdd.call(this, name, null, { + dir: true, + createFolders: createFolders + }); + } + return this.files[name]; +}; + +/** +* Cross-window, cross-Node-context regular expression detection +* @param {Object} object Anything +* @return {Boolean} true if the object is a regular expression, +* false otherwise +*/ +function isRegExp(object) { + return Object.prototype.toString.call(object) === "[object RegExp]"; +} + +// return the actual prototype of JSZip +var out = { + /** + * @see loadAsync + */ + load: function() { + throw new Error("This method has been removed in JSZip 3.0, please check the upgrade guide."); + }, + + + /** + * Call a callback function for each entry at this folder level. + * @param {Function} cb the callback function: + * function (relativePath, file) {...} + * It takes 2 arguments : the relative path and the file. + */ + forEach: function(cb) { + var filename, relativePath, file; + // ignore warning about unwanted properties because this.files is a null prototype object + /* eslint-disable-next-line guard-for-in */ + for (filename in this.files) { + file = this.files[filename]; + relativePath = filename.slice(this.root.length, filename.length); + if (relativePath && filename.slice(0, this.root.length) === this.root) { // the file is in the current root + cb(relativePath, file); // TODO reverse the parameters ? need to be clean AND consistent with the filter search fn... + } + } + }, + + /** + * Filter nested files/folders with the specified function. + * @param {Function} search the predicate to use : + * function (relativePath, file) {...} + * It takes 2 arguments : the relative path and the file. + * @return {Array} An array of matching elements. + */ + filter: function(search) { + var result = []; + this.forEach(function (relativePath, entry) { + if (search(relativePath, entry)) { // the file matches the function + result.push(entry); + } + + }); + return result; + }, + + /** + * Add a file to the zip file, or search a file. + * @param {string|RegExp} name The name of the file to add (if data is defined), + * the name of the file to find (if no data) or a regex to match files. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data The file data, either raw or base64 encoded + * @param {Object} o File options + * @return {JSZip|Object|Array} this JSZip object (when adding a file), + * a file (when searching by string) or an array of files (when searching by regex). + */ + file: function(name, data, o) { + if (arguments.length === 1) { + if (isRegExp(name)) { + var regexp = name; + return this.filter(function(relativePath, file) { + return !file.dir && regexp.test(relativePath); + }); + } + else { // text + var obj = this.files[this.root + name]; + if (obj && !obj.dir) { + return obj; + } else { + return null; + } + } + } + else { // more than one argument : we have data ! + name = this.root + name; + fileAdd.call(this, name, data, o); + } + return this; + }, + + /** + * Add a directory to the zip file, or search. + * @param {String|RegExp} arg The name of the directory to add, or a regex to search folders. + * @return {JSZip} an object with the new directory as the root, or an array containing matching folders. + */ + folder: function(arg) { + if (!arg) { + return this; + } + + if (isRegExp(arg)) { + return this.filter(function(relativePath, file) { + return file.dir && arg.test(relativePath); + }); + } + + // else, name is a new folder + var name = this.root + arg; + var newFolder = folderAdd.call(this, name); + + // Allow chaining by returning a new object with this folder as the root + var ret = this.clone(); + ret.root = newFolder.name; + return ret; + }, + + /** + * Delete a file, or a directory and all sub-files, from the zip + * @param {string} name the name of the file to delete + * @return {JSZip} this JSZip object + */ + remove: function(name) { + name = this.root + name; + var file = this.files[name]; + if (!file) { + // Look for any folders + if (name.slice(-1) !== "/") { + name += "/"; + } + file = this.files[name]; + } + + if (file && !file.dir) { + // file + delete this.files[name]; + } else { + // maybe a folder, delete recursively + var kids = this.filter(function(relativePath, file) { + return file.name.slice(0, name.length) === name; + }); + for (var i = 0; i < kids.length; i++) { + delete this.files[kids[i].name]; + } + } + + return this; + }, + + /** + * @deprecated This method has been removed in JSZip 3.0, please check the upgrade guide. + */ + generate: function() { + throw new Error("This method has been removed in JSZip 3.0, please check the upgrade guide."); + }, + + /** + * Generate the complete zip file as an internal stream. + * @param {Object} options the options to generate the zip file : + * - compression, "STORE" by default. + * - type, "base64" by default. Values are : string, base64, uint8array, arraybuffer, blob. + * @return {StreamHelper} the streamed zip file. + */ + generateInternalStream: function(options) { + var worker, opts = {}; + try { + opts = utils.extend(options || {}, { + streamFiles: false, + compression: "STORE", + compressionOptions : null, + type: "", + platform: "DOS", + comment: null, + mimeType: "application/zip", + encodeFileName: utf8.utf8encode + }); + + opts.type = opts.type.toLowerCase(); + opts.compression = opts.compression.toUpperCase(); + + // "binarystring" is preferred but the internals use "string". + if(opts.type === "binarystring") { + opts.type = "string"; + } + + if (!opts.type) { + throw new Error("No output type specified."); + } + + utils.checkSupport(opts.type); + + // accept nodejs `process.platform` + if( + opts.platform === "darwin" || + opts.platform === "freebsd" || + opts.platform === "linux" || + opts.platform === "sunos" + ) { + opts.platform = "UNIX"; + } + if (opts.platform === "win32") { + opts.platform = "DOS"; + } + + var comment = opts.comment || this.comment || ""; + worker = generate.generateWorker(this, opts, comment); + } catch (e) { + worker = new GenericWorker("error"); + worker.error(e); + } + return new StreamHelper(worker, opts.type || "string", opts.mimeType); + }, + /** + * Generate the complete zip file asynchronously. + * @see generateInternalStream + */ + generateAsync: function(options, onUpdate) { + return this.generateInternalStream(options).accumulate(onUpdate); + }, + /** + * Generate the complete zip file asynchronously. + * @see generateInternalStream + */ + generateNodeStream: function(options, onUpdate) { + options = options || {}; + if (!options.type) { + options.type = "nodebuffer"; + } + return this.generateInternalStream(options).toNodejsStream(onUpdate); + } +}; +module.exports = out; + +},{"./compressedObject":2,"./defaults":5,"./generate":9,"./nodejs/NodejsStreamInputAdapter":12,"./nodejsUtils":14,"./stream/GenericWorker":28,"./stream/StreamHelper":29,"./utf8":31,"./utils":32,"./zipObject":35}],16:[function(require,module,exports){ +"use strict"; +/* + * This file is used by module bundlers (browserify/webpack/etc) when + * including a stream implementation. We use "readable-stream" to get a + * consistent behavior between nodejs versions but bundlers often have a shim + * for "stream". Using this shim greatly improve the compatibility and greatly + * reduce the final size of the bundle (only one stream implementation, not + * two). + */ +module.exports = require("stream"); + +},{"stream":undefined}],17:[function(require,module,exports){ +"use strict"; +var DataReader = require("./DataReader"); +var utils = require("../utils"); + +function ArrayReader(data) { + DataReader.call(this, data); + for(var i = 0; i < this.data.length; i++) { + data[i] = data[i] & 0xFF; + } +} +utils.inherits(ArrayReader, DataReader); +/** + * @see DataReader.byteAt + */ +ArrayReader.prototype.byteAt = function(i) { + return this.data[this.zero + i]; +}; +/** + * @see DataReader.lastIndexOfSignature + */ +ArrayReader.prototype.lastIndexOfSignature = function(sig) { + var sig0 = sig.charCodeAt(0), + sig1 = sig.charCodeAt(1), + sig2 = sig.charCodeAt(2), + sig3 = sig.charCodeAt(3); + for (var i = this.length - 4; i >= 0; --i) { + if (this.data[i] === sig0 && this.data[i + 1] === sig1 && this.data[i + 2] === sig2 && this.data[i + 3] === sig3) { + return i - this.zero; + } + } + + return -1; +}; +/** + * @see DataReader.readAndCheckSignature + */ +ArrayReader.prototype.readAndCheckSignature = function (sig) { + var sig0 = sig.charCodeAt(0), + sig1 = sig.charCodeAt(1), + sig2 = sig.charCodeAt(2), + sig3 = sig.charCodeAt(3), + data = this.readData(4); + return sig0 === data[0] && sig1 === data[1] && sig2 === data[2] && sig3 === data[3]; +}; +/** + * @see DataReader.readData + */ +ArrayReader.prototype.readData = function(size) { + this.checkOffset(size); + if(size === 0) { + return []; + } + var result = this.data.slice(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = ArrayReader; + +},{"../utils":32,"./DataReader":18}],18:[function(require,module,exports){ +"use strict"; +var utils = require("../utils"); + +function DataReader(data) { + this.data = data; // type : see implementation + this.length = data.length; + this.index = 0; + this.zero = 0; +} +DataReader.prototype = { + /** + * Check that the offset will not go too far. + * @param {string} offset the additional offset to check. + * @throws {Error} an Error if the offset is out of bounds. + */ + checkOffset: function(offset) { + this.checkIndex(this.index + offset); + }, + /** + * Check that the specified index will not be too far. + * @param {string} newIndex the index to check. + * @throws {Error} an Error if the index is out of bounds. + */ + checkIndex: function(newIndex) { + if (this.length < this.zero + newIndex || newIndex < 0) { + throw new Error("End of data reached (data length = " + this.length + ", asked index = " + (newIndex) + "). Corrupted zip ?"); + } + }, + /** + * Change the index. + * @param {number} newIndex The new index. + * @throws {Error} if the new index is out of the data. + */ + setIndex: function(newIndex) { + this.checkIndex(newIndex); + this.index = newIndex; + }, + /** + * Skip the next n bytes. + * @param {number} n the number of bytes to skip. + * @throws {Error} if the new index is out of the data. + */ + skip: function(n) { + this.setIndex(this.index + n); + }, + /** + * Get the byte at the specified index. + * @param {number} i the index to use. + * @return {number} a byte. + */ + byteAt: function() { + // see implementations + }, + /** + * Get the next number with a given byte size. + * @param {number} size the number of bytes to read. + * @return {number} the corresponding number. + */ + readInt: function(size) { + var result = 0, + i; + this.checkOffset(size); + for (i = this.index + size - 1; i >= this.index; i--) { + result = (result << 8) + this.byteAt(i); + } + this.index += size; + return result; + }, + /** + * Get the next string with a given byte size. + * @param {number} size the number of bytes to read. + * @return {string} the corresponding string. + */ + readString: function(size) { + return utils.transformTo("string", this.readData(size)); + }, + /** + * Get raw data without conversion, bytes. + * @param {number} size the number of bytes to read. + * @return {Object} the raw data, implementation specific. + */ + readData: function() { + // see implementations + }, + /** + * Find the last occurrence of a zip signature (4 bytes). + * @param {string} sig the signature to find. + * @return {number} the index of the last occurrence, -1 if not found. + */ + lastIndexOfSignature: function() { + // see implementations + }, + /** + * Read the signature (4 bytes) at the current position and compare it with sig. + * @param {string} sig the expected signature + * @return {boolean} true if the signature matches, false otherwise. + */ + readAndCheckSignature: function() { + // see implementations + }, + /** + * Get the next date. + * @return {Date} the date. + */ + readDate: function() { + var dostime = this.readInt(4); + return new Date(Date.UTC( + ((dostime >> 25) & 0x7f) + 1980, // year + ((dostime >> 21) & 0x0f) - 1, // month + (dostime >> 16) & 0x1f, // day + (dostime >> 11) & 0x1f, // hour + (dostime >> 5) & 0x3f, // minute + (dostime & 0x1f) << 1)); // second + } +}; +module.exports = DataReader; + +},{"../utils":32}],19:[function(require,module,exports){ +"use strict"; +var Uint8ArrayReader = require("./Uint8ArrayReader"); +var utils = require("../utils"); + +function NodeBufferReader(data) { + Uint8ArrayReader.call(this, data); +} +utils.inherits(NodeBufferReader, Uint8ArrayReader); + +/** + * @see DataReader.readData + */ +NodeBufferReader.prototype.readData = function(size) { + this.checkOffset(size); + var result = this.data.slice(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = NodeBufferReader; + +},{"../utils":32,"./Uint8ArrayReader":21}],20:[function(require,module,exports){ +"use strict"; +var DataReader = require("./DataReader"); +var utils = require("../utils"); + +function StringReader(data) { + DataReader.call(this, data); +} +utils.inherits(StringReader, DataReader); +/** + * @see DataReader.byteAt + */ +StringReader.prototype.byteAt = function(i) { + return this.data.charCodeAt(this.zero + i); +}; +/** + * @see DataReader.lastIndexOfSignature + */ +StringReader.prototype.lastIndexOfSignature = function(sig) { + return this.data.lastIndexOf(sig) - this.zero; +}; +/** + * @see DataReader.readAndCheckSignature + */ +StringReader.prototype.readAndCheckSignature = function (sig) { + var data = this.readData(4); + return sig === data; +}; +/** + * @see DataReader.readData + */ +StringReader.prototype.readData = function(size) { + this.checkOffset(size); + // this will work because the constructor applied the "& 0xff" mask. + var result = this.data.slice(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = StringReader; + +},{"../utils":32,"./DataReader":18}],21:[function(require,module,exports){ +"use strict"; +var ArrayReader = require("./ArrayReader"); +var utils = require("../utils"); + +function Uint8ArrayReader(data) { + ArrayReader.call(this, data); +} +utils.inherits(Uint8ArrayReader, ArrayReader); +/** + * @see DataReader.readData + */ +Uint8ArrayReader.prototype.readData = function(size) { + this.checkOffset(size); + if(size === 0) { + // in IE10, when using subarray(idx, idx), we get the array [0x00] instead of []. + return new Uint8Array(0); + } + var result = this.data.subarray(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = Uint8ArrayReader; + +},{"../utils":32,"./ArrayReader":17}],22:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var support = require("../support"); +var ArrayReader = require("./ArrayReader"); +var StringReader = require("./StringReader"); +var NodeBufferReader = require("./NodeBufferReader"); +var Uint8ArrayReader = require("./Uint8ArrayReader"); + +/** + * Create a reader adapted to the data. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the data to read. + * @return {DataReader} the data reader. + */ +module.exports = function (data) { + var type = utils.getTypeOf(data); + utils.checkSupport(type); + if (type === "string" && !support.uint8array) { + return new StringReader(data); + } + if (type === "nodebuffer") { + return new NodeBufferReader(data); + } + if (support.uint8array) { + return new Uint8ArrayReader(utils.transformTo("uint8array", data)); + } + return new ArrayReader(utils.transformTo("array", data)); +}; + +},{"../support":30,"../utils":32,"./ArrayReader":17,"./NodeBufferReader":19,"./StringReader":20,"./Uint8ArrayReader":21}],23:[function(require,module,exports){ +"use strict"; +exports.LOCAL_FILE_HEADER = "PK\x03\x04"; +exports.CENTRAL_FILE_HEADER = "PK\x01\x02"; +exports.CENTRAL_DIRECTORY_END = "PK\x05\x06"; +exports.ZIP64_CENTRAL_DIRECTORY_LOCATOR = "PK\x06\x07"; +exports.ZIP64_CENTRAL_DIRECTORY_END = "PK\x06\x06"; +exports.DATA_DESCRIPTOR = "PK\x07\x08"; + +},{}],24:[function(require,module,exports){ +"use strict"; + +var GenericWorker = require("./GenericWorker"); +var utils = require("../utils"); + +/** + * A worker which convert chunks to a specified type. + * @constructor + * @param {String} destType the destination type. + */ +function ConvertWorker(destType) { + GenericWorker.call(this, "ConvertWorker to " + destType); + this.destType = destType; +} +utils.inherits(ConvertWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +ConvertWorker.prototype.processChunk = function (chunk) { + this.push({ + data : utils.transformTo(this.destType, chunk.data), + meta : chunk.meta + }); +}; +module.exports = ConvertWorker; + +},{"../utils":32,"./GenericWorker":28}],25:[function(require,module,exports){ +"use strict"; + +var GenericWorker = require("./GenericWorker"); +var crc32 = require("../crc32"); +var utils = require("../utils"); + +/** + * A worker which calculate the crc32 of the data flowing through. + * @constructor + */ +function Crc32Probe() { + GenericWorker.call(this, "Crc32Probe"); + this.withStreamInfo("crc32", 0); +} +utils.inherits(Crc32Probe, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +Crc32Probe.prototype.processChunk = function (chunk) { + this.streamInfo.crc32 = crc32(chunk.data, this.streamInfo.crc32 || 0); + this.push(chunk); +}; +module.exports = Crc32Probe; + +},{"../crc32":4,"../utils":32,"./GenericWorker":28}],26:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("./GenericWorker"); + +/** + * A worker which calculate the total length of the data flowing through. + * @constructor + * @param {String} propName the name used to expose the length + */ +function DataLengthProbe(propName) { + GenericWorker.call(this, "DataLengthProbe for " + propName); + this.propName = propName; + this.withStreamInfo(propName, 0); +} +utils.inherits(DataLengthProbe, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +DataLengthProbe.prototype.processChunk = function (chunk) { + if(chunk) { + var length = this.streamInfo[this.propName] || 0; + this.streamInfo[this.propName] = length + chunk.data.length; + } + GenericWorker.prototype.processChunk.call(this, chunk); +}; +module.exports = DataLengthProbe; + + +},{"../utils":32,"./GenericWorker":28}],27:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("./GenericWorker"); + +// the size of the generated chunks +// TODO expose this as a public variable +var DEFAULT_BLOCK_SIZE = 16 * 1024; + +/** + * A worker that reads a content and emits chunks. + * @constructor + * @param {Promise} dataP the promise of the data to split + */ +function DataWorker(dataP) { + GenericWorker.call(this, "DataWorker"); + var self = this; + this.dataIsReady = false; + this.index = 0; + this.max = 0; + this.data = null; + this.type = ""; + + this._tickScheduled = false; + + dataP.then(function (data) { + self.dataIsReady = true; + self.data = data; + self.max = data && data.length || 0; + self.type = utils.getTypeOf(data); + if(!self.isPaused) { + self._tickAndRepeat(); + } + }, function (e) { + self.error(e); + }); +} + +utils.inherits(DataWorker, GenericWorker); + +/** + * @see GenericWorker.cleanUp + */ +DataWorker.prototype.cleanUp = function () { + GenericWorker.prototype.cleanUp.call(this); + this.data = null; +}; + +/** + * @see GenericWorker.resume + */ +DataWorker.prototype.resume = function () { + if(!GenericWorker.prototype.resume.call(this)) { + return false; + } + + if (!this._tickScheduled && this.dataIsReady) { + this._tickScheduled = true; + utils.delay(this._tickAndRepeat, [], this); + } + return true; +}; + +/** + * Trigger a tick a schedule an other call to this function. + */ +DataWorker.prototype._tickAndRepeat = function() { + this._tickScheduled = false; + if(this.isPaused || this.isFinished) { + return; + } + this._tick(); + if(!this.isFinished) { + utils.delay(this._tickAndRepeat, [], this); + this._tickScheduled = true; + } +}; + +/** + * Read and push a chunk. + */ +DataWorker.prototype._tick = function() { + + if(this.isPaused || this.isFinished) { + return false; + } + + var size = DEFAULT_BLOCK_SIZE; + var data = null, nextIndex = Math.min(this.max, this.index + size); + if (this.index >= this.max) { + // EOF + return this.end(); + } else { + switch(this.type) { + case "string": + data = this.data.substring(this.index, nextIndex); + break; + case "uint8array": + data = this.data.subarray(this.index, nextIndex); + break; + case "array": + case "nodebuffer": + data = this.data.slice(this.index, nextIndex); + break; + } + this.index = nextIndex; + return this.push({ + data : data, + meta : { + percent : this.max ? this.index / this.max * 100 : 0 + } + }); + } +}; + +module.exports = DataWorker; + +},{"../utils":32,"./GenericWorker":28}],28:[function(require,module,exports){ +"use strict"; + +/** + * A worker that does nothing but passing chunks to the next one. This is like + * a nodejs stream but with some differences. On the good side : + * - it works on IE 6-9 without any issue / polyfill + * - it weights less than the full dependencies bundled with browserify + * - it forwards errors (no need to declare an error handler EVERYWHERE) + * + * A chunk is an object with 2 attributes : `meta` and `data`. The former is an + * object containing anything (`percent` for example), see each worker for more + * details. The latter is the real data (String, Uint8Array, etc). + * + * @constructor + * @param {String} name the name of the stream (mainly used for debugging purposes) + */ +function GenericWorker(name) { + // the name of the worker + this.name = name || "default"; + // an object containing metadata about the workers chain + this.streamInfo = {}; + // an error which happened when the worker was paused + this.generatedError = null; + // an object containing metadata to be merged by this worker into the general metadata + this.extraStreamInfo = {}; + // true if the stream is paused (and should not do anything), false otherwise + this.isPaused = true; + // true if the stream is finished (and should not do anything), false otherwise + this.isFinished = false; + // true if the stream is locked to prevent further structure updates (pipe), false otherwise + this.isLocked = false; + // the event listeners + this._listeners = { + "data":[], + "end":[], + "error":[] + }; + // the previous worker, if any + this.previous = null; +} + +GenericWorker.prototype = { + /** + * Push a chunk to the next workers. + * @param {Object} chunk the chunk to push + */ + push : function (chunk) { + this.emit("data", chunk); + }, + /** + * End the stream. + * @return {Boolean} true if this call ended the worker, false otherwise. + */ + end : function () { + if (this.isFinished) { + return false; + } + + this.flush(); + try { + this.emit("end"); + this.cleanUp(); + this.isFinished = true; + } catch (e) { + this.emit("error", e); + } + return true; + }, + /** + * End the stream with an error. + * @param {Error} e the error which caused the premature end. + * @return {Boolean} true if this call ended the worker with an error, false otherwise. + */ + error : function (e) { + if (this.isFinished) { + return false; + } + + if(this.isPaused) { + this.generatedError = e; + } else { + this.isFinished = true; + + this.emit("error", e); + + // in the workers chain exploded in the middle of the chain, + // the error event will go downward but we also need to notify + // workers upward that there has been an error. + if(this.previous) { + this.previous.error(e); + } + + this.cleanUp(); + } + return true; + }, + /** + * Add a callback on an event. + * @param {String} name the name of the event (data, end, error) + * @param {Function} listener the function to call when the event is triggered + * @return {GenericWorker} the current object for chainability + */ + on : function (name, listener) { + this._listeners[name].push(listener); + return this; + }, + /** + * Clean any references when a worker is ending. + */ + cleanUp : function () { + this.streamInfo = this.generatedError = this.extraStreamInfo = null; + this._listeners = []; + }, + /** + * Trigger an event. This will call registered callback with the provided arg. + * @param {String} name the name of the event (data, end, error) + * @param {Object} arg the argument to call the callback with. + */ + emit : function (name, arg) { + if (this._listeners[name]) { + for(var i = 0; i < this._listeners[name].length; i++) { + this._listeners[name][i].call(this, arg); + } + } + }, + /** + * Chain a worker with an other. + * @param {Worker} next the worker receiving events from the current one. + * @return {worker} the next worker for chainability + */ + pipe : function (next) { + return next.registerPrevious(this); + }, + /** + * Same as `pipe` in the other direction. + * Using an API with `pipe(next)` is very easy. + * Implementing the API with the point of view of the next one registering + * a source is easier, see the ZipFileWorker. + * @param {Worker} previous the previous worker, sending events to this one + * @return {Worker} the current worker for chainability + */ + registerPrevious : function (previous) { + if (this.isLocked) { + throw new Error("The stream '" + this + "' has already been used."); + } + + // sharing the streamInfo... + this.streamInfo = previous.streamInfo; + // ... and adding our own bits + this.mergeStreamInfo(); + this.previous = previous; + var self = this; + previous.on("data", function (chunk) { + self.processChunk(chunk); + }); + previous.on("end", function () { + self.end(); + }); + previous.on("error", function (e) { + self.error(e); + }); + return this; + }, + /** + * Pause the stream so it doesn't send events anymore. + * @return {Boolean} true if this call paused the worker, false otherwise. + */ + pause : function () { + if(this.isPaused || this.isFinished) { + return false; + } + this.isPaused = true; + + if(this.previous) { + this.previous.pause(); + } + return true; + }, + /** + * Resume a paused stream. + * @return {Boolean} true if this call resumed the worker, false otherwise. + */ + resume : function () { + if(!this.isPaused || this.isFinished) { + return false; + } + this.isPaused = false; + + // if true, the worker tried to resume but failed + var withError = false; + if(this.generatedError) { + this.error(this.generatedError); + withError = true; + } + if(this.previous) { + this.previous.resume(); + } + + return !withError; + }, + /** + * Flush any remaining bytes as the stream is ending. + */ + flush : function () {}, + /** + * Process a chunk. This is usually the method overridden. + * @param {Object} chunk the chunk to process. + */ + processChunk : function(chunk) { + this.push(chunk); + }, + /** + * Add a key/value to be added in the workers chain streamInfo once activated. + * @param {String} key the key to use + * @param {Object} value the associated value + * @return {Worker} the current worker for chainability + */ + withStreamInfo : function (key, value) { + this.extraStreamInfo[key] = value; + this.mergeStreamInfo(); + return this; + }, + /** + * Merge this worker's streamInfo into the chain's streamInfo. + */ + mergeStreamInfo : function () { + for(var key in this.extraStreamInfo) { + if (!Object.prototype.hasOwnProperty.call(this.extraStreamInfo, key)) { + continue; + } + this.streamInfo[key] = this.extraStreamInfo[key]; + } + }, + + /** + * Lock the stream to prevent further updates on the workers chain. + * After calling this method, all calls to pipe will fail. + */ + lock: function () { + if (this.isLocked) { + throw new Error("The stream '" + this + "' has already been used."); + } + this.isLocked = true; + if (this.previous) { + this.previous.lock(); + } + }, + + /** + * + * Pretty print the workers chain. + */ + toString : function () { + var me = "Worker " + this.name; + if (this.previous) { + return this.previous + " -> " + me; + } else { + return me; + } + } +}; + +module.exports = GenericWorker; + +},{}],29:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var ConvertWorker = require("./ConvertWorker"); +var GenericWorker = require("./GenericWorker"); +var base64 = require("../base64"); +var support = require("../support"); +var external = require("../external"); + +var NodejsStreamOutputAdapter = null; +if (support.nodestream) { + try { + NodejsStreamOutputAdapter = require("../nodejs/NodejsStreamOutputAdapter"); + } catch(e) { + // ignore + } +} + +/** + * Apply the final transformation of the data. If the user wants a Blob for + * example, it's easier to work with an U8intArray and finally do the + * ArrayBuffer/Blob conversion. + * @param {String} type the name of the final type + * @param {String|Uint8Array|Buffer} content the content to transform + * @param {String} mimeType the mime type of the content, if applicable. + * @return {String|Uint8Array|ArrayBuffer|Buffer|Blob} the content in the right format. + */ +function transformZipOutput(type, content, mimeType) { + switch(type) { + case "blob" : + return utils.newBlob(utils.transformTo("arraybuffer", content), mimeType); + case "base64" : + return base64.encode(content); + default : + return utils.transformTo(type, content); + } +} + +/** + * Concatenate an array of data of the given type. + * @param {String} type the type of the data in the given array. + * @param {Array} dataArray the array containing the data chunks to concatenate + * @return {String|Uint8Array|Buffer} the concatenated data + * @throws Error if the asked type is unsupported + */ +function concat (type, dataArray) { + var i, index = 0, res = null, totalLength = 0; + for(i = 0; i < dataArray.length; i++) { + totalLength += dataArray[i].length; + } + switch(type) { + case "string": + return dataArray.join(""); + case "array": + return Array.prototype.concat.apply([], dataArray); + case "uint8array": + res = new Uint8Array(totalLength); + for(i = 0; i < dataArray.length; i++) { + res.set(dataArray[i], index); + index += dataArray[i].length; + } + return res; + case "nodebuffer": + return Buffer.concat(dataArray); + default: + throw new Error("concat : unsupported type '" + type + "'"); + } +} + +/** + * Listen a StreamHelper, accumulate its content and concatenate it into a + * complete block. + * @param {StreamHelper} helper the helper to use. + * @param {Function} updateCallback a callback called on each update. Called + * with one arg : + * - the metadata linked to the update received. + * @return Promise the promise for the accumulation. + */ +function accumulate(helper, updateCallback) { + return new external.Promise(function (resolve, reject){ + var dataArray = []; + var chunkType = helper._internalType, + resultType = helper._outputType, + mimeType = helper._mimeType; + helper + .on("data", function (data, meta) { + dataArray.push(data); + if(updateCallback) { + updateCallback(meta); + } + }) + .on("error", function(err) { + dataArray = []; + reject(err); + }) + .on("end", function (){ + try { + var result = transformZipOutput(resultType, concat(chunkType, dataArray), mimeType); + resolve(result); + } catch (e) { + reject(e); + } + dataArray = []; + }) + .resume(); + }); +} + +/** + * An helper to easily use workers outside of JSZip. + * @constructor + * @param {Worker} worker the worker to wrap + * @param {String} outputType the type of data expected by the use + * @param {String} mimeType the mime type of the content, if applicable. + */ +function StreamHelper(worker, outputType, mimeType) { + var internalType = outputType; + switch(outputType) { + case "blob": + case "arraybuffer": + internalType = "uint8array"; + break; + case "base64": + internalType = "string"; + break; + } + + try { + // the type used internally + this._internalType = internalType; + // the type used to output results + this._outputType = outputType; + // the mime type + this._mimeType = mimeType; + utils.checkSupport(internalType); + this._worker = worker.pipe(new ConvertWorker(internalType)); + // the last workers can be rewired without issues but we need to + // prevent any updates on previous workers. + worker.lock(); + } catch(e) { + this._worker = new GenericWorker("error"); + this._worker.error(e); + } +} + +StreamHelper.prototype = { + /** + * Listen a StreamHelper, accumulate its content and concatenate it into a + * complete block. + * @param {Function} updateCb the update callback. + * @return Promise the promise for the accumulation. + */ + accumulate : function (updateCb) { + return accumulate(this, updateCb); + }, + /** + * Add a listener on an event triggered on a stream. + * @param {String} evt the name of the event + * @param {Function} fn the listener + * @return {StreamHelper} the current helper. + */ + on : function (evt, fn) { + var self = this; + + if(evt === "data") { + this._worker.on(evt, function (chunk) { + fn.call(self, chunk.data, chunk.meta); + }); + } else { + this._worker.on(evt, function () { + utils.delay(fn, arguments, self); + }); + } + return this; + }, + /** + * Resume the flow of chunks. + * @return {StreamHelper} the current helper. + */ + resume : function () { + utils.delay(this._worker.resume, [], this._worker); + return this; + }, + /** + * Pause the flow of chunks. + * @return {StreamHelper} the current helper. + */ + pause : function () { + this._worker.pause(); + return this; + }, + /** + * Return a nodejs stream for this helper. + * @param {Function} updateCb the update callback. + * @return {NodejsStreamOutputAdapter} the nodejs stream. + */ + toNodejsStream : function (updateCb) { + utils.checkSupport("nodestream"); + if (this._outputType !== "nodebuffer") { + // an object stream containing blob/arraybuffer/uint8array/string + // is strange and I don't know if it would be useful. + // I you find this comment and have a good usecase, please open a + // bug report ! + throw new Error(this._outputType + " is not supported by this method"); + } + + return new NodejsStreamOutputAdapter(this, { + objectMode : this._outputType !== "nodebuffer" + }, updateCb); + } +}; + + +module.exports = StreamHelper; + +},{"../base64":1,"../external":6,"../nodejs/NodejsStreamOutputAdapter":13,"../support":30,"../utils":32,"./ConvertWorker":24,"./GenericWorker":28}],30:[function(require,module,exports){ +"use strict"; + +exports.base64 = true; +exports.array = true; +exports.string = true; +exports.arraybuffer = typeof ArrayBuffer !== "undefined" && typeof Uint8Array !== "undefined"; +exports.nodebuffer = typeof Buffer !== "undefined"; +// contains true if JSZip can read/generate Uint8Array, false otherwise. +exports.uint8array = typeof Uint8Array !== "undefined"; + +if (typeof ArrayBuffer === "undefined") { + exports.blob = false; +} +else { + var buffer = new ArrayBuffer(0); + try { + exports.blob = new Blob([buffer], { + type: "application/zip" + }).size === 0; + } + catch (e) { + try { + var Builder = self.BlobBuilder || self.WebKitBlobBuilder || self.MozBlobBuilder || self.MSBlobBuilder; + var builder = new Builder(); + builder.append(buffer); + exports.blob = builder.getBlob("application/zip").size === 0; + } + catch (e) { + exports.blob = false; + } + } +} + +try { + exports.nodestream = !!require("readable-stream").Readable; +} catch(e) { + exports.nodestream = false; +} + +},{"readable-stream":16}],31:[function(require,module,exports){ +"use strict"; + +var utils = require("./utils"); +var support = require("./support"); +var nodejsUtils = require("./nodejsUtils"); +var GenericWorker = require("./stream/GenericWorker"); + +/** + * The following functions come from pako, from pako/lib/utils/strings + * released under the MIT license, see pako https://github.com/nodeca/pako/ + */ + +// Table with utf8 lengths (calculated by first byte of sequence) +// Note, that 5 & 6-byte values and some 4-byte values can not be represented in JS, +// because max possible codepoint is 0x10ffff +var _utf8len = new Array(256); +for (var i=0; i<256; i++) { + _utf8len[i] = (i >= 252 ? 6 : i >= 248 ? 5 : i >= 240 ? 4 : i >= 224 ? 3 : i >= 192 ? 2 : 1); +} +_utf8len[254]=_utf8len[254]=1; // Invalid sequence start + +// convert string to array (typed, when possible) +var string2buf = function (str) { + var buf, c, c2, m_pos, i, str_len = str.length, buf_len = 0; + + // count binary size + for (m_pos = 0; m_pos < str_len; m_pos++) { + c = str.charCodeAt(m_pos); + if ((c & 0xfc00) === 0xd800 && (m_pos+1 < str_len)) { + c2 = str.charCodeAt(m_pos+1); + if ((c2 & 0xfc00) === 0xdc00) { + c = 0x10000 + ((c - 0xd800) << 10) + (c2 - 0xdc00); + m_pos++; + } + } + buf_len += c < 0x80 ? 1 : c < 0x800 ? 2 : c < 0x10000 ? 3 : 4; + } + + // allocate buffer + if (support.uint8array) { + buf = new Uint8Array(buf_len); + } else { + buf = new Array(buf_len); + } + + // convert + for (i=0, m_pos = 0; i < buf_len; m_pos++) { + c = str.charCodeAt(m_pos); + if ((c & 0xfc00) === 0xd800 && (m_pos+1 < str_len)) { + c2 = str.charCodeAt(m_pos+1); + if ((c2 & 0xfc00) === 0xdc00) { + c = 0x10000 + ((c - 0xd800) << 10) + (c2 - 0xdc00); + m_pos++; + } + } + if (c < 0x80) { + /* one byte */ + buf[i++] = c; + } else if (c < 0x800) { + /* two bytes */ + buf[i++] = 0xC0 | (c >>> 6); + buf[i++] = 0x80 | (c & 0x3f); + } else if (c < 0x10000) { + /* three bytes */ + buf[i++] = 0xE0 | (c >>> 12); + buf[i++] = 0x80 | (c >>> 6 & 0x3f); + buf[i++] = 0x80 | (c & 0x3f); + } else { + /* four bytes */ + buf[i++] = 0xf0 | (c >>> 18); + buf[i++] = 0x80 | (c >>> 12 & 0x3f); + buf[i++] = 0x80 | (c >>> 6 & 0x3f); + buf[i++] = 0x80 | (c & 0x3f); + } + } + + return buf; +}; + +// Calculate max possible position in utf8 buffer, +// that will not break sequence. If that's not possible +// - (very small limits) return max size as is. +// +// buf[] - utf8 bytes array +// max - length limit (mandatory); +var utf8border = function(buf, max) { + var pos; + + max = max || buf.length; + if (max > buf.length) { max = buf.length; } + + // go back from last position, until start of sequence found + pos = max-1; + while (pos >= 0 && (buf[pos] & 0xC0) === 0x80) { pos--; } + + // Fuckup - very small and broken sequence, + // return max, because we should return something anyway. + if (pos < 0) { return max; } + + // If we came to start of buffer - that means vuffer is too small, + // return max too. + if (pos === 0) { return max; } + + return (pos + _utf8len[buf[pos]] > max) ? pos : max; +}; + +// convert array to string +var buf2string = function (buf) { + var i, out, c, c_len; + var len = buf.length; + + // Reserve max possible length (2 words per char) + // NB: by unknown reasons, Array is significantly faster for + // String.fromCharCode.apply than Uint16Array. + var utf16buf = new Array(len*2); + + for (out=0, i=0; i 4) { utf16buf[out++] = 0xfffd; i += c_len-1; continue; } + + // apply mask on first byte + c &= c_len === 2 ? 0x1f : c_len === 3 ? 0x0f : 0x07; + // join the rest + while (c_len > 1 && i < len) { + c = (c << 6) | (buf[i++] & 0x3f); + c_len--; + } + + // terminated by end of string? + if (c_len > 1) { utf16buf[out++] = 0xfffd; continue; } + + if (c < 0x10000) { + utf16buf[out++] = c; + } else { + c -= 0x10000; + utf16buf[out++] = 0xd800 | ((c >> 10) & 0x3ff); + utf16buf[out++] = 0xdc00 | (c & 0x3ff); + } + } + + // shrinkBuf(utf16buf, out) + if (utf16buf.length !== out) { + if(utf16buf.subarray) { + utf16buf = utf16buf.subarray(0, out); + } else { + utf16buf.length = out; + } + } + + // return String.fromCharCode.apply(null, utf16buf); + return utils.applyFromCharCode(utf16buf); +}; + + +// That's all for the pako functions. + + +/** + * Transform a javascript string into an array (typed if possible) of bytes, + * UTF-8 encoded. + * @param {String} str the string to encode + * @return {Array|Uint8Array|Buffer} the UTF-8 encoded string. + */ +exports.utf8encode = function utf8encode(str) { + if (support.nodebuffer) { + return nodejsUtils.newBufferFrom(str, "utf-8"); + } + + return string2buf(str); +}; + + +/** + * Transform a bytes array (or a representation) representing an UTF-8 encoded + * string into a javascript string. + * @param {Array|Uint8Array|Buffer} buf the data de decode + * @return {String} the decoded string. + */ +exports.utf8decode = function utf8decode(buf) { + if (support.nodebuffer) { + return utils.transformTo("nodebuffer", buf).toString("utf-8"); + } + + buf = utils.transformTo(support.uint8array ? "uint8array" : "array", buf); + + return buf2string(buf); +}; + +/** + * A worker to decode utf8 encoded binary chunks into string chunks. + * @constructor + */ +function Utf8DecodeWorker() { + GenericWorker.call(this, "utf-8 decode"); + // the last bytes if a chunk didn't end with a complete codepoint. + this.leftOver = null; +} +utils.inherits(Utf8DecodeWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +Utf8DecodeWorker.prototype.processChunk = function (chunk) { + + var data = utils.transformTo(support.uint8array ? "uint8array" : "array", chunk.data); + + // 1st step, re-use what's left of the previous chunk + if (this.leftOver && this.leftOver.length) { + if(support.uint8array) { + var previousData = data; + data = new Uint8Array(previousData.length + this.leftOver.length); + data.set(this.leftOver, 0); + data.set(previousData, this.leftOver.length); + } else { + data = this.leftOver.concat(data); + } + this.leftOver = null; + } + + var nextBoundary = utf8border(data); + var usableData = data; + if (nextBoundary !== data.length) { + if (support.uint8array) { + usableData = data.subarray(0, nextBoundary); + this.leftOver = data.subarray(nextBoundary, data.length); + } else { + usableData = data.slice(0, nextBoundary); + this.leftOver = data.slice(nextBoundary, data.length); + } + } + + this.push({ + data : exports.utf8decode(usableData), + meta : chunk.meta + }); +}; + +/** + * @see GenericWorker.flush + */ +Utf8DecodeWorker.prototype.flush = function () { + if(this.leftOver && this.leftOver.length) { + this.push({ + data : exports.utf8decode(this.leftOver), + meta : {} + }); + this.leftOver = null; + } +}; +exports.Utf8DecodeWorker = Utf8DecodeWorker; + +/** + * A worker to endcode string chunks into utf8 encoded binary chunks. + * @constructor + */ +function Utf8EncodeWorker() { + GenericWorker.call(this, "utf-8 encode"); +} +utils.inherits(Utf8EncodeWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +Utf8EncodeWorker.prototype.processChunk = function (chunk) { + this.push({ + data : exports.utf8encode(chunk.data), + meta : chunk.meta + }); +}; +exports.Utf8EncodeWorker = Utf8EncodeWorker; + +},{"./nodejsUtils":14,"./stream/GenericWorker":28,"./support":30,"./utils":32}],32:[function(require,module,exports){ +"use strict"; + +var support = require("./support"); +var base64 = require("./base64"); +var nodejsUtils = require("./nodejsUtils"); +var external = require("./external"); +require("setimmediate"); + + +/** + * Convert a string that pass as a "binary string": it should represent a byte + * array but may have > 255 char codes. Be sure to take only the first byte + * and returns the byte array. + * @param {String} str the string to transform. + * @return {Array|Uint8Array} the string in a binary format. + */ +function string2binary(str) { + var result = null; + if (support.uint8array) { + result = new Uint8Array(str.length); + } else { + result = new Array(str.length); + } + return stringToArrayLike(str, result); +} + +/** + * Create a new blob with the given content and the given type. + * @param {String|ArrayBuffer} part the content to put in the blob. DO NOT use + * an Uint8Array because the stock browser of android 4 won't accept it (it + * will be silently converted to a string, "[object Uint8Array]"). + * + * Use only ONE part to build the blob to avoid a memory leak in IE11 / Edge: + * when a large amount of Array is used to create the Blob, the amount of + * memory consumed is nearly 100 times the original data amount. + * + * @param {String} type the mime type of the blob. + * @return {Blob} the created blob. + */ +exports.newBlob = function(part, type) { + exports.checkSupport("blob"); + + try { + // Blob constructor + return new Blob([part], { + type: type + }); + } + catch (e) { + + try { + // deprecated, browser only, old way + var Builder = self.BlobBuilder || self.WebKitBlobBuilder || self.MozBlobBuilder || self.MSBlobBuilder; + var builder = new Builder(); + builder.append(part); + return builder.getBlob(type); + } + catch (e) { + + // well, fuck ?! + throw new Error("Bug : can't construct the Blob."); + } + } + + +}; +/** + * The identity function. + * @param {Object} input the input. + * @return {Object} the same input. + */ +function identity(input) { + return input; +} + +/** + * Fill in an array with a string. + * @param {String} str the string to use. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to fill in (will be mutated). + * @return {Array|ArrayBuffer|Uint8Array|Buffer} the updated array. + */ +function stringToArrayLike(str, array) { + for (var i = 0; i < str.length; ++i) { + array[i] = str.charCodeAt(i) & 0xFF; + } + return array; +} + +/** + * An helper for the function arrayLikeToString. + * This contains static information and functions that + * can be optimized by the browser JIT compiler. + */ +var arrayToStringHelper = { + /** + * Transform an array of int into a string, chunk by chunk. + * See the performances notes on arrayLikeToString. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to transform. + * @param {String} type the type of the array. + * @param {Integer} chunk the chunk size. + * @return {String} the resulting string. + * @throws Error if the chunk is too big for the stack. + */ + stringifyByChunk: function(array, type, chunk) { + var result = [], k = 0, len = array.length; + // shortcut + if (len <= chunk) { + return String.fromCharCode.apply(null, array); + } + while (k < len) { + if (type === "array" || type === "nodebuffer") { + result.push(String.fromCharCode.apply(null, array.slice(k, Math.min(k + chunk, len)))); + } + else { + result.push(String.fromCharCode.apply(null, array.subarray(k, Math.min(k + chunk, len)))); + } + k += chunk; + } + return result.join(""); + }, + /** + * Call String.fromCharCode on every item in the array. + * This is the naive implementation, which generate A LOT of intermediate string. + * This should be used when everything else fail. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to transform. + * @return {String} the result. + */ + stringifyByChar: function(array){ + var resultStr = ""; + for(var i = 0; i < array.length; i++) { + resultStr += String.fromCharCode(array[i]); + } + return resultStr; + }, + applyCanBeUsed : { + /** + * true if the browser accepts to use String.fromCharCode on Uint8Array + */ + uint8array : (function () { + try { + return support.uint8array && String.fromCharCode.apply(null, new Uint8Array(1)).length === 1; + } catch (e) { + return false; + } + })(), + /** + * true if the browser accepts to use String.fromCharCode on nodejs Buffer. + */ + nodebuffer : (function () { + try { + return support.nodebuffer && String.fromCharCode.apply(null, nodejsUtils.allocBuffer(1)).length === 1; + } catch (e) { + return false; + } + })() + } +}; + +/** + * Transform an array-like object to a string. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to transform. + * @return {String} the result. + */ +function arrayLikeToString(array) { + // Performances notes : + // -------------------- + // String.fromCharCode.apply(null, array) is the fastest, see + // see http://jsperf.com/converting-a-uint8array-to-a-string/2 + // but the stack is limited (and we can get huge arrays !). + // + // result += String.fromCharCode(array[i]); generate too many strings ! + // + // This code is inspired by http://jsperf.com/arraybuffer-to-string-apply-performance/2 + // TODO : we now have workers that split the work. Do we still need that ? + var chunk = 65536, + type = exports.getTypeOf(array), + canUseApply = true; + if (type === "uint8array") { + canUseApply = arrayToStringHelper.applyCanBeUsed.uint8array; + } else if (type === "nodebuffer") { + canUseApply = arrayToStringHelper.applyCanBeUsed.nodebuffer; + } + + if (canUseApply) { + while (chunk > 1) { + try { + return arrayToStringHelper.stringifyByChunk(array, type, chunk); + } catch (e) { + chunk = Math.floor(chunk / 2); + } + } + } + + // no apply or chunk error : slow and painful algorithm + // default browser on android 4.* + return arrayToStringHelper.stringifyByChar(array); +} + +exports.applyFromCharCode = arrayLikeToString; + + +/** + * Copy the data from an array-like to an other array-like. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} arrayFrom the origin array. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} arrayTo the destination array which will be mutated. + * @return {Array|ArrayBuffer|Uint8Array|Buffer} the updated destination array. + */ +function arrayLikeToArrayLike(arrayFrom, arrayTo) { + for (var i = 0; i < arrayFrom.length; i++) { + arrayTo[i] = arrayFrom[i]; + } + return arrayTo; +} + +// a matrix containing functions to transform everything into everything. +var transform = {}; + +// string to ? +transform["string"] = { + "string": identity, + "array": function(input) { + return stringToArrayLike(input, new Array(input.length)); + }, + "arraybuffer": function(input) { + return transform["string"]["uint8array"](input).buffer; + }, + "uint8array": function(input) { + return stringToArrayLike(input, new Uint8Array(input.length)); + }, + "nodebuffer": function(input) { + return stringToArrayLike(input, nodejsUtils.allocBuffer(input.length)); + } +}; + +// array to ? +transform["array"] = { + "string": arrayLikeToString, + "array": identity, + "arraybuffer": function(input) { + return (new Uint8Array(input)).buffer; + }, + "uint8array": function(input) { + return new Uint8Array(input); + }, + "nodebuffer": function(input) { + return nodejsUtils.newBufferFrom(input); + } +}; + +// arraybuffer to ? +transform["arraybuffer"] = { + "string": function(input) { + return arrayLikeToString(new Uint8Array(input)); + }, + "array": function(input) { + return arrayLikeToArrayLike(new Uint8Array(input), new Array(input.byteLength)); + }, + "arraybuffer": identity, + "uint8array": function(input) { + return new Uint8Array(input); + }, + "nodebuffer": function(input) { + return nodejsUtils.newBufferFrom(new Uint8Array(input)); + } +}; + +// uint8array to ? +transform["uint8array"] = { + "string": arrayLikeToString, + "array": function(input) { + return arrayLikeToArrayLike(input, new Array(input.length)); + }, + "arraybuffer": function(input) { + return input.buffer; + }, + "uint8array": identity, + "nodebuffer": function(input) { + return nodejsUtils.newBufferFrom(input); + } +}; + +// nodebuffer to ? +transform["nodebuffer"] = { + "string": arrayLikeToString, + "array": function(input) { + return arrayLikeToArrayLike(input, new Array(input.length)); + }, + "arraybuffer": function(input) { + return transform["nodebuffer"]["uint8array"](input).buffer; + }, + "uint8array": function(input) { + return arrayLikeToArrayLike(input, new Uint8Array(input.length)); + }, + "nodebuffer": identity +}; + +/** + * Transform an input into any type. + * The supported output type are : string, array, uint8array, arraybuffer, nodebuffer. + * If no output type is specified, the unmodified input will be returned. + * @param {String} outputType the output type. + * @param {String|Array|ArrayBuffer|Uint8Array|Buffer} input the input to convert. + * @throws {Error} an Error if the browser doesn't support the requested output type. + */ +exports.transformTo = function(outputType, input) { + if (!input) { + // undefined, null, etc + // an empty string won't harm. + input = ""; + } + if (!outputType) { + return input; + } + exports.checkSupport(outputType); + var inputType = exports.getTypeOf(input); + var result = transform[inputType][outputType](input); + return result; +}; + +/** + * Resolve all relative path components, "." and "..", in a path. If these relative components + * traverse above the root then the resulting path will only contain the final path component. + * + * All empty components, e.g. "//", are removed. + * @param {string} path A path with / or \ separators + * @returns {string} The path with all relative path components resolved. + */ +exports.resolve = function(path) { + var parts = path.split("/"); + var result = []; + for (var index = 0; index < parts.length; index++) { + var part = parts[index]; + // Allow the first and last component to be empty for trailing slashes. + if (part === "." || (part === "" && index !== 0 && index !== parts.length - 1)) { + continue; + } else if (part === "..") { + result.pop(); + } else { + result.push(part); + } + } + return result.join("/"); +}; + +/** + * Return the type of the input. + * The type will be in a format valid for JSZip.utils.transformTo : string, array, uint8array, arraybuffer. + * @param {Object} input the input to identify. + * @return {String} the (lowercase) type of the input. + */ +exports.getTypeOf = function(input) { + if (typeof input === "string") { + return "string"; + } + if (Object.prototype.toString.call(input) === "[object Array]") { + return "array"; + } + if (support.nodebuffer && nodejsUtils.isBuffer(input)) { + return "nodebuffer"; + } + if (support.uint8array && input instanceof Uint8Array) { + return "uint8array"; + } + if (support.arraybuffer && input instanceof ArrayBuffer) { + return "arraybuffer"; + } +}; + +/** + * Throw an exception if the type is not supported. + * @param {String} type the type to check. + * @throws {Error} an Error if the browser doesn't support the requested type. + */ +exports.checkSupport = function(type) { + var supported = support[type.toLowerCase()]; + if (!supported) { + throw new Error(type + " is not supported by this platform"); + } +}; + +exports.MAX_VALUE_16BITS = 65535; +exports.MAX_VALUE_32BITS = -1; // well, "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" is parsed as -1 + +/** + * Prettify a string read as binary. + * @param {string} str the string to prettify. + * @return {string} a pretty string. + */ +exports.pretty = function(str) { + var res = "", + code, i; + for (i = 0; i < (str || "").length; i++) { + code = str.charCodeAt(i); + res += "\\x" + (code < 16 ? "0" : "") + code.toString(16).toUpperCase(); + } + return res; +}; + +/** + * Defer the call of a function. + * @param {Function} callback the function to call asynchronously. + * @param {Array} args the arguments to give to the callback. + */ +exports.delay = function(callback, args, self) { + setImmediate(function () { + callback.apply(self || null, args || []); + }); +}; + +/** + * Extends a prototype with an other, without calling a constructor with + * side effects. Inspired by nodejs' `utils.inherits` + * @param {Function} ctor the constructor to augment + * @param {Function} superCtor the parent constructor to use + */ +exports.inherits = function (ctor, superCtor) { + var Obj = function() {}; + Obj.prototype = superCtor.prototype; + ctor.prototype = new Obj(); +}; + +/** + * Merge the objects passed as parameters into a new one. + * @private + * @param {...Object} var_args All objects to merge. + * @return {Object} a new object with the data of the others. + */ +exports.extend = function() { + var result = {}, i, attr; + for (i = 0; i < arguments.length; i++) { // arguments is not enumerable in some browsers + for (attr in arguments[i]) { + if (Object.prototype.hasOwnProperty.call(arguments[i], attr) && typeof result[attr] === "undefined") { + result[attr] = arguments[i][attr]; + } + } + } + return result; +}; + +/** + * Transform arbitrary content into a Promise. + * @param {String} name a name for the content being processed. + * @param {Object} inputData the content to process. + * @param {Boolean} isBinary true if the content is not an unicode string + * @param {Boolean} isOptimizedBinaryString true if the string content only has one byte per character. + * @param {Boolean} isBase64 true if the string content is encoded with base64. + * @return {Promise} a promise in a format usable by JSZip. + */ +exports.prepareContent = function(name, inputData, isBinary, isOptimizedBinaryString, isBase64) { + + // if inputData is already a promise, this flatten it. + var promise = external.Promise.resolve(inputData).then(function(data) { + + + var isBlob = support.blob && (data instanceof Blob || ["[object File]", "[object Blob]"].indexOf(Object.prototype.toString.call(data)) !== -1); + + if (isBlob && typeof FileReader !== "undefined") { + return new external.Promise(function (resolve, reject) { + var reader = new FileReader(); + + reader.onload = function(e) { + resolve(e.target.result); + }; + reader.onerror = function(e) { + reject(e.target.error); + }; + reader.readAsArrayBuffer(data); + }); + } else { + return data; + } + }); + + return promise.then(function(data) { + var dataType = exports.getTypeOf(data); + + if (!dataType) { + return external.Promise.reject( + new Error("Can't read the data of '" + name + "'. Is it " + + "in a supported JavaScript type (String, Blob, ArrayBuffer, etc) ?") + ); + } + // special case : it's way easier to work with Uint8Array than with ArrayBuffer + if (dataType === "arraybuffer") { + data = exports.transformTo("uint8array", data); + } else if (dataType === "string") { + if (isBase64) { + data = base64.decode(data); + } + else if (isBinary) { + // optimizedBinaryString === true means that the file has already been filtered with a 0xFF mask + if (isOptimizedBinaryString !== true) { + // this is a string, not in a base64 format. + // Be sure that this is a correct "binary string" + data = string2binary(data); + } + } + } + return data; + }); +}; + +},{"./base64":1,"./external":6,"./nodejsUtils":14,"./support":30,"setimmediate":54}],33:[function(require,module,exports){ +"use strict"; +var readerFor = require("./reader/readerFor"); +var utils = require("./utils"); +var sig = require("./signature"); +var ZipEntry = require("./zipEntry"); +var support = require("./support"); +// class ZipEntries {{{ +/** + * All the entries in the zip file. + * @constructor + * @param {Object} loadOptions Options for loading the stream. + */ +function ZipEntries(loadOptions) { + this.files = []; + this.loadOptions = loadOptions; +} +ZipEntries.prototype = { + /** + * Check that the reader is on the specified signature. + * @param {string} expectedSignature the expected signature. + * @throws {Error} if it is an other signature. + */ + checkSignature: function(expectedSignature) { + if (!this.reader.readAndCheckSignature(expectedSignature)) { + this.reader.index -= 4; + var signature = this.reader.readString(4); + throw new Error("Corrupted zip or bug: unexpected signature " + "(" + utils.pretty(signature) + ", expected " + utils.pretty(expectedSignature) + ")"); + } + }, + /** + * Check if the given signature is at the given index. + * @param {number} askedIndex the index to check. + * @param {string} expectedSignature the signature to expect. + * @return {boolean} true if the signature is here, false otherwise. + */ + isSignature: function(askedIndex, expectedSignature) { + var currentIndex = this.reader.index; + this.reader.setIndex(askedIndex); + var signature = this.reader.readString(4); + var result = signature === expectedSignature; + this.reader.setIndex(currentIndex); + return result; + }, + /** + * Read the end of the central directory. + */ + readBlockEndOfCentral: function() { + this.diskNumber = this.reader.readInt(2); + this.diskWithCentralDirStart = this.reader.readInt(2); + this.centralDirRecordsOnThisDisk = this.reader.readInt(2); + this.centralDirRecords = this.reader.readInt(2); + this.centralDirSize = this.reader.readInt(4); + this.centralDirOffset = this.reader.readInt(4); + + this.zipCommentLength = this.reader.readInt(2); + // warning : the encoding depends of the system locale + // On a linux machine with LANG=en_US.utf8, this field is utf8 encoded. + // On a windows machine, this field is encoded with the localized windows code page. + var zipComment = this.reader.readData(this.zipCommentLength); + var decodeParamType = support.uint8array ? "uint8array" : "array"; + // To get consistent behavior with the generation part, we will assume that + // this is utf8 encoded unless specified otherwise. + var decodeContent = utils.transformTo(decodeParamType, zipComment); + this.zipComment = this.loadOptions.decodeFileName(decodeContent); + }, + /** + * Read the end of the Zip 64 central directory. + * Not merged with the method readEndOfCentral : + * The end of central can coexist with its Zip64 brother, + * I don't want to read the wrong number of bytes ! + */ + readBlockZip64EndOfCentral: function() { + this.zip64EndOfCentralSize = this.reader.readInt(8); + this.reader.skip(4); + // this.versionMadeBy = this.reader.readString(2); + // this.versionNeeded = this.reader.readInt(2); + this.diskNumber = this.reader.readInt(4); + this.diskWithCentralDirStart = this.reader.readInt(4); + this.centralDirRecordsOnThisDisk = this.reader.readInt(8); + this.centralDirRecords = this.reader.readInt(8); + this.centralDirSize = this.reader.readInt(8); + this.centralDirOffset = this.reader.readInt(8); + + this.zip64ExtensibleData = {}; + var extraDataSize = this.zip64EndOfCentralSize - 44, + index = 0, + extraFieldId, + extraFieldLength, + extraFieldValue; + while (index < extraDataSize) { + extraFieldId = this.reader.readInt(2); + extraFieldLength = this.reader.readInt(4); + extraFieldValue = this.reader.readData(extraFieldLength); + this.zip64ExtensibleData[extraFieldId] = { + id: extraFieldId, + length: extraFieldLength, + value: extraFieldValue + }; + } + }, + /** + * Read the end of the Zip 64 central directory locator. + */ + readBlockZip64EndOfCentralLocator: function() { + this.diskWithZip64CentralDirStart = this.reader.readInt(4); + this.relativeOffsetEndOfZip64CentralDir = this.reader.readInt(8); + this.disksCount = this.reader.readInt(4); + if (this.disksCount > 1) { + throw new Error("Multi-volumes zip are not supported"); + } + }, + /** + * Read the local files, based on the offset read in the central part. + */ + readLocalFiles: function() { + var i, file; + for (i = 0; i < this.files.length; i++) { + file = this.files[i]; + this.reader.setIndex(file.localHeaderOffset); + this.checkSignature(sig.LOCAL_FILE_HEADER); + file.readLocalPart(this.reader); + file.handleUTF8(); + file.processAttributes(); + } + }, + /** + * Read the central directory. + */ + readCentralDir: function() { + var file; + + this.reader.setIndex(this.centralDirOffset); + while (this.reader.readAndCheckSignature(sig.CENTRAL_FILE_HEADER)) { + file = new ZipEntry({ + zip64: this.zip64 + }, this.loadOptions); + file.readCentralPart(this.reader); + this.files.push(file); + } + + if (this.centralDirRecords !== this.files.length) { + if (this.centralDirRecords !== 0 && this.files.length === 0) { + // We expected some records but couldn't find ANY. + // This is really suspicious, as if something went wrong. + throw new Error("Corrupted zip or bug: expected " + this.centralDirRecords + " records in central dir, got " + this.files.length); + } else { + // We found some records but not all. + // Something is wrong but we got something for the user: no error here. + // console.warn("expected", this.centralDirRecords, "records in central dir, got", this.files.length); + } + } + }, + /** + * Read the end of central directory. + */ + readEndOfCentral: function() { + var offset = this.reader.lastIndexOfSignature(sig.CENTRAL_DIRECTORY_END); + if (offset < 0) { + // Check if the content is a truncated zip or complete garbage. + // A "LOCAL_FILE_HEADER" is not required at the beginning (auto + // extractible zip for example) but it can give a good hint. + // If an ajax request was used without responseType, we will also + // get unreadable data. + var isGarbage = !this.isSignature(0, sig.LOCAL_FILE_HEADER); + + if (isGarbage) { + throw new Error("Can't find end of central directory : is this a zip file ? " + + "If it is, see https://stuk.github.io/jszip/documentation/howto/read_zip.html"); + } else { + throw new Error("Corrupted zip: can't find end of central directory"); + } + + } + this.reader.setIndex(offset); + var endOfCentralDirOffset = offset; + this.checkSignature(sig.CENTRAL_DIRECTORY_END); + this.readBlockEndOfCentral(); + + + /* extract from the zip spec : + 4) If one of the fields in the end of central directory + record is too small to hold required data, the field + should be set to -1 (0xFFFF or 0xFFFFFFFF) and the + ZIP64 format record should be created. + 5) The end of central directory record and the + Zip64 end of central directory locator record must + reside on the same disk when splitting or spanning + an archive. + */ + if (this.diskNumber === utils.MAX_VALUE_16BITS || this.diskWithCentralDirStart === utils.MAX_VALUE_16BITS || this.centralDirRecordsOnThisDisk === utils.MAX_VALUE_16BITS || this.centralDirRecords === utils.MAX_VALUE_16BITS || this.centralDirSize === utils.MAX_VALUE_32BITS || this.centralDirOffset === utils.MAX_VALUE_32BITS) { + this.zip64 = true; + + /* + Warning : the zip64 extension is supported, but ONLY if the 64bits integer read from + the zip file can fit into a 32bits integer. This cannot be solved : JavaScript represents + all numbers as 64-bit double precision IEEE 754 floating point numbers. + So, we have 53bits for integers and bitwise operations treat everything as 32bits. + see https://developer.mozilla.org/en-US/docs/JavaScript/Reference/Operators/Bitwise_Operators + and http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-262.pdf section 8.5 + */ + + // should look for a zip64 EOCD locator + offset = this.reader.lastIndexOfSignature(sig.ZIP64_CENTRAL_DIRECTORY_LOCATOR); + if (offset < 0) { + throw new Error("Corrupted zip: can't find the ZIP64 end of central directory locator"); + } + this.reader.setIndex(offset); + this.checkSignature(sig.ZIP64_CENTRAL_DIRECTORY_LOCATOR); + this.readBlockZip64EndOfCentralLocator(); + + // now the zip64 EOCD record + if (!this.isSignature(this.relativeOffsetEndOfZip64CentralDir, sig.ZIP64_CENTRAL_DIRECTORY_END)) { + // console.warn("ZIP64 end of central directory not where expected."); + this.relativeOffsetEndOfZip64CentralDir = this.reader.lastIndexOfSignature(sig.ZIP64_CENTRAL_DIRECTORY_END); + if (this.relativeOffsetEndOfZip64CentralDir < 0) { + throw new Error("Corrupted zip: can't find the ZIP64 end of central directory"); + } + } + this.reader.setIndex(this.relativeOffsetEndOfZip64CentralDir); + this.checkSignature(sig.ZIP64_CENTRAL_DIRECTORY_END); + this.readBlockZip64EndOfCentral(); + } + + var expectedEndOfCentralDirOffset = this.centralDirOffset + this.centralDirSize; + if (this.zip64) { + expectedEndOfCentralDirOffset += 20; // end of central dir 64 locator + expectedEndOfCentralDirOffset += 12 /* should not include the leading 12 bytes */ + this.zip64EndOfCentralSize; + } + + var extraBytes = endOfCentralDirOffset - expectedEndOfCentralDirOffset; + + if (extraBytes > 0) { + // console.warn(extraBytes, "extra bytes at beginning or within zipfile"); + if (this.isSignature(endOfCentralDirOffset, sig.CENTRAL_FILE_HEADER)) { + // The offsets seem wrong, but we have something at the specified offset. + // So… we keep it. + } else { + // the offset is wrong, update the "zero" of the reader + // this happens if data has been prepended (crx files for example) + this.reader.zero = extraBytes; + } + } else if (extraBytes < 0) { + throw new Error("Corrupted zip: missing " + Math.abs(extraBytes) + " bytes."); + } + }, + prepareReader: function(data) { + this.reader = readerFor(data); + }, + /** + * Read a zip file and create ZipEntries. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the binary string representing a zip file. + */ + load: function(data) { + this.prepareReader(data); + this.readEndOfCentral(); + this.readCentralDir(); + this.readLocalFiles(); + } +}; +// }}} end of ZipEntries +module.exports = ZipEntries; + +},{"./reader/readerFor":22,"./signature":23,"./support":30,"./utils":32,"./zipEntry":34}],34:[function(require,module,exports){ +"use strict"; +var readerFor = require("./reader/readerFor"); +var utils = require("./utils"); +var CompressedObject = require("./compressedObject"); +var crc32fn = require("./crc32"); +var utf8 = require("./utf8"); +var compressions = require("./compressions"); +var support = require("./support"); + +var MADE_BY_DOS = 0x00; +var MADE_BY_UNIX = 0x03; + +/** + * Find a compression registered in JSZip. + * @param {string} compressionMethod the method magic to find. + * @return {Object|null} the JSZip compression object, null if none found. + */ +var findCompression = function(compressionMethod) { + for (var method in compressions) { + if (!Object.prototype.hasOwnProperty.call(compressions, method)) { + continue; + } + if (compressions[method].magic === compressionMethod) { + return compressions[method]; + } + } + return null; +}; + +// class ZipEntry {{{ +/** + * An entry in the zip file. + * @constructor + * @param {Object} options Options of the current file. + * @param {Object} loadOptions Options for loading the stream. + */ +function ZipEntry(options, loadOptions) { + this.options = options; + this.loadOptions = loadOptions; +} +ZipEntry.prototype = { + /** + * say if the file is encrypted. + * @return {boolean} true if the file is encrypted, false otherwise. + */ + isEncrypted: function() { + // bit 1 is set + return (this.bitFlag & 0x0001) === 0x0001; + }, + /** + * say if the file has utf-8 filename/comment. + * @return {boolean} true if the filename/comment is in utf-8, false otherwise. + */ + useUTF8: function() { + // bit 11 is set + return (this.bitFlag & 0x0800) === 0x0800; + }, + /** + * Read the local part of a zip file and add the info in this object. + * @param {DataReader} reader the reader to use. + */ + readLocalPart: function(reader) { + var compression, localExtraFieldsLength; + + // we already know everything from the central dir ! + // If the central dir data are false, we are doomed. + // On the bright side, the local part is scary : zip64, data descriptors, both, etc. + // The less data we get here, the more reliable this should be. + // Let's skip the whole header and dash to the data ! + reader.skip(22); + // in some zip created on windows, the filename stored in the central dir contains \ instead of /. + // Strangely, the filename here is OK. + // I would love to treat these zip files as corrupted (see http://www.info-zip.org/FAQ.html#backslashes + // or APPNOTE#4.4.17.1, "All slashes MUST be forward slashes '/'") but there are a lot of bad zip generators... + // Search "unzip mismatching "local" filename continuing with "central" filename version" on + // the internet. + // + // I think I see the logic here : the central directory is used to display + // content and the local directory is used to extract the files. Mixing / and \ + // may be used to display \ to windows users and use / when extracting the files. + // Unfortunately, this lead also to some issues : http://seclists.org/fulldisclosure/2009/Sep/394 + this.fileNameLength = reader.readInt(2); + localExtraFieldsLength = reader.readInt(2); // can't be sure this will be the same as the central dir + // the fileName is stored as binary data, the handleUTF8 method will take care of the encoding. + this.fileName = reader.readData(this.fileNameLength); + reader.skip(localExtraFieldsLength); + + if (this.compressedSize === -1 || this.uncompressedSize === -1) { + throw new Error("Bug or corrupted zip : didn't get enough information from the central directory " + "(compressedSize === -1 || uncompressedSize === -1)"); + } + + compression = findCompression(this.compressionMethod); + if (compression === null) { // no compression found + throw new Error("Corrupted zip : compression " + utils.pretty(this.compressionMethod) + " unknown (inner file : " + utils.transformTo("string", this.fileName) + ")"); + } + this.decompressed = new CompressedObject(this.compressedSize, this.uncompressedSize, this.crc32, compression, reader.readData(this.compressedSize)); + }, + + /** + * Read the central part of a zip file and add the info in this object. + * @param {DataReader} reader the reader to use. + */ + readCentralPart: function(reader) { + this.versionMadeBy = reader.readInt(2); + reader.skip(2); + // this.versionNeeded = reader.readInt(2); + this.bitFlag = reader.readInt(2); + this.compressionMethod = reader.readString(2); + this.date = reader.readDate(); + this.crc32 = reader.readInt(4); + this.compressedSize = reader.readInt(4); + this.uncompressedSize = reader.readInt(4); + var fileNameLength = reader.readInt(2); + this.extraFieldsLength = reader.readInt(2); + this.fileCommentLength = reader.readInt(2); + this.diskNumberStart = reader.readInt(2); + this.internalFileAttributes = reader.readInt(2); + this.externalFileAttributes = reader.readInt(4); + this.localHeaderOffset = reader.readInt(4); + + if (this.isEncrypted()) { + throw new Error("Encrypted zip are not supported"); + } + + // will be read in the local part, see the comments there + reader.skip(fileNameLength); + this.readExtraFields(reader); + this.parseZIP64ExtraField(reader); + this.fileComment = reader.readData(this.fileCommentLength); + }, + + /** + * Parse the external file attributes and get the unix/dos permissions. + */ + processAttributes: function () { + this.unixPermissions = null; + this.dosPermissions = null; + var madeBy = this.versionMadeBy >> 8; + + // Check if we have the DOS directory flag set. + // We look for it in the DOS and UNIX permissions + // but some unknown platform could set it as a compatibility flag. + this.dir = this.externalFileAttributes & 0x0010 ? true : false; + + if(madeBy === MADE_BY_DOS) { + // first 6 bits (0 to 5) + this.dosPermissions = this.externalFileAttributes & 0x3F; + } + + if(madeBy === MADE_BY_UNIX) { + this.unixPermissions = (this.externalFileAttributes >> 16) & 0xFFFF; + // the octal permissions are in (this.unixPermissions & 0x01FF).toString(8); + } + + // fail safe : if the name ends with a / it probably means a folder + if (!this.dir && this.fileNameStr.slice(-1) === "/") { + this.dir = true; + } + }, + + /** + * Parse the ZIP64 extra field and merge the info in the current ZipEntry. + * @param {DataReader} reader the reader to use. + */ + parseZIP64ExtraField: function() { + if (!this.extraFields[0x0001]) { + return; + } + + // should be something, preparing the extra reader + var extraReader = readerFor(this.extraFields[0x0001].value); + + // I really hope that these 64bits integer can fit in 32 bits integer, because js + // won't let us have more. + if (this.uncompressedSize === utils.MAX_VALUE_32BITS) { + this.uncompressedSize = extraReader.readInt(8); + } + if (this.compressedSize === utils.MAX_VALUE_32BITS) { + this.compressedSize = extraReader.readInt(8); + } + if (this.localHeaderOffset === utils.MAX_VALUE_32BITS) { + this.localHeaderOffset = extraReader.readInt(8); + } + if (this.diskNumberStart === utils.MAX_VALUE_32BITS) { + this.diskNumberStart = extraReader.readInt(4); + } + }, + /** + * Read the central part of a zip file and add the info in this object. + * @param {DataReader} reader the reader to use. + */ + readExtraFields: function(reader) { + var end = reader.index + this.extraFieldsLength, + extraFieldId, + extraFieldLength, + extraFieldValue; + + if (!this.extraFields) { + this.extraFields = {}; + } + + while (reader.index + 4 < end) { + extraFieldId = reader.readInt(2); + extraFieldLength = reader.readInt(2); + extraFieldValue = reader.readData(extraFieldLength); + + this.extraFields[extraFieldId] = { + id: extraFieldId, + length: extraFieldLength, + value: extraFieldValue + }; + } + + reader.setIndex(end); + }, + /** + * Apply an UTF8 transformation if needed. + */ + handleUTF8: function() { + var decodeParamType = support.uint8array ? "uint8array" : "array"; + if (this.useUTF8()) { + this.fileNameStr = utf8.utf8decode(this.fileName); + this.fileCommentStr = utf8.utf8decode(this.fileComment); + } else { + var upath = this.findExtraFieldUnicodePath(); + if (upath !== null) { + this.fileNameStr = upath; + } else { + // ASCII text or unsupported code page + var fileNameByteArray = utils.transformTo(decodeParamType, this.fileName); + this.fileNameStr = this.loadOptions.decodeFileName(fileNameByteArray); + } + + var ucomment = this.findExtraFieldUnicodeComment(); + if (ucomment !== null) { + this.fileCommentStr = ucomment; + } else { + // ASCII text or unsupported code page + var commentByteArray = utils.transformTo(decodeParamType, this.fileComment); + this.fileCommentStr = this.loadOptions.decodeFileName(commentByteArray); + } + } + }, + + /** + * Find the unicode path declared in the extra field, if any. + * @return {String} the unicode path, null otherwise. + */ + findExtraFieldUnicodePath: function() { + var upathField = this.extraFields[0x7075]; + if (upathField) { + var extraReader = readerFor(upathField.value); + + // wrong version + if (extraReader.readInt(1) !== 1) { + return null; + } + + // the crc of the filename changed, this field is out of date. + if (crc32fn(this.fileName) !== extraReader.readInt(4)) { + return null; + } + + return utf8.utf8decode(extraReader.readData(upathField.length - 5)); + } + return null; + }, + + /** + * Find the unicode comment declared in the extra field, if any. + * @return {String} the unicode comment, null otherwise. + */ + findExtraFieldUnicodeComment: function() { + var ucommentField = this.extraFields[0x6375]; + if (ucommentField) { + var extraReader = readerFor(ucommentField.value); + + // wrong version + if (extraReader.readInt(1) !== 1) { + return null; + } + + // the crc of the comment changed, this field is out of date. + if (crc32fn(this.fileComment) !== extraReader.readInt(4)) { + return null; + } + + return utf8.utf8decode(extraReader.readData(ucommentField.length - 5)); + } + return null; + } +}; +module.exports = ZipEntry; + +},{"./compressedObject":2,"./compressions":3,"./crc32":4,"./reader/readerFor":22,"./support":30,"./utf8":31,"./utils":32}],35:[function(require,module,exports){ +"use strict"; + +var StreamHelper = require("./stream/StreamHelper"); +var DataWorker = require("./stream/DataWorker"); +var utf8 = require("./utf8"); +var CompressedObject = require("./compressedObject"); +var GenericWorker = require("./stream/GenericWorker"); + +/** + * A simple object representing a file in the zip file. + * @constructor + * @param {string} name the name of the file + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the data + * @param {Object} options the options of the file + */ +var ZipObject = function(name, data, options) { + this.name = name; + this.dir = options.dir; + this.date = options.date; + this.comment = options.comment; + this.unixPermissions = options.unixPermissions; + this.dosPermissions = options.dosPermissions; + + this._data = data; + this._dataBinary = options.binary; + // keep only the compression + this.options = { + compression : options.compression, + compressionOptions : options.compressionOptions + }; +}; + +ZipObject.prototype = { + /** + * Create an internal stream for the content of this object. + * @param {String} type the type of each chunk. + * @return StreamHelper the stream. + */ + internalStream: function (type) { + var result = null, outputType = "string"; + try { + if (!type) { + throw new Error("No output type specified."); + } + outputType = type.toLowerCase(); + var askUnicodeString = outputType === "string" || outputType === "text"; + if (outputType === "binarystring" || outputType === "text") { + outputType = "string"; + } + result = this._decompressWorker(); + + var isUnicodeString = !this._dataBinary; + + if (isUnicodeString && !askUnicodeString) { + result = result.pipe(new utf8.Utf8EncodeWorker()); + } + if (!isUnicodeString && askUnicodeString) { + result = result.pipe(new utf8.Utf8DecodeWorker()); + } + } catch (e) { + result = new GenericWorker("error"); + result.error(e); + } + + return new StreamHelper(result, outputType, ""); + }, + + /** + * Prepare the content in the asked type. + * @param {String} type the type of the result. + * @param {Function} onUpdate a function to call on each internal update. + * @return Promise the promise of the result. + */ + async: function (type, onUpdate) { + return this.internalStream(type).accumulate(onUpdate); + }, + + /** + * Prepare the content as a nodejs stream. + * @param {String} type the type of each chunk. + * @param {Function} onUpdate a function to call on each internal update. + * @return Stream the stream. + */ + nodeStream: function (type, onUpdate) { + return this.internalStream(type || "nodebuffer").toNodejsStream(onUpdate); + }, + + /** + * Return a worker for the compressed content. + * @private + * @param {Object} compression the compression object to use. + * @param {Object} compressionOptions the options to use when compressing. + * @return Worker the worker. + */ + _compressWorker: function (compression, compressionOptions) { + if ( + this._data instanceof CompressedObject && + this._data.compression.magic === compression.magic + ) { + return this._data.getCompressedWorker(); + } else { + var result = this._decompressWorker(); + if(!this._dataBinary) { + result = result.pipe(new utf8.Utf8EncodeWorker()); + } + return CompressedObject.createWorkerFrom(result, compression, compressionOptions); + } + }, + /** + * Return a worker for the decompressed content. + * @private + * @return Worker the worker. + */ + _decompressWorker : function () { + if (this._data instanceof CompressedObject) { + return this._data.getContentWorker(); + } else if (this._data instanceof GenericWorker) { + return this._data; + } else { + return new DataWorker(this._data); + } + } +}; + +var removedMethods = ["asText", "asBinary", "asNodeBuffer", "asUint8Array", "asArrayBuffer"]; +var removedFn = function () { + throw new Error("This method has been removed in JSZip 3.0, please check the upgrade guide."); +}; + +for(var i = 0; i < removedMethods.length; i++) { + ZipObject.prototype[removedMethods[i]] = removedFn; +} +module.exports = ZipObject; + +},{"./compressedObject":2,"./stream/DataWorker":27,"./stream/GenericWorker":28,"./stream/StreamHelper":29,"./utf8":31}],36:[function(require,module,exports){ +(function (global){ +'use strict'; +var Mutation = global.MutationObserver || global.WebKitMutationObserver; + +var scheduleDrain; + +{ + if (Mutation) { + var called = 0; + var observer = new Mutation(nextTick); + var element = global.document.createTextNode(''); + observer.observe(element, { + characterData: true + }); + scheduleDrain = function () { + element.data = (called = ++called % 2); + }; + } else if (!global.setImmediate && typeof global.MessageChannel !== 'undefined') { + var channel = new global.MessageChannel(); + channel.port1.onmessage = nextTick; + scheduleDrain = function () { + channel.port2.postMessage(0); + }; + } else if ('document' in global && 'onreadystatechange' in global.document.createElement('script')) { + scheduleDrain = function () { + + // Create a + + +
+
+

Select Tree

+ + +
+
+

Tree Metrics

+
+
+ +
+

Decision Tree Structure

+
+
+ +
+

Activation Distributions

+ +
+ +
+

Example Samples

+
+
+
+ + + + + + + +``` + +**`js/viewer.js`:** ```javascript // Main viewer logic -async function loadTree(layer, targetIdx) { - const data = await fetch(`data/tree_${layer}_${targetIdx}.json`); - - displayMetadata(data.metadata); - displayHistograms(data.activations); // Use sparklines.js - displayTokens(data.tokens); // Use token-display.js (modified) - displayTree(data.tree); // NEW tree-display.js - displayInputFeatures(data.input_features); +let treeIndex = []; +let currentTree = null; + +async function init() { + // Load tree index + const response = await fetch('data/index.json'); + treeIndex = await response.json(); + + // Populate layer selector + const layers = [...new Set(treeIndex.map(t => t.layer))]; + const layerSelect = document.getElementById('layer-select'); + layers.forEach(layer => { + const option = document.createElement('option'); + option.value = layer; + option.text = `Layer ${layer}`; + layerSelect.appendChild(option); + }); + + // Event listeners + layerSelect.addEventListener('change', onLayerChange); + document.getElementById('target-select').addEventListener('change', onTargetChange); + + // Load first tree + if (treeIndex.length > 0) { + await loadTree(treeIndex[0].layer, treeIndex[0].target); + } } -``` -**Key functions to implement:** -- `displayHistograms()` - Call sparkbars() with true (blue) and predicted (red) -- `displayTokens()` - Modified token-display.js to overlay true (blue) + pred (red) -- `displayTree()` - Render tree structure (text or D3) +function onLayerChange() { + const layer = parseInt(document.getElementById('layer-select').value); + const trees = treeIndex.filter(t => t.layer === layer); -### Modified `token-display.js` + const targetSelect = document.getElementById('target-select'); + targetSelect.innerHTML = ''; + trees.forEach(tree => { + const option = document.createElement('option'); + option.value = tree.target; + option.text = `Target ${tree.target} (AP=${tree.ap.toFixed(3)})`; + targetSelect.appendChild(option); + }); -Need new function: -```javascript -function createDualActivationVisualization(tokens, trueActs, predProbs) { - // Overlay blue (true) and red (predicted) - // Perfect prediction = purple - // False negative = blue only - // False positive = red only + if (trees.length > 0) { + loadTree(layer, trees[0].target); + } +} + +async function loadTree(layer, target) { + const response = await fetch(`data/tree_${layer}_${target}.json`); + currentTree = await response.json(); + + displayMetrics(currentTree.metadata); + displayHistograms(currentTree.histograms); + displayTree(currentTree.tree); + displayTokenSamples(currentTree.samples); +} + +function displayMetrics(metadata) { + const m = metadata.metrics; + const cm = m.confusion_matrix; + + const html = ` + + + + + + + + + + +
AP:${m.ap.toFixed(3)}
Accuracy:${m.accuracy.toFixed(3)}
Balanced Acc:${m.balanced_accuracy.toFixed(3)}
Prevalence:${m.prevalence.toFixed(4)}
Confusion Matrix:
TP:${cm.TP}
TN:${cm.TN}
FP:${cm.FP}
FN:${cm.FN}
+ `; + document.getElementById('metrics').innerHTML = html; +} + +function displayHistograms(histograms) { + // Use sparklines.js to render dual histograms + const canvas = document.getElementById('hist-canvas'); + const ctx = canvas.getContext('2d'); + + // Draw true activations (blue) and predicted (red) overlaid + drawHistogram(ctx, histograms.true_activations, 'blue', 0); + drawHistogram(ctx, histograms.predicted_probabilities, 'red', 0); +} + +function displayTree(treeData) { + // Use tree-display.js to render D3 tree + renderDecisionTree('tree-svg', treeData); +} + +function displayTokenSamples(samples) { + const container = document.getElementById('samples-container'); + container.innerHTML = ''; + + samples.forEach(sample => { + const div = document.createElement('div'); + div.className = `sample sample-${sample.category}`; + div.innerHTML = ` +

${sample.category} (confidence: ${sample.confidence.toFixed(3)})

+
${renderTokens(sample)}
+ `; + container.appendChild(div); + }); +} + +function renderTokens(sample) { + // Create dual-color token visualization + // Blue background = true activation, Red = predicted + return sample.tokens.map((token, i) => { + const trueVal = sample.true_activations[i]; + const predVal = sample.predicted_probabilities[i]; + + // Dual gradient or side-by-side bars + return ` + ${token} + `; + }).join(' '); } + +// Initialize on load +init(); ``` -## JSON Schema Example +**`js/tree-display.js`:** +```javascript +function renderDecisionTree(containerId, treeData) { + const container = document.getElementById(containerId); + container.innerHTML = ''; -```json -{ - "metadata": { - "layer_index": 1, - "target_idx": 5, - "module_key": "blocks.0.mlp.W_gate", - "metrics": {"ap": 0.85, "accuracy": 0.92} - }, - "tree": { - "structure": { - "children_left": [1, -1, 3, ...], - "feature": [7, -2, 12, ...], - "feature_names": ["blocks.0.attn.W_Q:3", ...] - } - }, - "activations": { - "true": { - "histogram": {"bins": [...], "counts": [...]} - }, - "predicted": { - "histogram": {"bins": [...], "counts": [...]} + // Simple text-based tree for now + // Can upgrade to D3.js interactive tree later + + const textTree = buildTextTree(treeData.structure, treeData.feature_names); + const pre = document.createElement('pre'); + pre.textContent = textTree; + container.appendChild(pre); +} + +function buildTextTree(structure, featureNames, nodeIdx = 0, depth = 0) { + const indent = ' '.repeat(depth); + + if (structure.children_left[nodeIdx] === -1) { + // Leaf node + const value = structure.value[nodeIdx]; + const prediction = value[1] > value[0] ? 'ACTIVE' : 'INACTIVE'; + return `${indent}→ ${prediction} (${value[0]}/${value[1]})\n`; } - }, - "tokens": { - "data": [ - { - "tokens": ["The", "cat", "sat", ...], - "true_activations": [0.0, 0.0, 0.8, ...], - "predicted_probabilities": [0.05, 0.1, 0.85, ...] - } - ] - }, - "input_features": { - "blocks.0.attn.W_Q": [3, 17, 42], // Which components used - "blocks.0.mlp.W_in": [5, 12] - } + + // Internal node + const feature = structure.feature[nodeIdx]; + const threshold = structure.threshold[nodeIdx]; + const featureName = featureNames[feature]; + + let result = `${indent}${featureName} <= ${threshold}?\n`; + result += buildTextTree(structure, featureNames, structure.children_left[nodeIdx], depth + 1); + result += `${indent}else:\n`; + result += buildTextTree(structure, featureNames, structure.children_right[nodeIdx], depth + 1); + + return result; } ``` -## Open Questions +--- + +## Implementation Checklist + +### Phase 1: Static Plot Improvements +- [ ] Update `plot_layer_metrics()`: scatter with jitter instead of bars +- [ ] Add LaTeX titles to all metrics plots (TP/FP/TN/FN formulas) +- [ ] Update AP vs prevalence: log scale, no edges, color by depth +- [ ] Add AP vs prevalence heatmap to `plot_tree_statistics()` +- [ ] Implement `greedy_sort()` helper function +- [ ] Create `plot_activations_unsorted()` with layer boundaries +- [ ] Create `plot_activations_sorted()` with diff plot +- [ ] Create `plot_covariance_unsorted()` with layer boundaries +- [ ] Create `plot_covariance_sorted()` +- [ ] Update all plot titles with LaTeX and newlines +- [ ] Test with existing `run.py` workflow + +### Phase 2: Data Export +- [ ] Create `spd/clustering/ci_dt/export.py` +- [ ] Implement `export_tree_json()` +- [ ] Implement `export_all_trees()` +- [ ] Implement `select_token_samples()` with stratified sampling +- [ ] Implement `serialize_tree()`, `compute_tree_metrics()`, etc. +- [ ] Add export call to `run.py` +- [ ] Test JSON output schema + +### Phase 3: Interactive Viewer +- [ ] Create `ci_dt_vis/` directory structure +- [ ] Implement `index.html` layout +- [ ] Implement `viewer.js` tree selection and loading +- [ ] Implement `tree-display.js` text rendering (D3 optional) +- [ ] Implement `token-display.js` dual-color visualization +- [ ] Implement histogram rendering (reuse or adapt sparklines.js) +- [ ] Add CSS styling +- [ ] Test end-to-end workflow + +### Phase 4: Documentation +- [ ] Update `run.py` docstrings +- [ ] Add README in `ci_dt_vis/` explaining viewer usage +- [ ] Document JSON schema +- [ ] Add example screenshots + +--- + +## Open Questions / Design Decisions + +1. **Token samples per tree:** 8 total (2 per category) seems reasonable. Too many? +2. **Histogram bins:** 50 bins for activations, 20 for probabilities? +3. **D3.js tree or text?** Start with text, add D3 if needed +4. **Component sorting:** Should we also show a version with components sorted by layer, then by similarity within layer? +5. **File size:** Each tree JSON might be 50-200KB. With 1000s of trees, total size could be 50-200MB. Acceptable? +6. **Continuous activations for tokens:** Currently we only have binary. Need to save continuous pre-threshold values? + +--- + +## Success Metrics + +**Static Plots:** +- Plots are immediately interpretable without prior knowledge +- Titles explain abbreviations and formulas +- Layer boundaries visible in unsorted plots +- Sorting reveals structure (coactivation patterns) +- Diff plot clearly shows FP/FN errors -1. **How many token samples to show?** 5? 10? 50? -2. **Sample selection strategy?** Random vs stratified vs worst-predictions? -3. **Tree display?** Text vs D3.js interactive? -4. **Single big JSON vs many small files?** (Currently: many small files) -5. **Need continuous activations or just binary?** (Binary sufficient for token viz?) -6. **Histogram bins?** How many, linear or log scale? +**Interactive Viewer:** +- Can load and view any tree in <1 second +- Token examples clearly show where component activates +- Confusion matrix category examples are informative +- Tree structure is readable +- Histograms show activation distributions clearly From 12cac519b46ebd5d241f17e154f10245c16890ad Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 23 Oct 2025 15:51:17 +0100 Subject: [PATCH 49/77] wip --- spd/clustering/ci_dt/__init__.py | 4 - spd/clustering/ci_dt/plot.py | 286 +++++++++++++++++++++++++++---- spd/clustering/ci_dt/run.py | 60 ++++++- 3 files changed, 309 insertions(+), 41 deletions(-) diff --git a/spd/clustering/ci_dt/__init__.py b/spd/clustering/ci_dt/__init__.py index 7dde8ab07..2097aa211 100644 --- a/spd/clustering/ci_dt/__init__.py +++ b/spd/clustering/ci_dt/__init__.py @@ -14,8 +14,6 @@ train_trees, ) from spd.clustering.ci_dt.plot import ( - plot_activations, - plot_covariance, plot_layer_metrics, plot_selected_trees, ) @@ -35,8 +33,6 @@ "proba_for_layer", "get_estimator_for", # Plot - "plot_activations", - "plot_covariance", "plot_layer_metrics", "plot_selected_trees", ] diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 672e1b394..7518ac563 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -4,41 +4,267 @@ import matplotlib.pyplot as plt import numpy as np -from jaxtyping import Float, Int +from jaxtyping import Bool, Float, Int from sklearn.tree import plot_tree from spd.clustering.ci_dt.core import LayerModel, get_estimator_for -def plot_activations(layers_true: list[np.ndarray], layers_pred: list[np.ndarray]) -> None: - """Show true and predicted activations as heatmaps.""" - A_true: np.ndarray = np.concatenate(layers_true, axis=1) - A_pred: np.ndarray = np.concatenate([layers_pred[0]] + layers_pred[1:], axis=1) - fig1 = plt.figure(figsize=(10, 6)) - ax1 = fig1.add_subplot(2, 1, 1) - ax1.set_title("Activations (True)") - ax1.imshow(A_true, aspect="auto", interpolation="nearest") - ax1.set_xlabel("components (all layers concatenated)") - ax1.set_ylabel("samples") - ax2 = fig1.add_subplot(2, 1, 2) - ax2.set_title("Activations (Predicted)") - ax2.imshow(A_pred, aspect="auto", interpolation="nearest") - ax2.set_xlabel("components (all layers concatenated)") - ax2.set_ylabel("samples") - fig1.tight_layout() - - -def plot_covariance(layers_true: list[np.ndarray]) -> None: - """Plot covariance between all components across layers.""" - A: np.ndarray = np.concatenate(layers_true, axis=1).astype(float) - C: np.ndarray = np.cov(A, rowvar=False) - fig2 = plt.figure(figsize=(6, 6)) - ax = fig2.add_subplot(1, 1, 1) - ax.set_title("Covariance of components (all layers)") - ax.imshow(C, aspect="auto", interpolation="nearest") - ax.set_xlabel("component index") - ax.set_ylabel("component index") - fig2.tight_layout() +def greedy_sort(A: np.ndarray, axis: int) -> np.ndarray: + """Greedy ordering by cosine similarity. + + Starts from the most central item (highest average similarity to all others) + and greedily adds the nearest neighbor at each step. + + Args: + A: 2D array to sort + axis: 0 to sort rows (samples), 1 to sort columns (components) + + Returns: + Array of indices in sorted order + """ + # Transpose if sorting columns + if axis == 1: + A = A.T + + # Compute cosine similarity + norms: Float[np.ndarray, "n 1"] = np.linalg.norm(A, axis=1, keepdims=True) + norms = np.where(norms > 1e-8, norms, 1.0) # Avoid division by zero + A_normalized: Float[np.ndarray, "n d"] = A / norms + similarity: Float[np.ndarray, "n n"] = A_normalized @ A_normalized.T + + # Start from most central item (highest average similarity) + n: int = similarity.shape[0] + avg_sim: Float[np.ndarray, "n"] = similarity.mean(axis=1) + start_idx: int = int(np.argmax(avg_sim)) + + # Greedy ordering: always add nearest unvisited neighbor + ordered: list[int] = [start_idx] + remaining: set[int] = set(range(n)) + remaining.remove(start_idx) + current: int = start_idx + + while remaining: + # Find unvisited item with highest similarity to current + best_sim: float = -1.0 + best_idx: int = -1 + for idx in remaining: + sim: float = float(similarity[current, idx]) + if sim > best_sim: + best_sim = sim + best_idx = idx + + ordered.append(best_idx) + remaining.remove(best_idx) + current = best_idx + + return np.array(ordered, dtype=np.int64) + + +def add_component_labeling( + ax: plt.Axes, component_labels: list[str], axis: str = "x" +) -> None: + """Add component labeling using major/minor ticks to show module boundaries. + + Args: + ax: Matplotlib axis to modify + component_labels: List of component labels in format "module:index" + axis: Which axis to label ('x' or 'y') + """ + if not component_labels: + return + + # Extract module information + module_changes: list[int] = [] + current_module: str = component_labels[0].split(":")[0] + module_labels: list[str] = [] + + for i, label in enumerate(component_labels): + module: str = label.split(":")[0] + if module != current_module: + module_changes.append(i) + module_labels.append(current_module) + current_module = module + module_labels.append(current_module) + + # Set up major and minor ticks + # Minor ticks: every 10 components + minor_ticks: list[int] = list(range(0, len(component_labels), 10)) + + # Major ticks: module boundaries (start of each module) + major_ticks: list[int] = [0] + module_changes + major_labels_final: list[str] = module_labels + + if axis == "x": + ax.set_xticks(minor_ticks, minor=True) + ax.set_xticks(major_ticks) + ax.set_xticklabels(major_labels_final, rotation=45, ha="right") + ax.set_xlim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="x", which="minor", length=2, width=0.5) + ax.tick_params(axis="x", which="major", length=6, width=1.5) + for x in major_ticks: + ax.axvline(x - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) + else: + ax.set_yticks(minor_ticks, minor=True) + ax.set_yticks(major_ticks) + ax.set_yticklabels(major_labels_final) + ax.set_ylim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="y", which="minor", length=2, width=0.5) + ax.tick_params(axis="y", which="major", length=6, width=1.5) + for y in major_ticks: + ax.axhline(y - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) + + +def plot_activations( + layers_true: list[np.ndarray], + layers_pred: list[np.ndarray], + module_keys: list[str], + activation_threshold: float, + sample_order: np.ndarray | None = None, +) -> None: + """Plot true and predicted activations with optional sorting and diff. + + Args: + layers_true: List of boolean activation arrays per layer + layers_pred: List of predicted activation arrays per layer + module_keys: List of module names (e.g., ["blocks.0.attn.W_Q", ...]) + activation_threshold: Threshold used for binary conversion + sample_order: Optional array of sample indices for sorting. If None, plots unsorted. + """ + A_true: Float[np.ndarray, "n_samples n_components"] = np.concatenate( + layers_true, axis=1 + ).astype(float) + A_pred: Float[np.ndarray, "n_samples n_components"] = np.concatenate( + layers_pred, axis=1 + ).astype(float) + + # Apply sample ordering if provided + if sample_order is not None: + A_true = A_true[sample_order, :] + A_pred = A_pred[sample_order, :] + sorted_label: str = " (Sorted by Sample Similarity)" + xlabel: str = "Sample index (sorted)" + else: + sorted_label = "" + xlabel = "Sample index" + + # Create component labels for unsorted plots + component_labels: list[str] | None = None + if sample_order is None: + component_labels = [] + for module_key, layer in zip(module_keys, layers_true, strict=True): + n_components: int = layer.shape[1] + component_labels.extend([f"{module_key}:{i}" for i in range(n_components)]) + + # Determine number of subplots + n_plots: int = 3 if sample_order is not None else 2 + fig, axes = plt.subplots(n_plots, 1, figsize=(12, 6 * n_plots)) + if n_plots == 2: + ax1, ax2 = axes + else: + ax1, ax2, ax3 = axes + + # Plot true activations + ax1.imshow(A_true.T, aspect="auto", interpolation="nearest", cmap="Blues") + ax1.set_title( + rf"True Binary Activations{sorted_label}" + "\n" + r"$A_{ij} = \mathbb{1}[\text{activation}_{ij} > \theta]$, " + rf"$\theta = {activation_threshold}$" + ) + ax1.set_xlabel(xlabel) + ax1.set_ylabel("Component index") + if component_labels is not None: + add_component_labeling(ax1, component_labels, axis="y") + + # Plot predicted activations + ax2.imshow(A_pred.T, aspect="auto", interpolation="nearest", cmap="Reds") + ax2.set_title( + rf"Predicted Binary Activations{sorted_label}" + "\n" + r"$\hat{A}_{ij} = \mathbb{1}[P(A_{ij}=1) > 0.5]$" + ) + ax2.set_xlabel(xlabel) + ax2.set_ylabel("Component index") + if component_labels is not None: + add_component_labeling(ax2, component_labels, axis="y") + + # Add diff plot if sorted + if sample_order is not None: + A_diff: Float[np.ndarray, "n_samples n_components"] = A_pred - A_true + im3 = ax3.imshow( + A_diff.T, aspect="auto", interpolation="nearest", cmap="RdBu_r", vmin=-1, vmax=1 + ) + ax3.set_title( + r"Prediction Errors (Predicted - True)" + "\n" + r"Red = FP ($\hat{A}=1, A=0$), Blue = FN ($\hat{A}=0, A=1$), White = Correct" + ) + ax3.set_xlabel(xlabel) + ax3.set_ylabel("Component index") + plt.colorbar(im3, ax=ax3, label="Error") + + fig.tight_layout() + + +def plot_covariance( + layers_true: list[np.ndarray], + module_keys: list[str], + component_order: np.ndarray | None = None, +) -> None: + """Plot covariance matrix with optional component ordering. + + Args: + layers_true: List of boolean activation arrays per layer + module_keys: List of module names for labeling + component_order: Optional array of component indices for sorting. If None, plots unsorted. + """ + A: Float[np.ndarray, "n_samples n_components"] = np.concatenate( + layers_true, axis=1 + ).astype(float) + + # Apply component ordering if provided + if component_order is not None: + A = A[:, component_order] + sorted_label: str = " (Sorted by Component Similarity)" + xlabel: str = "Component index (sorted)" + ylabel: str = "Component index (sorted)" + else: + sorted_label = "" + xlabel = "Component index" + ylabel = "Component index" + + # Compute covariance + C: Float[np.ndarray, "n_components n_components"] = np.cov(A, rowvar=False) + + # Center colormap on 0 + vmax: float = float(np.abs(C).max()) + vmin: float = -vmax + + # Create component labels for unsorted plots + component_labels: list[str] | None = None + if component_order is None: + component_labels = [] + for module_key, layer in zip(module_keys, layers_true, strict=True): + n_components: int = layer.shape[1] + component_labels.extend([f"{module_key}:{i}" for i in range(n_components)]) + + fig, ax = plt.subplots(figsize=(10, 10)) + im = ax.imshow(C, aspect="auto", interpolation="nearest", cmap="RdBu_r", vmin=vmin, vmax=vmax) + ax.set_title( + rf"Component Covariance Matrix{sorted_label}" + "\n" + r"$\text{Cov}(i,j) = \mathbb{E}[(A_i - \mu_i)(A_j - \mu_j)]$" + "\n" + r"where $A_i$ is binary activation of component $i$" + ) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + # Add layer boundaries for unsorted + if component_labels is not None: + add_component_labeling(ax, component_labels, axis="x") + add_component_labeling(ax, component_labels, axis="y") + + plt.colorbar(im, ax=ax, label="Covariance") + fig.tight_layout() def plot_layer_metrics(per_layer_stats: list[dict[str, Any]]) -> None: diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 20692d09e..6d4427e17 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -16,6 +16,7 @@ convert_to_boolean_layers, ) from spd.clustering.ci_dt.plot import ( + greedy_sort, plot_activations, plot_covariance, plot_layer_metrics, @@ -36,8 +37,8 @@ # ----------------------- configuration ----------------------- config = CIDTConfig( - batch_size=50, # 50 ~~ 16GB VRAM - n_batches=4, + batch_size=8, # 50 ~~ 16GB VRAM + n_batches=2, activation_threshold=0.01, max_depth=8, random_state=42, @@ -133,19 +134,64 @@ plot_tree_statistics(models, per_layer_stats) print("Tree statistics plots generated.") +# %% +# ----------------------- compute orderings ----------------------- +# Generate sample ordering once for use in multiple plots + +# Get module keys for labeling +module_keys: list[str] = list(component_acts_concat.keys()) + +# Concatenate true activations for ordering +A_true_concat: np.ndarray = np.concatenate(layers_true, axis=1).astype(float) + +# Compute sample ordering by similarity +sample_order: np.ndarray = greedy_sort(A_true_concat, axis=0) +print(f"Computed sample ordering ({len(sample_order)} samples)") + # %% # ----------------------- plot: activations ----------------------- -# Simple heatmaps of true vs predicted activations +# Heatmaps of true vs predicted activations (unsorted and sorted) -plot_activations(layers_true, layers_pred) -print("Activation plots generated.") +# Unsorted version with layer boundaries +plot_activations( + layers_true=layers_true, + layers_pred=layers_pred, + module_keys=module_keys, + activation_threshold=config.activation_threshold, + sample_order=None, +) +print("Activation plots (unsorted) generated.") + +# Sorted version with diff plot +plot_activations( + layers_true=layers_true, + layers_pred=layers_pred, + module_keys=module_keys, + activation_threshold=config.activation_threshold, + sample_order=sample_order, +) +print("Activation plots (sorted by samples) generated.") # %% # ----------------------- plot: covariance ----------------------- # Covariance matrix - can be slow with many components -plot_covariance(layers_true) -print("Covariance plot generated.") +# Unsorted version with layer boundaries +plot_covariance( + layers_true=layers_true, + module_keys=module_keys, + component_order=None, +) +print("Covariance plot (unsorted) generated.") + +# Sorted version by component similarity +component_order: np.ndarray = greedy_sort(A_true_concat, axis=1) +plot_covariance( + layers_true=layers_true, + module_keys=module_keys, + component_order=component_order, +) +print("Covariance plot (sorted by components) generated.") # %% # ----------------------- generate feature names ----------------------- From 0661120ce514b940e7db27de08da0d575a66d4ef Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 23 Oct 2025 16:41:15 +0100 Subject: [PATCH 50/77] wip --- spd/clustering/ci_dt/pipeline.py | 20 +- spd/clustering/ci_dt/plot.py | 302 +++++++++++++++++++++++++++---- spd/clustering/ci_dt/run.py | 114 ++++++------ 3 files changed, 335 insertions(+), 101 deletions(-) diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 2f231c53a..5467dad91 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -80,6 +80,7 @@ def compute_activations_multibatch( def convert_to_boolean_layers( component_acts: dict[str, Tensor], activation_threshold: float, + verbose: bool = False, ) -> list[Bool[np.ndarray, "n_samples n_components"]]: """Convert activations to boolean, filter constant (always dead/alive) components. @@ -120,15 +121,16 @@ def convert_to_boolean_layers( ] layers_true.append(module_acts_varying) - n_varying: int = module_acts_varying.shape[1] - n_total: int = module_acts_bool.shape[1] - print( - f" {module_key:30s} {n_varying:5d} varying, {n_always_dead:5d} dead, {n_always_alive:5d} const, {n_total:5d} total", - flush=True, - ) - dbg_tensor(module_acts_np) - dbg_tensor(module_acts_bool) - dbg_tensor(module_acts_varying) + if verbose: + n_varying: int = module_acts_varying.shape[1] + n_total: int = module_acts_bool.shape[1] + print( + f" {module_key:30s} {n_varying:5d} varying, {n_always_dead:5d} dead, {n_always_alive:5d} const, {n_total:5d} total", + flush=True, + ) + dbg_tensor(module_acts_np) + dbg_tensor(module_acts_bool) + dbg_tensor(module_acts_varying) print(f"\nCreated {len(layers_true)} layers for decision tree training") diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 7518ac563..f9483d201 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -267,51 +267,217 @@ def plot_covariance( fig.tight_layout() -def plot_layer_metrics(per_layer_stats: list[dict[str, Any]]) -> None: - """Plot summary metrics per layer and per-target AP vs prevalence.""" +def plot_layer_metrics( + per_layer_stats: list[dict[str, Any]], + models: list[LayerModel], + module_keys: list[str], + component_acts: dict[str, np.ndarray], + activation_threshold: float, +) -> None: + """Plot distributions of metrics per layer with scatter plots and jitter. + + Args: + per_layer_stats: List of dicts with metrics per layer + models: List of trained LayerModel objects (needed for tree depths) + module_keys: List of module names for x-axis labels + component_acts: Dict of continuous activations per module + activation_threshold: Threshold used for binary conversion + """ L: int = len(per_layer_stats) - mean_ap: np.ndarray = np.array([d["mean_ap"] for d in per_layer_stats]) - mean_acc: np.ndarray = np.array([d["mean_acc"] for d in per_layer_stats]) - mean_bacc: np.ndarray = np.array([d["mean_bacc"] for d in per_layer_stats]) - - # bar: mean AP, ACC, BACC per layer (three separate figures to respect one-plot rule) - fig3 = plt.figure(figsize=(8, 3)) - ax3 = fig3.add_subplot(1, 1, 1) - ax3.set_title("Mean Average Precision per layer") - ax3.bar(np.arange(1, L + 1), mean_ap) - ax3.set_xlabel("layer index (target)") - ax3.set_ylabel("mean AP") + + # Prepare data: all values per layer with jitter for visualization + np.random.seed(42) # Reproducible jitter + jitter_amount: float = 0.15 + + # AP per layer + fig1, ax1 = plt.subplots(figsize=(10, 5)) + for layer_idx, stats in enumerate(per_layer_stats): + ap_values: np.ndarray = stats["ap"] + # Remove NaN values + ap_valid: np.ndarray = ap_values[~np.isnan(ap_values)] + if len(ap_valid) > 0: + # Add horizontal jitter + x_positions: np.ndarray = np.ones(len(ap_valid)) * (layer_idx + 1) + x_jittered: np.ndarray = x_positions + np.random.uniform( + -jitter_amount, jitter_amount, len(ap_valid) + ) + ax1.scatter(x_jittered, ap_valid, alpha=0.5, s=20, color="C0", edgecolors='none') + # Add mean line + ax1.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + [stats["mean_ap"], stats["mean_ap"]], + 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') + + ax1.set_title( + r"Average Precision per Target Component" + "\n" + r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$ where " + r"$P_n = \frac{\text{TP}}{\text{TP}+\text{FP}}$, " + r"$R_n = \frac{\text{TP}}{\text{TP}+\text{FN}}$" + ) + ax1.set_xlabel("Target Module") + ax1.set_ylabel("Average Precision") + ax1.set_xticks(np.arange(1, L + 1)) + # Only use module keys that correspond to target layers (skip input layer) + ax1.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax1.set_ylim(-0.05, 1.05) + ax1.grid(True, alpha=0.3, axis='y') + ax1.legend() + fig1.tight_layout() + + # Accuracy per layer + fig2, ax2 = plt.subplots(figsize=(10, 5)) + for layer_idx, stats in enumerate(per_layer_stats): + acc_values: np.ndarray = stats["acc"] + acc_valid: np.ndarray = acc_values[~np.isnan(acc_values)] + if len(acc_valid) > 0: + x_positions = np.ones(len(acc_valid)) * (layer_idx + 1) + x_jittered = x_positions + np.random.uniform( + -jitter_amount, jitter_amount, len(acc_valid) + ) + ax2.scatter(x_jittered, acc_valid, alpha=0.5, s=20, color="C1", edgecolors='none') + ax2.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + [stats["mean_acc"], stats["mean_acc"]], + 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') + + ax2.set_title( + r"Accuracy per Target Component" + "\n" + r"$\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}}$" + ) + ax2.set_xlabel("Target Module") + ax2.set_ylabel("Accuracy") + ax2.set_xticks(np.arange(1, L + 1)) + ax2.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax2.set_ylim(-0.05, 1.05) + ax2.grid(True, alpha=0.3, axis='y') + ax2.legend() + fig2.tight_layout() + + # Balanced Accuracy per layer + fig3, ax3 = plt.subplots(figsize=(10, 5)) + for layer_idx, stats in enumerate(per_layer_stats): + bacc_values: np.ndarray = stats["bacc"] + bacc_valid: np.ndarray = bacc_values[~np.isnan(bacc_values)] + if len(bacc_valid) > 0: + x_positions = np.ones(len(bacc_valid)) * (layer_idx + 1) + x_jittered = x_positions + np.random.uniform( + -jitter_amount, jitter_amount, len(bacc_valid) + ) + ax3.scatter(x_jittered, bacc_valid, alpha=0.5, s=20, color="C2", edgecolors='none') + ax3.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + [stats["mean_bacc"], stats["mean_bacc"]], + 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') + + ax3.set_title( + r"Balanced Accuracy per Target Component" + "\n" + r"$\text{Balanced Acc} = \frac{1}{2}\left(\frac{\text{TP}}{\text{TP}+\text{FN}} + \frac{\text{TN}}{\text{TN}+\text{FP}}\right)$" + ) + ax3.set_xlabel("Target Module") + ax3.set_ylabel("Balanced Accuracy") + ax3.set_xticks(np.arange(1, L + 1)) + ax3.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax3.set_ylim(-0.05, 1.05) + ax3.grid(True, alpha=0.3, axis='y') + ax3.legend() fig3.tight_layout() - fig4 = plt.figure(figsize=(8, 3)) - ax4 = fig4.add_subplot(1, 1, 1) - ax4.set_title("Mean Accuracy per layer") - ax4.bar(np.arange(1, L + 1), mean_acc) - ax4.set_xlabel("layer index (target)") - ax4.set_ylabel("mean accuracy") + # AP vs prevalence scatter with tree depth coloring + fig4, ax4 = plt.subplots(figsize=(8, 6)) + + prevalence_list: list[float] = [] + ap_list: list[float] = [] + depth_list: list[int] = [] + + for layer_idx, (stats, model) in enumerate(zip(per_layer_stats, models, strict=True)): + for target_idx, (prev, ap) in enumerate(zip(stats["prev"], stats["ap"], strict=True)): + if not np.isnan(ap): + prevalence_list.append(prev) + ap_list.append(ap) + # Get tree depth for this target + estimator = model.model.estimators_[target_idx] + depth_list.append(int(estimator.tree_.max_depth)) + + prevalence_arr: np.ndarray = np.array(prevalence_list) + ap_arr: np.ndarray = np.array(ap_list) + depth_arr: np.ndarray = np.array(depth_list) + + scatter = ax4.scatter( + prevalence_arr, + ap_arr, + c=depth_arr, + cmap="viridis", + alpha=0.6, + s=30, + edgecolors='none', + ) + + ax4.set_title( + r"Average Precision vs Component Prevalence" + "\n" + r"Prevalence = $\frac{n_{\text{active samples}}}{n_{\text{total samples}}}$, colored by tree depth" + ) + ax4.set_xlabel("Prevalence (log scale)") + ax4.set_ylabel("Average Precision") + ax4.set_xscale("log") + ax4.set_ylim(-0.05, 1.05) + ax4.grid(True, alpha=0.3) + + cbar = plt.colorbar(scatter, ax=ax4) + cbar.set_label("Tree Depth") + fig4.tight_layout() - fig5 = plt.figure(figsize=(8, 3)) - ax5 = fig5.add_subplot(1, 1, 1) - ax5.set_title("Mean Balanced Accuracy per layer") - ax5.bar(np.arange(1, L + 1), mean_bacc) - ax5.set_xlabel("layer index (target)") - ax5.set_ylabel("mean balanced accuracy") - fig5.tight_layout() + # Component activity breakdown per module + fig5, ax5 = plt.subplots(figsize=(12, 6)) + + # Compute counts for each module + n_varying_list: list[int] = [] + n_always_dead_list: list[int] = [] + n_always_alive_list: list[int] = [] + + for module_key in module_keys: + acts: np.ndarray = component_acts[module_key] + # Convert to numpy if needed + if hasattr(acts, 'cpu'): + acts = acts.cpu().numpy() + # Convert to boolean + acts_bool: np.ndarray = (acts >= activation_threshold).astype(bool) + + # Count each category + always_dead: np.ndarray = ~acts_bool.any(axis=0) + always_alive: np.ndarray = acts_bool.all(axis=0) + varying: np.ndarray = ~(always_dead | always_alive) + + n_always_dead_list.append(int(always_dead.sum())) + n_always_alive_list.append(int(always_alive.sum())) + n_varying_list.append(int(varying.sum())) + + # Convert to arrays + n_varying: np.ndarray = np.array(n_varying_list) + n_always_dead: np.ndarray = np.array(n_always_dead_list) + n_always_alive: np.ndarray = np.array(n_always_alive_list) + n_total_per_module: np.ndarray = n_varying + n_always_dead + n_always_alive + + # Sort modules by total components (smallest to largest) + sort_idx: np.ndarray = np.argsort(n_total_per_module) + module_keys_sorted: list[str] = [module_keys[i] for i in sort_idx] + n_varying_sorted: np.ndarray = n_varying[sort_idx] + n_always_dead_sorted: np.ndarray = n_always_dead[sort_idx] + n_always_alive_sorted: np.ndarray = n_always_alive[sort_idx] + + # Create stacked bar chart + x_pos: np.ndarray = np.arange(len(module_keys_sorted)) + ax5.bar(x_pos, n_varying_sorted, label="Varying", color="C3") + ax5.bar(x_pos, n_always_alive_sorted, bottom=n_varying_sorted, label="Always Active", color="C2") + ax5.bar(x_pos, n_always_dead_sorted, bottom=n_varying_sorted + n_always_alive_sorted, label="Always Inactive", color="C1") + + ax5.set_title("Component Activity Distribution per Module") + ax5.set_xlabel("Module (sorted by total component count)") + ax5.set_ylabel("Number of Components (log scale)") + ax5.set_xticks(x_pos) + ax5.set_xticklabels(module_keys_sorted, rotation=45, ha='right') + # ax5.set_yscale('log') + ax5.legend(loc='upper left') + ax5.grid(True, alpha=0.3, axis='y') - # scatter: prevalence vs AP for all targets across layers - fig6 = plt.figure(figsize=(6, 5)) - ax6 = fig6.add_subplot(1, 1, 1) - ax6.set_title("Per-target AP vs prevalence") - x_list: list[float] = [] - y_list: list[float] = [] - for d in per_layer_stats: - x_list.extend(list(d["prev"])) - y_list.extend(list(d["ap"])) - ax6.scatter(x_list, y_list, alpha=0.6) - ax6.set_xlabel("prevalence") - ax6.set_ylabel("average precision") - fig6.tight_layout() + fig5.tight_layout() def plot_selected_trees( @@ -474,3 +640,61 @@ def plot_tree_statistics( if count > 0: ax6.text(i, j, str(count), ha="center", va="center") plt.colorbar(im, ax=ax6, label="log10(count+1)") + + # Heatmap: AP vs prevalence + # Need to compute prevalence for each tree from per_layer_stats + prevalence_list: list[float] = [] + ap_list_for_heatmap: list[float] = [] + + for layer_stats in per_layer_stats: + for prev, ap in zip(layer_stats["prev"], layer_stats["ap"], strict=True): + if not np.isnan(ap): + prevalence_list.append(prev) + ap_list_for_heatmap.append(ap) + + prevalence_arr: np.ndarray = np.array(prevalence_list) + ap_arr_for_heatmap: np.ndarray = np.array(ap_list_for_heatmap) + + # Prevalence bins (log scale) + prev_min: float = max(prevalence_arr.min(), 1e-4) # Avoid log(0) + prev_max: float = prevalence_arr.max() + prev_bins: Float[np.ndarray, "n_bins"] = np.logspace( + np.log10(prev_min), np.log10(prev_max), 10 + ) + + # AP bins (linear) + ap_bins_heatmap: Float[np.ndarray, "n_bins"] = np.linspace(0, 1, 11) + + heatmap_prev_ap: Float[np.ndarray, "prev_bins ap_bins"] + heatmap_prev_ap, _, _ = np.histogram2d( + prevalence_arr, ap_arr_for_heatmap, bins=[prev_bins, ap_bins_heatmap] + ) + + fig7, ax7 = plt.subplots(figsize=(8, 6)) + heatmap_log = np.log10(heatmap_prev_ap.T + 1) + im = ax7.imshow(heatmap_log, origin="lower", aspect="auto", cmap="Blues") + + # X-axis: prevalence (log scale) + ax7.set_xticks(range(len(prev_bins) - 1)) + ax7.set_xticklabels([f"{x:.3f}" for x in prev_bins[:-1]], rotation=45, ha="right") + ax7.set_xlabel("Prevalence (log scale)") + + # Y-axis: AP + ax7.set_yticks(range(len(ap_bins_heatmap) - 1)) + ax7.set_yticklabels([f"{x:.1f}" for x in ap_bins_heatmap[:-1]]) + ax7.set_ylabel("Average Precision") + + ax7.set_title( + r"Tree Performance vs Component Prevalence" + "\n" + r"AP = Average Precision, Prev = $\frac{n_{\text{active}}}{n_{\text{total}}}$" + ) + + # Add counts to cells + for i in range(len(prev_bins) - 1): + for j in range(len(ap_bins_heatmap) - 1): + count = int(heatmap_prev_ap[i, j]) + if count > 0: + ax7.text(i, j, str(count), ha="center", va="center", fontsize=8) + + plt.colorbar(im, ax=ax7, label="log10(count+1)") + fig7.tight_layout() diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 6d4427e17..c83434195 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -37,8 +37,10 @@ # ----------------------- configuration ----------------------- config = CIDTConfig( - batch_size=8, # 50 ~~ 16GB VRAM - n_batches=2, + # batch_size=50, # 50 ~~ 16GB VRAM max + # n_batches=8, + batch_size=16, + n_batches=4, activation_threshold=0.01, max_depth=8, random_state=42, @@ -120,20 +122,6 @@ layers_true=layers_true, ) -# %% -# ----------------------- plot: layer metrics ----------------------- -# Simplest - just bar charts and scatter plot of summary statistics - -plot_layer_metrics(per_layer_stats) -print("Layer metrics plots generated.") - -# %% -# ----------------------- plot: tree statistics ----------------------- -# Distributions of tree depth, leaf counts, and correlations with accuracy - -plot_tree_statistics(models, per_layer_stats) -print("Tree statistics plots generated.") - # %% # ----------------------- compute orderings ----------------------- # Generate sample ordering once for use in multiple plots @@ -148,29 +136,49 @@ sample_order: np.ndarray = greedy_sort(A_true_concat, axis=0) print(f"Computed sample ordering ({len(sample_order)} samples)") +# %% +# ----------------------- plot: layer metrics ----------------------- +# Scatter plots with jitter showing distribution of metrics per layer + +plot_layer_metrics( + per_layer_stats, + models, + module_keys, + component_acts_concat, + config.activation_threshold, +) +print("Layer metrics plots generated.") + +# %% +# ----------------------- plot: tree statistics ----------------------- +# Distributions of tree depth, leaf counts, and correlations with accuracy + +plot_tree_statistics(models, per_layer_stats) +print("Tree statistics plots generated.") + # %% # ----------------------- plot: activations ----------------------- # Heatmaps of true vs predicted activations (unsorted and sorted) # Unsorted version with layer boundaries -plot_activations( - layers_true=layers_true, - layers_pred=layers_pred, - module_keys=module_keys, - activation_threshold=config.activation_threshold, - sample_order=None, -) -print("Activation plots (unsorted) generated.") - -# Sorted version with diff plot -plot_activations( - layers_true=layers_true, - layers_pred=layers_pred, - module_keys=module_keys, - activation_threshold=config.activation_threshold, - sample_order=sample_order, -) -print("Activation plots (sorted by samples) generated.") +# plot_activations( +# layers_true=layers_true, +# layers_pred=layers_pred, +# module_keys=module_keys, +# activation_threshold=config.activation_threshold, +# sample_order=None, +# ) +# print("Activation plots (unsorted) generated.") + +# # Sorted version with diff plot +# plot_activations( +# layers_true=layers_true, +# layers_pred=layers_pred, +# module_keys=module_keys, +# activation_threshold=config.activation_threshold, +# sample_order=sample_order, +# ) +# print("Activation plots (sorted by samples) generated.") # %% # ----------------------- plot: covariance ----------------------- @@ -197,31 +205,31 @@ # ----------------------- generate feature names ----------------------- # Generate feature names with activation statistics and decoded directions -from spd.clustering.ci_dt.feature_names import generate_feature_names +# from spd.clustering.ci_dt.feature_names import generate_feature_names -module_keys = list(component_acts_concat.keys()) +# module_keys = list(component_acts_concat.keys()) -feature_names = generate_feature_names( - component_model=model, - component_acts=component_acts_concat, - layers_true=layers_true, - layers_pred=layers_pred, - tokenizer=cfg.task_config.tokenizer if hasattr(cfg.task_config, 'tokenizer') else None, - module_keys=module_keys, - top_k=3, -) -print("Feature names generated.") +# feature_names = generate_feature_names( +# component_model=model, +# component_acts=component_acts_concat, +# layers_true=layers_true, +# layers_pred=layers_pred, +# tokenizer=cfg.task_config.tokenizer if hasattr(cfg.task_config, 'tokenizer') else None, +# module_keys=module_keys, +# top_k=3, +# ) +# print("Feature names generated.") # %% # ----------------------- plot: worst trees ----------------------- # Decision tree visualization for worst performing trees -plot_selected_trees(worst_list, "Worst", models, feature_names=feature_names) -print("Worst trees plots generated.") +# plot_selected_trees(worst_list, "Worst", models, feature_names=feature_names) +# print("Worst trees plots generated.") -# %% -# ----------------------- plot: best trees ----------------------- -# Decision tree visualization for best performing trees +# # %% +# # ----------------------- plot: best trees ----------------------- +# # Decision tree visualization for best performing trees -plot_selected_trees(best_list, "Best", models, feature_names=feature_names) -print("Best trees plots generated.") +# plot_selected_trees(best_list, "Best", models, feature_names=feature_names) +# print("Best trees plots generated.") From f75fd9fdee683ef7c8ad2d9683ad286d6f199fbd Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 23 Oct 2025 16:48:23 +0100 Subject: [PATCH 51/77] wip --- spd/clustering/ci_dt/plot.py | 67 +++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index f9483d201..3c8371ddc 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -453,28 +453,61 @@ def plot_layer_metrics( n_varying: np.ndarray = np.array(n_varying_list) n_always_dead: np.ndarray = np.array(n_always_dead_list) n_always_alive: np.ndarray = np.array(n_always_alive_list) - n_total_per_module: np.ndarray = n_varying + n_always_dead + n_always_alive - # Sort modules by total components (smallest to largest) - sort_idx: np.ndarray = np.argsort(n_total_per_module) - module_keys_sorted: list[str] = [module_keys[i] for i in sort_idx] - n_varying_sorted: np.ndarray = n_varying[sort_idx] - n_always_dead_sorted: np.ndarray = n_always_dead[sort_idx] - n_always_alive_sorted: np.ndarray = n_always_alive[sort_idx] - - # Create stacked bar chart - x_pos: np.ndarray = np.arange(len(module_keys_sorted)) - ax5.bar(x_pos, n_varying_sorted, label="Varying", color="C3") - ax5.bar(x_pos, n_always_alive_sorted, bottom=n_varying_sorted, label="Always Active", color="C2") - ax5.bar(x_pos, n_always_dead_sorted, bottom=n_varying_sorted + n_always_alive_sorted, label="Always Inactive", color="C1") + # For each module, sort the three categories by size (smallest to largest) + # This will be stacked bottom-to-top as smallest, medium, largest + x_pos: np.ndarray = np.arange(len(module_keys)) + + # Build stacked bars with sorted segments per module + bottom_vals: np.ndarray = np.zeros(len(module_keys)) + + # Collect all three categories with labels + categories: list[tuple[str, np.ndarray, str]] = [ + ("Varying", n_varying, "C2"), + ("Always Active", n_always_alive, "C1"), + ("Always Inactive", n_always_dead, "C0"), + ] + + # For each position, we need to stack in order of size + # We'll plot all bars for the smallest category first, then medium, then largest + for module_idx in range(len(module_keys)): + # Get values for this module + vals: list[tuple[float, str, str]] = [ + (n_varying[module_idx], "Varying", "C2"), + (n_always_alive[module_idx], "Always Active", "C1"), + (n_always_dead[module_idx], "Always Inactive", "C0"), + ] + # Sort by value (smallest to largest) + vals.sort(key=lambda x: x[0]) + + # Stack them + bottom: float = 0 + for val, label, color in vals: + if val > 0: # Only plot if non-zero + ax5.bar( + module_idx, + val, + bottom=bottom, + color=color, + label=label if module_idx == 0 else "", # Only label once + ) + bottom += val ax5.set_title("Component Activity Distribution per Module") - ax5.set_xlabel("Module (sorted by total component count)") + ax5.set_xlabel("Module") ax5.set_ylabel("Number of Components (log scale)") ax5.set_xticks(x_pos) - ax5.set_xticklabels(module_keys_sorted, rotation=45, ha='right') - # ax5.set_yscale('log') - ax5.legend(loc='upper left') + ax5.set_xticklabels(module_keys, rotation=45, ha='right') + ax5.set_yscale('log') + + # Create legend with correct labels + from matplotlib.patches import Patch + legend_elements = [ + Patch(facecolor='C2', label='Varying'), + Patch(facecolor='C1', label='Always Active'), + Patch(facecolor='C0', label='Always Inactive'), + ] + ax5.legend(handles=legend_elements, loc='upper left') ax5.grid(True, alpha=0.3, axis='y') fig5.tight_layout() From 23ba0550cd2c963e421d5d008e2c9c9a116eed49 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 23 Oct 2025 17:12:15 +0100 Subject: [PATCH 52/77] wip --- spd/clustering/ci_dt/pipeline.py | 4 +- spd/clustering/ci_dt/plot.py | 83 +++++++++++++++++--------------- spd/clustering/ci_dt/run.py | 24 ++++++--- 3 files changed, 63 insertions(+), 48 deletions(-) diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 5467dad91..ea022eccf 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -4,8 +4,8 @@ import numpy as np import torch -from jaxtyping import Bool, Float -from torch import Tensor +from jaxtyping import Bool, Float, Int +from torch import Tensor, nn from torch.utils.data import DataLoader from muutils.dbg import dbg_tensor from tqdm import tqdm diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 3c8371ddc..f9172a8e3 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -269,19 +269,13 @@ def plot_covariance( def plot_layer_metrics( per_layer_stats: list[dict[str, Any]], - models: list[LayerModel], module_keys: list[str], - component_acts: dict[str, np.ndarray], - activation_threshold: float, ) -> None: """Plot distributions of metrics per layer with scatter plots and jitter. Args: per_layer_stats: List of dicts with metrics per layer - models: List of trained LayerModel objects (needed for tree depths) module_keys: List of module names for x-axis labels - component_acts: Dict of continuous activations per module - activation_threshold: Threshold used for binary conversion """ L: int = len(per_layer_stats) @@ -379,8 +373,17 @@ def plot_layer_metrics( ax3.legend() fig3.tight_layout() - # AP vs prevalence scatter with tree depth coloring - fig4, ax4 = plt.subplots(figsize=(8, 6)) + +def plot_ap_vs_prevalence( + per_layer_stats: list[dict[str, Any]], models: list[LayerModel] +) -> None: + """Plot AP vs prevalence scatter colored by tree depth. + + Args: + per_layer_stats: List of dicts with metrics per layer + models: List of trained LayerModel objects (needed for tree depths) + """ + fig, ax = plt.subplots(figsize=(8, 6)) prevalence_list: list[float] = [] ap_list: list[float] = [] @@ -399,7 +402,7 @@ def plot_layer_metrics( ap_arr: np.ndarray = np.array(ap_list) depth_arr: np.ndarray = np.array(depth_list) - scatter = ax4.scatter( + scatter = ax.scatter( prevalence_arr, ap_arr, c=depth_arr, @@ -409,23 +412,35 @@ def plot_layer_metrics( edgecolors='none', ) - ax4.set_title( + ax.set_title( r"Average Precision vs Component Prevalence" + "\n" r"Prevalence = $\frac{n_{\text{active samples}}}{n_{\text{total samples}}}$, colored by tree depth" ) - ax4.set_xlabel("Prevalence (log scale)") - ax4.set_ylabel("Average Precision") - ax4.set_xscale("log") - ax4.set_ylim(-0.05, 1.05) - ax4.grid(True, alpha=0.3) + ax.set_xlabel("Prevalence (log scale)") + ax.set_ylabel("Average Precision") + ax.set_xscale("log") + ax.set_ylim(-0.05, 1.05) + ax.grid(True, alpha=0.3) - cbar = plt.colorbar(scatter, ax=ax4) + cbar = plt.colorbar(scatter, ax=ax) cbar.set_label("Tree Depth") - fig4.tight_layout() + fig.tight_layout() + - # Component activity breakdown per module - fig5, ax5 = plt.subplots(figsize=(12, 6)) +def plot_component_activity_breakdown( + component_acts: dict[str, np.ndarray], + module_keys: list[str], + activation_threshold: float, +) -> None: + """Plot stacked bar chart of component activity breakdown per module. + + Args: + component_acts: Dict of continuous activations per module + module_keys: List of module names for x-axis labels + activation_threshold: Threshold used for binary conversion + """ + fig, ax = plt.subplots(figsize=(12, 6)) # Compute counts for each module n_varying_list: list[int] = [] @@ -458,16 +473,6 @@ def plot_layer_metrics( # This will be stacked bottom-to-top as smallest, medium, largest x_pos: np.ndarray = np.arange(len(module_keys)) - # Build stacked bars with sorted segments per module - bottom_vals: np.ndarray = np.zeros(len(module_keys)) - - # Collect all three categories with labels - categories: list[tuple[str, np.ndarray, str]] = [ - ("Varying", n_varying, "C2"), - ("Always Active", n_always_alive, "C1"), - ("Always Inactive", n_always_dead, "C0"), - ] - # For each position, we need to stack in order of size # We'll plot all bars for the smallest category first, then medium, then largest for module_idx in range(len(module_keys)): @@ -484,7 +489,7 @@ def plot_layer_metrics( bottom: float = 0 for val, label, color in vals: if val > 0: # Only plot if non-zero - ax5.bar( + ax.bar( module_idx, val, bottom=bottom, @@ -493,12 +498,12 @@ def plot_layer_metrics( ) bottom += val - ax5.set_title("Component Activity Distribution per Module") - ax5.set_xlabel("Module") - ax5.set_ylabel("Number of Components (log scale)") - ax5.set_xticks(x_pos) - ax5.set_xticklabels(module_keys, rotation=45, ha='right') - ax5.set_yscale('log') + ax.set_title("Component Activity Distribution per Module") + ax.set_xlabel("Module") + ax.set_ylabel("Number of Components (log scale)") + ax.set_xticks(x_pos) + ax.set_xticklabels(module_keys, rotation=45, ha='right') + ax.set_yscale('log') # Create legend with correct labels from matplotlib.patches import Patch @@ -507,10 +512,10 @@ def plot_layer_metrics( Patch(facecolor='C1', label='Always Active'), Patch(facecolor='C0', label='Always Inactive'), ] - ax5.legend(handles=legend_elements, loc='upper left') - ax5.grid(True, alpha=0.3, axis='y') + ax.legend(handles=legend_elements, loc='upper left') + ax.grid(True, alpha=0.3, axis='y') - fig5.tight_layout() + fig.tight_layout() def plot_selected_trees( diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index c83434195..b914bb13d 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -18,6 +18,8 @@ from spd.clustering.ci_dt.plot import ( greedy_sort, plot_activations, + plot_ap_vs_prevalence, + plot_component_activity_breakdown, plot_covariance, plot_layer_metrics, plot_selected_trees, @@ -140,15 +142,23 @@ # ----------------------- plot: layer metrics ----------------------- # Scatter plots with jitter showing distribution of metrics per layer -plot_layer_metrics( - per_layer_stats, - models, - module_keys, - component_acts_concat, - config.activation_threshold, -) +plot_layer_metrics(per_layer_stats, module_keys) print("Layer metrics plots generated.") +# %% +# ----------------------- plot: AP vs prevalence ----------------------- + +plot_ap_vs_prevalence(per_layer_stats, models) +print("AP vs prevalence plot generated.") + +# %% +# ----------------------- plot: component activity breakdown ----------------------- + +plot_component_activity_breakdown( + component_acts_concat, module_keys, config.activation_threshold +) +print("Component activity breakdown plot generated.") + # %% # ----------------------- plot: tree statistics ----------------------- # Distributions of tree depth, leaf counts, and correlations with accuracy From dc15ee4d5dc17ec4cd91e722d9e23d6476df0b89 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 23 Oct 2025 17:25:08 +0100 Subject: [PATCH 53/77] wip --- spd/clustering/ci_dt/attn.py | 434 +++++++++++++++++++++++++++++++++++ spd/clustering/ci_dt/plot.py | 127 ++++++---- spd/clustering/ci_dt/run.py | 30 ++- 3 files changed, 535 insertions(+), 56 deletions(-) create mode 100644 spd/clustering/ci_dt/attn.py diff --git a/spd/clustering/ci_dt/attn.py b/spd/clustering/ci_dt/attn.py new file mode 100644 index 000000000..d73ce535c --- /dev/null +++ b/spd/clustering/ci_dt/attn.py @@ -0,0 +1,434 @@ +# %% +"""Attention pattern visualization for CI decision tree analysis.""" + +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Float, Int +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm + +from spd.clustering.ci_dt.config import CIDTConfig +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, SPDRunInfo + + +# magic autoreload +%load_ext autoreload +%autoreload 2 + +# %% +# ----------------------- configuration ----------------------- + +config = CIDTConfig( + batch_size=16, + n_batches=4, + activation_threshold=0.01, + max_depth=8, + random_state=42, +) +device: str = "cuda" if torch.cuda.is_available() else "cpu" + +# %% +# ----------------------- load model ----------------------- + +wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" + +spd_run: SPDRunInfo = SPDRunInfo.from_path(wandb_run_path) +model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) +model.to(device) +cfg: Config = spd_run.config + +print(f"Loaded model from {wandb_run_path}") + +# %% +# ----------------------- load dataset ----------------------- + +# Create LM dataset and dataloader +assert isinstance(cfg.task_config, LMTaskConfig) +pretrained_model_name = cfg.pretrained_model_name +assert pretrained_model_name is not None + +dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + column_name=cfg.task_config.column_name, + is_tokenized=False, + streaming=False, + seed=0, +) +dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=config.batch_size, + buffer_size=cfg.task_config.buffer_size, + global_seed=cfg.seed, + ddp_rank=0, + ddp_world_size=1, +) +print(f"Created LM dataset with {cfg.task_config.dataset_name}") + +# %% +# ----------------------- extract attention patterns ----------------------- + + +def extract_attention_patterns_multibatch( + model: ComponentModel, + device: torch.device | str, + dataloader: DataLoader, + n_batches: int, +) -> dict[str, Float[Tensor, "total_samples n_heads seq_len seq_len"]]: + """Extract attention patterns over multiple batches. + + Args: + model: ComponentModel containing the transformer + device: Device to run inference on + dataloader: DataLoader to get batches from + n_batches: Number of batches to process + + Returns: + Dictionary mapping layer names to attention patterns (on CPU) + Format: {layer_name: tensor of shape [total_samples, n_heads, seq_len, seq_len]} + """ + print(f"Extracting attention patterns for {n_batches} batches...") + all_attention_patterns: list[dict[str, Tensor]] = [] + + for batch_idx in tqdm(range(n_batches), desc="Batches", total=n_batches): + batch_data = next(iter(dataloader)) + input_ids: Int[Tensor, "batch seq_len"] = batch_data["input_ids"].to(device) + + # Get attention patterns on GPU + with torch.no_grad(): + outputs = model.target_model(input_ids, output_attentions=True) + + # Extract attention patterns + # outputs.attentions is a tuple of tensors, one per layer + # Each tensor has shape [batch, n_heads, seq_len, seq_len] + batch_attention: dict[str, Tensor] = {} + if hasattr(outputs, "attentions") and outputs.attentions is not None: + for layer_idx, attn_weights in enumerate(outputs.attentions): + layer_name = f"layer_{layer_idx}" + # Move to CPU immediately + batch_attention[layer_name] = attn_weights.cpu() + + all_attention_patterns.append(batch_attention) + + # Concatenate all batches on CPU + print("Concatenating batches...") + layer_names: list[str] = list(all_attention_patterns[0].keys()) + attention_patterns_concat: dict[str, Tensor] = { + layer_name: torch.cat( + [batch[layer_name] for batch in all_attention_patterns], dim=0 + ) + for layer_name in layer_names + } + + print(f"Extracted attention patterns for {len(layer_names)} layers") + return attention_patterns_concat + + +# Extract attention patterns +attention_patterns: dict[str, Float[Tensor, "total_samples n_heads seq_len seq_len"]] = ( + extract_attention_patterns_multibatch( + model=model, + device=device, + dataloader=dataloader, + n_batches=config.n_batches, + ) +) + +# Print shapes +print("\nAttention pattern shapes:") +for layer_name, attn in attention_patterns.items(): + print(f" {layer_name}: {attn.shape}") + +# %% +# ----------------------- compute attention statistics ----------------------- + + +def compute_attention_stats( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], +) -> dict[str, dict[str, Float[np.ndarray, "..."]]]: + """Compute statistics about attention patterns. + + Args: + attention_patterns: Dictionary of attention patterns per layer + + Returns: + Dictionary with statistics per layer including: + - mean_pattern: Average attention pattern [n_heads, seq_len, seq_len] + - entropy: Entropy of attention distributions [samples, n_heads, seq_len] + - max_attention: Maximum attention value [samples, n_heads, seq_len] + - sparsity: Fraction of attention < 0.01 [samples, n_heads] + """ + stats: dict[str, dict[str, np.ndarray]] = {} + + for layer_name, attn in attention_patterns.items(): + # Convert to numpy for stats + attn_np: np.ndarray = attn.numpy() + + # Mean pattern across all samples + mean_pattern: np.ndarray = attn_np.mean(axis=0) # [n_heads, seq_len, seq_len] + + # Entropy per query position: -sum(p * log(p)) + # Add small epsilon to avoid log(0) + epsilon = 1e-10 + attn_safe = attn_np + epsilon + entropy: np.ndarray = -(attn_safe * np.log(attn_safe)).sum( + axis=-1 + ) # [samples, n_heads, seq_len] + + # Max attention per query position + max_attention: np.ndarray = attn_np.max(axis=-1) # [samples, n_heads, seq_len] + + # Sparsity: fraction of attention weights < 0.01 + sparsity: np.ndarray = (attn_np < 0.01).mean( + axis=(2, 3) + ) # [samples, n_heads] + + stats[layer_name] = { + "mean_pattern": mean_pattern, + "entropy": entropy, + "max_attention": max_attention, + "sparsity": sparsity, + } + + return stats + + +attention_stats = compute_attention_stats(attention_patterns) +print("Computed attention statistics") + +# %% +# ----------------------- plot: average attention patterns per layer ----------------------- + + +def plot_average_attention_per_layer( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], + max_layers: int | None = None, +) -> None: + """Plot average attention pattern for each layer (averaged over heads and samples). + + Args: + attention_patterns: Dictionary of attention patterns per layer + max_layers: Maximum number of layers to plot (default: all) + """ + layer_names = sorted(attention_patterns.keys()) + if max_layers is not None: + layer_names = layer_names[:max_layers] + + n_layers = len(layer_names) + n_cols = min(4, n_layers) + n_rows = (n_layers + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows)) + if n_layers == 1: + axes = np.array([axes]) + axes = axes.flatten() + + for idx, layer_name in enumerate(layer_names): + attn = attention_patterns[layer_name].numpy() + # Average over samples and heads + avg_attn = attn.mean(axis=(0, 1)) # [seq_len, seq_len] + + ax = axes[idx] + im = ax.imshow(avg_attn, cmap="viridis", aspect="auto") + ax.set_title(f"{layer_name}\n(avg over samples & heads)") + ax.set_xlabel("Key position") + ax.set_ylabel("Query position") + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + # Hide unused subplots + for idx in range(n_layers, len(axes)): + axes[idx].axis("off") + + fig.tight_layout() + + +plot_average_attention_per_layer(attention_patterns, max_layers=None) +print("Average attention per layer plots generated.") + +# %% +# ----------------------- plot: per-head attention for selected layers ----------------------- + + +def plot_per_head_attention( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], + layer_names: list[str] | None = None, +) -> None: + """Plot attention pattern for each head in selected layers. + + Args: + attention_patterns: Dictionary of attention patterns per layer + layer_names: List of layer names to plot (default: first layer) + """ + if layer_names is None: + layer_names = [sorted(attention_patterns.keys())[0]] + + for layer_name in layer_names: + if layer_name not in attention_patterns: + print(f"Warning: {layer_name} not found in attention patterns") + continue + + attn = attention_patterns[layer_name].numpy() + # Average over samples + avg_attn = attn.mean(axis=0) # [n_heads, seq_len, seq_len] + n_heads = avg_attn.shape[0] + + n_cols = min(4, n_heads) + n_rows = (n_heads + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows)) + if n_heads == 1: + axes = np.array([axes]) + axes = axes.flatten() + + for head_idx in range(n_heads): + ax = axes[head_idx] + im = ax.imshow(avg_attn[head_idx], cmap="viridis", aspect="auto") + ax.set_title(f"Head {head_idx}") + ax.set_xlabel("Key position") + ax.set_ylabel("Query position") + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + # Hide unused subplots + for idx in range(n_heads, len(axes)): + axes[idx].axis("off") + + fig.suptitle(f"{layer_name} - Per-Head Attention Patterns", fontsize=14, y=1.00) + fig.tight_layout() + + +# Plot first and last layers +all_layer_names = sorted(attention_patterns.keys()) +layers_to_plot = [all_layer_names[0], all_layer_names[-1]] +plot_per_head_attention(attention_patterns, layer_names=layers_to_plot) +print(f"Per-head attention plots generated for layers: {layers_to_plot}") + +# %% +# ----------------------- plot: attention entropy across layers ----------------------- + + +def plot_attention_entropy( + attention_stats: dict[str, dict[str, np.ndarray]], +) -> None: + """Plot attention entropy statistics across layers. + + Args: + attention_stats: Dictionary of attention statistics per layer + """ + layer_names = sorted(attention_stats.keys()) + + # Collect mean entropy per layer (averaged over samples, heads, and query positions) + mean_entropies: list[float] = [] + for layer_name in layer_names: + entropy = attention_stats[layer_name]["entropy"] # [samples, n_heads, seq_len] + mean_entropies.append(float(entropy.mean())) + + # Plot + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(range(len(layer_names)), mean_entropies, marker="o") + ax.set_xlabel("Layer") + ax.set_ylabel("Mean Attention Entropy") + ax.set_title("Attention Entropy Across Layers\n(Higher = more uniform attention)") + ax.set_xticks(range(len(layer_names))) + ax.set_xticklabels(layer_names, rotation=45, ha="right") + ax.grid(True, alpha=0.3) + fig.tight_layout() + + +plot_attention_entropy(attention_stats) +print("Attention entropy plot generated.") + +# %% +# ----------------------- plot: attention sparsity across layers ----------------------- + + +def plot_attention_sparsity( + attention_stats: dict[str, dict[str, np.ndarray]], +) -> None: + """Plot attention sparsity across layers. + + Args: + attention_stats: Dictionary of attention statistics per layer + """ + layer_names = sorted(attention_stats.keys()) + + # Collect mean sparsity per layer (averaged over samples and heads) + mean_sparsities: list[float] = [] + for layer_name in layer_names: + sparsity = attention_stats[layer_name]["sparsity"] # [samples, n_heads] + mean_sparsities.append(float(sparsity.mean())) + + # Plot + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(range(len(layer_names)), mean_sparsities, marker="o", color="C1") + ax.set_xlabel("Layer") + ax.set_ylabel("Mean Sparsity (fraction < 0.01)") + ax.set_title( + "Attention Sparsity Across Layers\n(Higher = more sparse/focused attention)" + ) + ax.set_xticks(range(len(layer_names))) + ax.set_xticklabels(layer_names, rotation=45, ha="right") + ax.set_ylim(0, 1) + ax.grid(True, alpha=0.3) + fig.tight_layout() + + +plot_attention_sparsity(attention_stats) +print("Attention sparsity plot generated.") + +# %% +# ----------------------- plot: attention to first/last tokens ----------------------- + + +def plot_attention_to_special_positions( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], +) -> None: + """Plot how much attention each position pays to first and last tokens. + + Args: + attention_patterns: Dictionary of attention patterns per layer + """ + layer_names = sorted(attention_patterns.keys()) + + # Collect attention to first and last tokens + attn_to_first: list[float] = [] + attn_to_last: list[float] = [] + + for layer_name in layer_names: + attn = attention_patterns[layer_name].numpy() + # Average over samples and heads + avg_attn = attn.mean(axis=(0, 1)) # [seq_len, seq_len] + + # Average attention to first token (across all query positions) + attn_to_first.append(float(avg_attn[:, 0].mean())) + + # Average attention to last token (across all query positions) + attn_to_last.append(float(avg_attn[:, -1].mean())) + + # Plot + fig, ax = plt.subplots(figsize=(10, 5)) + x = range(len(layer_names)) + ax.plot(x, attn_to_first, marker="o", label="Attention to first token") + ax.plot(x, attn_to_last, marker="s", label="Attention to last token") + ax.set_xlabel("Layer") + ax.set_ylabel("Mean Attention Weight") + ax.set_title("Attention to Special Token Positions Across Layers") + ax.set_xticks(x) + ax.set_xticklabels(layer_names, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + + +plot_attention_to_special_positions(attention_patterns) +print("Attention to special positions plot generated.") + +# %% diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index f9172a8e3..b535bb24d 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -267,24 +267,21 @@ def plot_covariance( fig.tight_layout() -def plot_layer_metrics( +def plot_average_precision( per_layer_stats: list[dict[str, Any]], module_keys: list[str], ) -> None: - """Plot distributions of metrics per layer with scatter plots and jitter. + """Plot distribution of average precision per layer with scatter plot and jitter. Args: per_layer_stats: List of dicts with metrics per layer module_keys: List of module names for x-axis labels """ L: int = len(per_layer_stats) - - # Prepare data: all values per layer with jitter for visualization np.random.seed(42) # Reproducible jitter jitter_amount: float = 0.15 - # AP per layer - fig1, ax1 = plt.subplots(figsize=(10, 5)) + fig, ax = plt.subplots(figsize=(10, 5)) for layer_idx, stats in enumerate(per_layer_stats): ap_values: np.ndarray = stats["ap"] # Remove NaN values @@ -295,83 +292,111 @@ def plot_layer_metrics( x_jittered: np.ndarray = x_positions + np.random.uniform( -jitter_amount, jitter_amount, len(ap_valid) ) - ax1.scatter(x_jittered, ap_valid, alpha=0.5, s=20, color="C0", edgecolors='none') + ax.scatter(x_jittered, ap_valid, alpha=0.5, s=20, color="C0", edgecolors='none') # Add mean line - ax1.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + ax.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], [stats["mean_ap"], stats["mean_ap"]], 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') - ax1.set_title( + ax.set_title( r"Average Precision per Target Component" + "\n" r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$ where " r"$P_n = \frac{\text{TP}}{\text{TP}+\text{FP}}$, " r"$R_n = \frac{\text{TP}}{\text{TP}+\text{FN}}$" ) - ax1.set_xlabel("Target Module") - ax1.set_ylabel("Average Precision") - ax1.set_xticks(np.arange(1, L + 1)) + ax.set_xlabel("Target Module") + ax.set_ylabel("Average Precision") + ax.set_xticks(np.arange(1, L + 1)) # Only use module keys that correspond to target layers (skip input layer) - ax1.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') - ax1.set_ylim(-0.05, 1.05) - ax1.grid(True, alpha=0.3, axis='y') - ax1.legend() - fig1.tight_layout() - - # Accuracy per layer - fig2, ax2 = plt.subplots(figsize=(10, 5)) + ax.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax.set_ylim(-0.05, 1.05) + ax.grid(True, alpha=0.3, axis='y') + ax.legend() + fig.tight_layout() + + +def plot_accuracy( + per_layer_stats: list[dict[str, Any]], + module_keys: list[str], +) -> None: + """Plot distribution of accuracy per layer with scatter plot and jitter. + + Args: + per_layer_stats: List of dicts with metrics per layer + module_keys: List of module names for x-axis labels + """ + L: int = len(per_layer_stats) + np.random.seed(42) # Reproducible jitter + jitter_amount: float = 0.15 + + fig, ax = plt.subplots(figsize=(10, 5)) for layer_idx, stats in enumerate(per_layer_stats): acc_values: np.ndarray = stats["acc"] acc_valid: np.ndarray = acc_values[~np.isnan(acc_values)] if len(acc_valid) > 0: - x_positions = np.ones(len(acc_valid)) * (layer_idx + 1) - x_jittered = x_positions + np.random.uniform( + x_positions: np.ndarray = np.ones(len(acc_valid)) * (layer_idx + 1) + x_jittered: np.ndarray = x_positions + np.random.uniform( -jitter_amount, jitter_amount, len(acc_valid) ) - ax2.scatter(x_jittered, acc_valid, alpha=0.5, s=20, color="C1", edgecolors='none') - ax2.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + ax.scatter(x_jittered, acc_valid, alpha=0.5, s=20, color="C1", edgecolors='none') + ax.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], [stats["mean_acc"], stats["mean_acc"]], 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') - ax2.set_title( + ax.set_title( r"Accuracy per Target Component" + "\n" r"$\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}}$" ) - ax2.set_xlabel("Target Module") - ax2.set_ylabel("Accuracy") - ax2.set_xticks(np.arange(1, L + 1)) - ax2.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') - ax2.set_ylim(-0.05, 1.05) - ax2.grid(True, alpha=0.3, axis='y') - ax2.legend() - fig2.tight_layout() - - # Balanced Accuracy per layer - fig3, ax3 = plt.subplots(figsize=(10, 5)) + ax.set_xlabel("Target Module") + ax.set_ylabel("Accuracy") + ax.set_xticks(np.arange(1, L + 1)) + ax.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax.set_ylim(-0.05, 1.05) + ax.grid(True, alpha=0.3, axis='y') + ax.legend() + fig.tight_layout() + + +def plot_balanced_accuracy( + per_layer_stats: list[dict[str, Any]], + module_keys: list[str], +) -> None: + """Plot distribution of balanced accuracy per layer with scatter plot and jitter. + + Args: + per_layer_stats: List of dicts with metrics per layer + module_keys: List of module names for x-axis labels + """ + L: int = len(per_layer_stats) + np.random.seed(42) # Reproducible jitter + jitter_amount: float = 0.15 + + fig, ax = plt.subplots(figsize=(10, 5)) for layer_idx, stats in enumerate(per_layer_stats): bacc_values: np.ndarray = stats["bacc"] bacc_valid: np.ndarray = bacc_values[~np.isnan(bacc_values)] if len(bacc_valid) > 0: - x_positions = np.ones(len(bacc_valid)) * (layer_idx + 1) - x_jittered = x_positions + np.random.uniform( + x_positions: np.ndarray = np.ones(len(bacc_valid)) * (layer_idx + 1) + x_jittered: np.ndarray = x_positions + np.random.uniform( -jitter_amount, jitter_amount, len(bacc_valid) ) - ax3.scatter(x_jittered, bacc_valid, alpha=0.5, s=20, color="C2", edgecolors='none') - ax3.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + ax.scatter(x_jittered, bacc_valid, alpha=0.5, s=20, color="C2", edgecolors='none') + ax.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], [stats["mean_bacc"], stats["mean_bacc"]], 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') - ax3.set_title( + ax.set_title( r"Balanced Accuracy per Target Component" + "\n" r"$\text{Balanced Acc} = \frac{1}{2}\left(\frac{\text{TP}}{\text{TP}+\text{FN}} + \frac{\text{TN}}{\text{TN}+\text{FP}}\right)$" ) - ax3.set_xlabel("Target Module") - ax3.set_ylabel("Balanced Accuracy") - ax3.set_xticks(np.arange(1, L + 1)) - ax3.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') - ax3.set_ylim(-0.05, 1.05) - ax3.grid(True, alpha=0.3, axis='y') - ax3.legend() - fig3.tight_layout() + ax.set_xlabel("Target Module") + ax.set_ylabel("Balanced Accuracy") + ax.set_xticks(np.arange(1, L + 1)) + ax.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax.set_ylim(-0.05, 1.05) + ax.grid(True, alpha=0.3, axis='y') + ax.legend() + fig.tight_layout() def plot_ap_vs_prevalence( @@ -432,6 +457,7 @@ def plot_component_activity_breakdown( component_acts: dict[str, np.ndarray], module_keys: list[str], activation_threshold: float, + logy: bool = False, ) -> None: """Plot stacked bar chart of component activity breakdown per module. @@ -503,7 +529,8 @@ def plot_component_activity_breakdown( ax.set_ylabel("Number of Components (log scale)") ax.set_xticks(x_pos) ax.set_xticklabels(module_keys, rotation=45, ha='right') - ax.set_yscale('log') + if logy: + ax.set_yscale('log') # Create legend with correct labels from matplotlib.patches import Patch diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index b914bb13d..064819dc6 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -17,11 +17,13 @@ ) from spd.clustering.ci_dt.plot import ( greedy_sort, + plot_accuracy, plot_activations, plot_ap_vs_prevalence, + plot_average_precision, + plot_balanced_accuracy, plot_component_activity_breakdown, plot_covariance, - plot_layer_metrics, plot_selected_trees, plot_tree_statistics, ) @@ -139,11 +141,22 @@ print(f"Computed sample ordering ({len(sample_order)} samples)") # %% -# ----------------------- plot: layer metrics ----------------------- -# Scatter plots with jitter showing distribution of metrics per layer +# ----------------------- plot: average precision ----------------------- -plot_layer_metrics(per_layer_stats, module_keys) -print("Layer metrics plots generated.") +plot_average_precision(per_layer_stats, module_keys) +print("Average precision plot generated.") + +# %% +# ----------------------- plot: accuracy ----------------------- + +plot_accuracy(per_layer_stats, module_keys) +print("Accuracy plot generated.") + +# %% +# ----------------------- plot: balanced accuracy ----------------------- + +plot_balanced_accuracy(per_layer_stats, module_keys) +print("Balanced accuracy plot generated.") # %% # ----------------------- plot: AP vs prevalence ----------------------- @@ -155,7 +168,12 @@ # ----------------------- plot: component activity breakdown ----------------------- plot_component_activity_breakdown( - component_acts_concat, module_keys, config.activation_threshold + component_acts_concat, module_keys, config.activation_threshold, + logy=False, +) +plot_component_activity_breakdown( + component_acts_concat, module_keys, config.activation_threshold, + logy=True, ) print("Component activity breakdown plot generated.") From 9325b1b1951f1ab6a8443d3d0d8b0feffe1ce53e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Thu, 23 Oct 2025 17:26:24 +0100 Subject: [PATCH 54/77] format --- spd/clustering/ci_dt/attn.py | 21 ++---- spd/clustering/ci_dt/core.py | 2 - spd/clustering/ci_dt/feature_names.py | 9 +-- spd/clustering/ci_dt/pipeline.py | 22 +++--- spd/clustering/ci_dt/plot.py | 103 ++++++++++++++------------ spd/clustering/ci_dt/run.py | 27 ++++--- 6 files changed, 87 insertions(+), 97 deletions(-) diff --git a/spd/clustering/ci_dt/attn.py b/spd/clustering/ci_dt/attn.py index d73ce535c..2ff5a7533 100644 --- a/spd/clustering/ci_dt/attn.py +++ b/spd/clustering/ci_dt/attn.py @@ -1,8 +1,6 @@ # %% """Attention pattern visualization for CI decision tree analysis.""" -from typing import Any - import matplotlib.pyplot as plt import numpy as np import torch @@ -17,10 +15,9 @@ from spd.experiments.lm.configs import LMTaskConfig from spd.models.component_model import ComponentModel, SPDRunInfo - # magic autoreload -%load_ext autoreload -%autoreload 2 +# %load_ext autoreload +# %autoreload 2 # %% # ----------------------- configuration ----------------------- @@ -99,7 +96,7 @@ def extract_attention_patterns_multibatch( print(f"Extracting attention patterns for {n_batches} batches...") all_attention_patterns: list[dict[str, Tensor]] = [] - for batch_idx in tqdm(range(n_batches), desc="Batches", total=n_batches): + for _batch_idx in tqdm(range(n_batches), desc="Batches", total=n_batches): batch_data = next(iter(dataloader)) input_ids: Int[Tensor, "batch seq_len"] = batch_data["input_ids"].to(device) @@ -123,9 +120,7 @@ def extract_attention_patterns_multibatch( print("Concatenating batches...") layer_names: list[str] = list(all_attention_patterns[0].keys()) attention_patterns_concat: dict[str, Tensor] = { - layer_name: torch.cat( - [batch[layer_name] for batch in all_attention_patterns], dim=0 - ) + layer_name: torch.cat([batch[layer_name] for batch in all_attention_patterns], dim=0) for layer_name in layer_names } @@ -188,9 +183,7 @@ def compute_attention_stats( max_attention: np.ndarray = attn_np.max(axis=-1) # [samples, n_heads, seq_len] # Sparsity: fraction of attention weights < 0.01 - sparsity: np.ndarray = (attn_np < 0.01).mean( - axis=(2, 3) - ) # [samples, n_heads] + sparsity: np.ndarray = (attn_np < 0.01).mean(axis=(2, 3)) # [samples, n_heads] stats[layer_name] = { "mean_pattern": mean_pattern, @@ -371,9 +364,7 @@ def plot_attention_sparsity( ax.plot(range(len(layer_names)), mean_sparsities, marker="o", color="C1") ax.set_xlabel("Layer") ax.set_ylabel("Mean Sparsity (fraction < 0.01)") - ax.set_title( - "Attention Sparsity Across Layers\n(Higher = more sparse/focused attention)" - ) + ax.set_title("Attention Sparsity Across Layers\n(Higher = more sparse/focused attention)") ax.set_xticks(range(len(layer_names))) ax.set_xticklabels(layer_names, rotation=45, ha="right") ax.set_ylim(0, 1) diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index a5ceac916..9147fcfb0 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -1,12 +1,10 @@ """Core library functions for causal importance decision trees.""" -from collections import Counter from collections.abc import Sequence from dataclasses import dataclass import numpy as np from jaxtyping import Bool, Float -from muutils.dbg import dbg from sklearn.metrics import ( accuracy_score, average_precision_score, diff --git a/spd/clustering/ci_dt/feature_names.py b/spd/clustering/ci_dt/feature_names.py index 960ced3b8..11f3ffcf8 100644 --- a/spd/clustering/ci_dt/feature_names.py +++ b/spd/clustering/ci_dt/feature_names.py @@ -84,7 +84,7 @@ def decode_direction_top_k( # Decode tokens tokens = [] - for idx, val in zip(top_k_indices.tolist(), top_k_values.tolist()): + for idx, val in zip(top_k_indices.tolist(), top_k_values.tolist(), strict=False): token_str = tokenizer.decode([idx]) # Clean up token string for display token_str = repr(token_str)[1:-1] # Remove quotes and escape special chars @@ -112,7 +112,7 @@ def get_component_directions( # Get the component module component = component_model.components[module_key] - assert isinstance(component, (LinearComponents, EmbeddingComponents)), ( + assert isinstance(component, LinearComponents | EmbeddingComponents), ( f"Expected LinearComponents or EmbeddingComponents, got {type(component)}" ) @@ -206,10 +206,7 @@ def generate_feature_names( ) feature_name = ( - f"{comp_label}\n" - f"{act_info}\n" - f"R→E:{read_embed}\n" - f"W→U:{write_unembed}" + f"{comp_label}\n{act_info}\nR→E:{read_embed}\nW→U:{write_unembed}" ) except Exception as e: print(f"Warning: Could not decode component {comp_label}: {e}") diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index ea022eccf..2d07eed65 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -4,10 +4,10 @@ import numpy as np import torch -from jaxtyping import Bool, Float, Int -from torch import Tensor, nn -from torch.utils.data import DataLoader +from jaxtyping import Bool, Float from muutils.dbg import dbg_tensor +from torch import Tensor +from torch.utils.data import DataLoader from tqdm import tqdm from spd.clustering.activations import component_activations @@ -44,7 +44,7 @@ def compute_activations_multibatch( print(f"Computing activations for {n_batches} batches...") all_component_acts: list[dict[str, Tensor]] = [] - for batch_idx in tqdm(range(n_batches), desc="Batches", total=n_batches): + for _batch_idx in tqdm(range(n_batches), desc="Batches", total=n_batches): batch_data = next(iter(dataloader)) batch: Tensor = batch_data["input_ids"] @@ -63,15 +63,13 @@ def compute_activations_multibatch( print("Concatenating batches...") module_keys: list[str] = list(all_component_acts[0].keys()) component_acts_concat: dict[str, Tensor] = { - key: torch.cat([batch[key] for batch in all_component_acts], dim=0) - for key in module_keys + key: torch.cat([batch[key] for batch in all_component_acts], dim=0) for key in module_keys } # Apply seq_mean if needed (LM task) print("Applying seq_mean over sequence dimension...") component_acts_concat = { - key: act.mean(dim=1) if act.ndim == 3 else act - for key, act in component_acts_concat.items() + key: act.mean(dim=1) if act.ndim == 3 else act for key, act in component_acts_concat.items() } return component_acts_concat @@ -107,12 +105,12 @@ def convert_to_boolean_layers( # Filter out components that are always dead or always alive # (they provide no information for decision trees) - component_variance: Float[np.ndarray, "n_components"] = module_acts_bool.var(axis=0) - varying_mask: Bool[np.ndarray, "n_components"] = component_variance > 0 + component_variance: Float[np.ndarray, " n_components"] = module_acts_bool.var(axis=0) + varying_mask: Bool[np.ndarray, " n_components"] = component_variance > 0 # Count always-dead and always-alive components for diagnostics - always_dead_mask: Bool[np.ndarray, "n_components"] = ~module_acts_bool.any(axis=0) - always_alive_mask: Bool[np.ndarray, "n_components"] = module_acts_bool.all(axis=0) + always_dead_mask: Bool[np.ndarray, " n_components"] = ~module_acts_bool.any(axis=0) + always_alive_mask: Bool[np.ndarray, " n_components"] = module_acts_bool.all(axis=0) n_always_dead: int = int(always_dead_mask.sum()) n_always_alive: int = int(always_alive_mask.sum()) diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index b535bb24d..c491383ec 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np -from jaxtyping import Bool, Float, Int +from jaxtyping import Float, Int from sklearn.tree import plot_tree from spd.clustering.ci_dt.core import LayerModel, get_estimator_for @@ -35,7 +35,7 @@ def greedy_sort(A: np.ndarray, axis: int) -> np.ndarray: # Start from most central item (highest average similarity) n: int = similarity.shape[0] - avg_sim: Float[np.ndarray, "n"] = similarity.mean(axis=1) + avg_sim: Float[np.ndarray, n] = similarity.mean(axis=1) start_idx: int = int(np.argmax(avg_sim)) # Greedy ordering: always add nearest unvisited neighbor @@ -61,9 +61,7 @@ def greedy_sort(A: np.ndarray, axis: int) -> np.ndarray: return np.array(ordered, dtype=np.int64) -def add_component_labeling( - ax: plt.Axes, component_labels: list[str], axis: str = "x" -) -> None: +def add_component_labeling(ax: plt.Axes, component_labels: list[str], axis: str = "x") -> None: """Add component labeling using major/minor ticks to show module boundaries. Args: @@ -218,9 +216,9 @@ def plot_covariance( module_keys: List of module names for labeling component_order: Optional array of component indices for sorting. If None, plots unsorted. """ - A: Float[np.ndarray, "n_samples n_components"] = np.concatenate( - layers_true, axis=1 - ).astype(float) + A: Float[np.ndarray, "n_samples n_components"] = np.concatenate(layers_true, axis=1).astype( + float + ) # Apply component ordering if provided if component_order is not None: @@ -292,11 +290,15 @@ def plot_average_precision( x_jittered: np.ndarray = x_positions + np.random.uniform( -jitter_amount, jitter_amount, len(ap_valid) ) - ax.scatter(x_jittered, ap_valid, alpha=0.5, s=20, color="C0", edgecolors='none') + ax.scatter(x_jittered, ap_valid, alpha=0.5, s=20, color="C0", edgecolors="none") # Add mean line - ax.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], - [stats["mean_ap"], stats["mean_ap"]], - 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') + ax.plot( + [layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + [stats["mean_ap"], stats["mean_ap"]], + "r-", + linewidth=2, + label="Mean" if layer_idx == 0 else "", + ) ax.set_title( r"Average Precision per Target Component" + "\n" @@ -308,9 +310,9 @@ def plot_average_precision( ax.set_ylabel("Average Precision") ax.set_xticks(np.arange(1, L + 1)) # Only use module keys that correspond to target layers (skip input layer) - ax.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax.set_xticklabels(module_keys[1 : L + 1], rotation=45, ha="right") ax.set_ylim(-0.05, 1.05) - ax.grid(True, alpha=0.3, axis='y') + ax.grid(True, alpha=0.3, axis="y") ax.legend() fig.tight_layout() @@ -338,10 +340,14 @@ def plot_accuracy( x_jittered: np.ndarray = x_positions + np.random.uniform( -jitter_amount, jitter_amount, len(acc_valid) ) - ax.scatter(x_jittered, acc_valid, alpha=0.5, s=20, color="C1", edgecolors='none') - ax.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], - [stats["mean_acc"], stats["mean_acc"]], - 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') + ax.scatter(x_jittered, acc_valid, alpha=0.5, s=20, color="C1", edgecolors="none") + ax.plot( + [layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + [stats["mean_acc"], stats["mean_acc"]], + "r-", + linewidth=2, + label="Mean" if layer_idx == 0 else "", + ) ax.set_title( r"Accuracy per Target Component" + "\n" @@ -350,9 +356,9 @@ def plot_accuracy( ax.set_xlabel("Target Module") ax.set_ylabel("Accuracy") ax.set_xticks(np.arange(1, L + 1)) - ax.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax.set_xticklabels(module_keys[1 : L + 1], rotation=45, ha="right") ax.set_ylim(-0.05, 1.05) - ax.grid(True, alpha=0.3, axis='y') + ax.grid(True, alpha=0.3, axis="y") ax.legend() fig.tight_layout() @@ -380,10 +386,14 @@ def plot_balanced_accuracy( x_jittered: np.ndarray = x_positions + np.random.uniform( -jitter_amount, jitter_amount, len(bacc_valid) ) - ax.scatter(x_jittered, bacc_valid, alpha=0.5, s=20, color="C2", edgecolors='none') - ax.plot([layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], - [stats["mean_bacc"], stats["mean_bacc"]], - 'r-', linewidth=2, label='Mean' if layer_idx == 0 else '') + ax.scatter(x_jittered, bacc_valid, alpha=0.5, s=20, color="C2", edgecolors="none") + ax.plot( + [layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + [stats["mean_bacc"], stats["mean_bacc"]], + "r-", + linewidth=2, + label="Mean" if layer_idx == 0 else "", + ) ax.set_title( r"Balanced Accuracy per Target Component" + "\n" @@ -392,16 +402,14 @@ def plot_balanced_accuracy( ax.set_xlabel("Target Module") ax.set_ylabel("Balanced Accuracy") ax.set_xticks(np.arange(1, L + 1)) - ax.set_xticklabels(module_keys[1:L+1], rotation=45, ha='right') + ax.set_xticklabels(module_keys[1 : L + 1], rotation=45, ha="right") ax.set_ylim(-0.05, 1.05) - ax.grid(True, alpha=0.3, axis='y') + ax.grid(True, alpha=0.3, axis="y") ax.legend() fig.tight_layout() -def plot_ap_vs_prevalence( - per_layer_stats: list[dict[str, Any]], models: list[LayerModel] -) -> None: +def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[LayerModel]) -> None: """Plot AP vs prevalence scatter colored by tree depth. Args: @@ -414,7 +422,7 @@ def plot_ap_vs_prevalence( ap_list: list[float] = [] depth_list: list[int] = [] - for layer_idx, (stats, model) in enumerate(zip(per_layer_stats, models, strict=True)): + for _layer_idx, (stats, model) in enumerate(zip(per_layer_stats, models, strict=True)): for target_idx, (prev, ap) in enumerate(zip(stats["prev"], stats["ap"], strict=True)): if not np.isnan(ap): prevalence_list.append(prev) @@ -434,7 +442,7 @@ def plot_ap_vs_prevalence( cmap="viridis", alpha=0.6, s=30, - edgecolors='none', + edgecolors="none", ) ax.set_title( @@ -476,7 +484,7 @@ def plot_component_activity_breakdown( for module_key in module_keys: acts: np.ndarray = component_acts[module_key] # Convert to numpy if needed - if hasattr(acts, 'cpu'): + if hasattr(acts, "cpu"): acts = acts.cpu().numpy() # Convert to boolean acts_bool: np.ndarray = (acts >= activation_threshold).astype(bool) @@ -528,19 +536,20 @@ def plot_component_activity_breakdown( ax.set_xlabel("Module") ax.set_ylabel("Number of Components (log scale)") ax.set_xticks(x_pos) - ax.set_xticklabels(module_keys, rotation=45, ha='right') + ax.set_xticklabels(module_keys, rotation=45, ha="right") if logy: - ax.set_yscale('log') + ax.set_yscale("log") # Create legend with correct labels from matplotlib.patches import Patch + legend_elements = [ - Patch(facecolor='C2', label='Varying'), - Patch(facecolor='C1', label='Always Active'), - Patch(facecolor='C0', label='Always Inactive'), + Patch(facecolor="C2", label="Varying"), + Patch(facecolor="C1", label="Always Active"), + Patch(facecolor="C0", label="Always Inactive"), ] - ax.legend(handles=legend_elements, loc='upper left') - ax.grid(True, alpha=0.3, axis='y') + ax.legend(handles=legend_elements, loc="upper left") + ax.grid(True, alpha=0.3, axis="y") fig.tight_layout() @@ -578,7 +587,7 @@ def plot_selected_trees( def extract_tree_stats( models: list[LayerModel], per_layer_stats: list[dict[str, Any]], -) -> dict[str, Float[np.ndarray, "n_trees"]]: +) -> dict[str, Float[np.ndarray, " n_trees"]]: """Extract depth, leaf count, and accuracy for all trees across all layers.""" depths: list[int] = [] leaf_counts: list[int] = [] @@ -603,9 +612,7 @@ def extract_tree_stats( } -def plot_tree_statistics( - models: list[LayerModel], per_layer_stats: list[dict[str, Any]] -) -> None: +def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[str, Any]]) -> None: """Plot distributions of tree depth, leaf count, and their correlations with accuracy.""" stats = extract_tree_stats(models, per_layer_stats) @@ -632,10 +639,10 @@ def plot_tree_statistics( # Heatmap: depth vs accuracy valid_mask: np.ndarray = ~np.isnan(stats["accuracy"]) - depth_bins: Int[np.ndarray, "n_bins"] = np.arange( + depth_bins: Int[np.ndarray, " n_bins"] = np.arange( int(stats["depth"].min()), int(stats["depth"].max()) + 2 ) - acc_bins: Float[np.ndarray, "n_bins"] = np.linspace(0, 1, 11) + acc_bins: Float[np.ndarray, " n_bins"] = np.linspace(0, 1, 11) heatmap_depth_acc: Float[np.ndarray, "depth_bins acc_bins"] heatmap_depth_acc, _, _ = np.histogram2d( stats["depth"][valid_mask], stats["accuracy"][valid_mask], bins=[depth_bins, acc_bins] @@ -660,7 +667,7 @@ def plot_tree_statistics( plt.colorbar(im, ax=ax4, label="log10(count+1)") # Heatmap: leaf count vs accuracy - leaf_bins: Int[np.ndarray, "n_bins"] = np.linspace( + leaf_bins: Int[np.ndarray, " n_bins"] = np.linspace( int(stats["n_leaves"].min()), int(stats["n_leaves"].max()) + 1, 11, dtype=int ) heatmap_leaf_acc: Float[np.ndarray, "leaf_bins acc_bins"] @@ -723,12 +730,12 @@ def plot_tree_statistics( # Prevalence bins (log scale) prev_min: float = max(prevalence_arr.min(), 1e-4) # Avoid log(0) prev_max: float = prevalence_arr.max() - prev_bins: Float[np.ndarray, "n_bins"] = np.logspace( + prev_bins: Float[np.ndarray, " n_bins"] = np.logspace( np.log10(prev_min), np.log10(prev_max), 10 ) # AP bins (linear) - ap_bins_heatmap: Float[np.ndarray, "n_bins"] = np.linspace(0, 1, 11) + ap_bins_heatmap: Float[np.ndarray, " n_bins"] = np.linspace(0, 1, 11) heatmap_prev_ap: Float[np.ndarray, "prev_bins ap_bins"] heatmap_prev_ap, _, _ = np.histogram2d( diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 064819dc6..35413abde 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -1,11 +1,9 @@ # %% """Main execution script for causal importance decision tree training.""" -from typing import Any - import numpy as np import torch -from jaxtyping import Bool, Float +from jaxtyping import Bool from torch import Tensor from spd.clustering.ci_dt.config import CIDTConfig @@ -18,13 +16,11 @@ from spd.clustering.ci_dt.plot import ( greedy_sort, plot_accuracy, - plot_activations, plot_ap_vs_prevalence, plot_average_precision, plot_balanced_accuracy, plot_component_activity_breakdown, plot_covariance, - plot_selected_trees, plot_tree_statistics, ) from spd.configs import Config @@ -32,10 +28,9 @@ from spd.experiments.lm.configs import LMTaskConfig from spd.models.component_model import ComponentModel, SPDRunInfo - # magic autoreload -%load_ext autoreload -%autoreload 2 +# %load_ext autoreload +# %autoreload 2 # %% # ----------------------- configuration ----------------------- @@ -43,8 +38,8 @@ config = CIDTConfig( # batch_size=50, # 50 ~~ 16GB VRAM max # n_batches=8, - batch_size=16, - n_batches=4, + batch_size=16, + n_batches=4, activation_threshold=0.01, max_depth=8, random_state=42, @@ -168,12 +163,16 @@ # ----------------------- plot: component activity breakdown ----------------------- plot_component_activity_breakdown( - component_acts_concat, module_keys, config.activation_threshold, - logy=False, + component_acts_concat, + module_keys, + config.activation_threshold, + logy=False, ) plot_component_activity_breakdown( - component_acts_concat, module_keys, config.activation_threshold, - logy=True, + component_acts_concat, + module_keys, + config.activation_threshold, + logy=True, ) print("Component activity breakdown plot generated.") From d860bfa7b93870d49e554a462ebbee5ba21662b5 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 10:48:35 +0100 Subject: [PATCH 55/77] wip --- spd/clustering/ci_dt/__init__.py | 7 -- spd/clustering/ci_dt/feature_names.py | 97 +-------------------------- spd/clustering/ci_dt/run.py | 71 ++++++++------------ 3 files changed, 28 insertions(+), 147 deletions(-) diff --git a/spd/clustering/ci_dt/__init__.py b/spd/clustering/ci_dt/__init__.py index 2097aa211..3f8e91e98 100644 --- a/spd/clustering/ci_dt/__init__.py +++ b/spd/clustering/ci_dt/__init__.py @@ -13,10 +13,6 @@ proba_for_layer, train_trees, ) -from spd.clustering.ci_dt.plot import ( - plot_layer_metrics, - plot_selected_trees, -) __all__ = [ # Config @@ -32,7 +28,4 @@ "layer_metrics", "proba_for_layer", "get_estimator_for", - # Plot - "plot_layer_metrics", - "plot_selected_trees", ] diff --git a/spd/clustering/ci_dt/feature_names.py b/spd/clustering/ci_dt/feature_names.py index 11f3ffcf8..4f96332d7 100644 --- a/spd/clustering/ci_dt/feature_names.py +++ b/spd/clustering/ci_dt/feature_names.py @@ -123,99 +123,4 @@ def get_component_directions( read_direction = V[:, component_idx] # [d_in] write_direction = U[component_idx, :] # [d_out] - return read_direction, write_direction - - -def generate_feature_names( - component_model: ComponentModel, - component_acts: dict[str, Tensor], - layers_true: list[Bool[np.ndarray, "n_samples n_components"]], - layers_pred: list[np.ndarray], - tokenizer, - module_keys: list[str], - top_k: int = 3, -) -> list[list[str]]: - """Generate feature names for all layers with activation and decoding info. - - Args: - component_model: The ComponentModel containing components - component_acts: Dictionary of continuous activations [n_samples, n_components] - layers_true: List of boolean activations per layer - layers_pred: List of predicted boolean activations per layer - tokenizer: Tokenizer for decoding directions - module_keys: List of module keys in order (matches layers_true) - top_k: Number of top tokens to show for each direction - - Returns: - List of feature name lists, one per layer. feature_names[k] contains - names for all features used to predict layer k (concatenation of layers 0..k-1). - """ - try: - embed, unembed = get_embed_unembed_matrices(component_model) - except ValueError as e: - print(f"Warning: Could not extract embed/unembed matrices: {e}") - embed = None - unembed = None - - feature_names_per_layer: list[list[str]] = [] - - # For each target layer k (k=1..L-1), we need feature names for concat(layers[:k]) - for k in range(1, len(layers_true)): - feature_names_k: list[str] = [] - - # Iterate through all previous layers (0..k-1) - for layer_idx in range(k): - module_key = module_keys[layer_idx] - n_components = layers_true[layer_idx].shape[1] - - # Get continuous activations for this module - module_acts = component_acts[module_key].numpy() # [n_samples, n_total_components] - - # Map from filtered component indices to original component indices - # layers_true[layer_idx] has only the varying components - varying_mask = module_acts.var(axis=0) > 0 - varying_indices = np.where(varying_mask)[0] - - for filtered_idx in range(n_components): - original_idx = varying_indices[filtered_idx] - - # Get actual and predicted activation stats - actual_acts = module_acts[:, original_idx] - actual_mean = float(actual_acts.mean()) - actual_std = float(actual_acts.std()) - - # Component label - comp_label = f"L{layer_idx}C{original_idx}" - - # Get activation info - act_info = f"μ={actual_mean:.2f},σ={actual_std:.2f}" - - # Get direction decodings if available - if embed is not None and unembed is not None: - try: - read_dir, write_dir = get_component_directions( - component_model, module_key, original_idx - ) - - # Decode using both embed and unembed - read_embed = decode_direction_top_k( - read_dir, embed, unembed, tokenizer, k=top_k, use_embed=True - ) - write_unembed = decode_direction_top_k( - write_dir, embed, unembed, tokenizer, k=top_k, use_embed=False - ) - - feature_name = ( - f"{comp_label}\n{act_info}\nR→E:{read_embed}\nW→U:{write_unembed}" - ) - except Exception as e: - print(f"Warning: Could not decode component {comp_label}: {e}") - feature_name = f"{comp_label}\n{act_info}" - else: - feature_name = f"{comp_label}\n{act_info}" - - feature_names_k.append(feature_name) - - feature_names_per_layer.append(feature_names_k) - - return feature_names_per_layer + return read_direction, write_direction \ No newline at end of file diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 35413abde..69b4f1360 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -16,11 +16,13 @@ from spd.clustering.ci_dt.plot import ( greedy_sort, plot_accuracy, + plot_activations, plot_ap_vs_prevalence, plot_average_precision, plot_balanced_accuracy, plot_component_activity_breakdown, plot_covariance, + plot_selected_trees, plot_tree_statistics, ) from spd.configs import Config @@ -38,10 +40,10 @@ config = CIDTConfig( # batch_size=50, # 50 ~~ 16GB VRAM max # n_batches=8, - batch_size=16, + batch_size=32, n_batches=4, activation_threshold=0.01, - max_depth=8, + max_depth=3, random_state=42, ) device: str = "cuda" if torch.cuda.is_available() else "cpu" @@ -188,24 +190,24 @@ # Heatmaps of true vs predicted activations (unsorted and sorted) # Unsorted version with layer boundaries -# plot_activations( -# layers_true=layers_true, -# layers_pred=layers_pred, -# module_keys=module_keys, -# activation_threshold=config.activation_threshold, -# sample_order=None, -# ) -# print("Activation plots (unsorted) generated.") +plot_activations( + layers_true=layers_true, + layers_pred=layers_pred, + module_keys=module_keys, + activation_threshold=config.activation_threshold, + sample_order=None, +) +print("Activation plots (unsorted) generated.") # # Sorted version with diff plot -# plot_activations( -# layers_true=layers_true, -# layers_pred=layers_pred, -# module_keys=module_keys, -# activation_threshold=config.activation_threshold, -# sample_order=sample_order, -# ) -# print("Activation plots (sorted by samples) generated.") +plot_activations( + layers_true=layers_true, + layers_pred=layers_pred, + module_keys=module_keys, + activation_threshold=config.activation_threshold, + sample_order=sample_order, +) +print("Activation plots (sorted by samples) generated.") # %% # ----------------------- plot: covariance ----------------------- @@ -228,35 +230,16 @@ ) print("Covariance plot (sorted by components) generated.") -# %% -# ----------------------- generate feature names ----------------------- -# Generate feature names with activation statistics and decoded directions - -# from spd.clustering.ci_dt.feature_names import generate_feature_names - -# module_keys = list(component_acts_concat.keys()) - -# feature_names = generate_feature_names( -# component_model=model, -# component_acts=component_acts_concat, -# layers_true=layers_true, -# layers_pred=layers_pred, -# tokenizer=cfg.task_config.tokenizer if hasattr(cfg.task_config, 'tokenizer') else None, -# module_keys=module_keys, -# top_k=3, -# ) -# print("Feature names generated.") - # %% # ----------------------- plot: worst trees ----------------------- # Decision tree visualization for worst performing trees -# plot_selected_trees(worst_list, "Worst", models, feature_names=feature_names) -# print("Worst trees plots generated.") +plot_selected_trees(worst_list, "Worst", models) +print("Worst trees plots generated.") -# # %% -# # ----------------------- plot: best trees ----------------------- -# # Decision tree visualization for best performing trees +# %% +# ----------------------- plot: best trees ----------------------- +# Decision tree visualization for best performing trees -# plot_selected_trees(best_list, "Best", models, feature_names=feature_names) -# print("Best trees plots generated.") +plot_selected_trees(best_list, "Best", models) +print("Best trees plots generated.") From 55c920684705d71f4f876dbfcdc9faf53d02d5b4 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 11:43:36 +0100 Subject: [PATCH 56/77] fix bug where we were avg across seq --- spd/clustering/ci_dt/config.py | 1 + spd/clustering/ci_dt/core.py | 3 ++- spd/clustering/ci_dt/pipeline.py | 37 ++++++++++++++++++++------------ spd/clustering/ci_dt/plot.py | 10 ++++++++- spd/clustering/ci_dt/run.py | 7 +++--- 5 files changed, 39 insertions(+), 19 deletions(-) diff --git a/spd/clustering/ci_dt/config.py b/spd/clustering/ci_dt/config.py index 0edf24e2b..5980d654b 100644 --- a/spd/clustering/ci_dt/config.py +++ b/spd/clustering/ci_dt/config.py @@ -9,6 +9,7 @@ class CIDTConfig: batch_size: int = 10 # Number of samples per batch for GPU inference n_batches: int = 25 # Number of batches to process (total samples = batch_size * n_batches) + n_ctx: int = 64 # Context length (sequence length) for tokenization activation_threshold: float = 0.01 # Threshold for boolean conversion max_depth: int = 8 # Maximum depth for decision trees random_state: int = 7 # Random state for reproducibility diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index 9147fcfb0..754b15e90 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -12,6 +12,7 @@ ) from sklearn.multioutput import MultiOutputClassifier from sklearn.tree import DecisionTreeClassifier +from tqdm import tqdm @dataclass @@ -59,7 +60,7 @@ def train_trees( """Train one decision tree per component per target layer using previous layers as features.""" XYs = build_xy(layers) models: list[LayerModel] = [] - for k, (X_k, Y_k) in enumerate(XYs, start=1): + for k, (X_k, Y_k) in tqdm(enumerate(XYs, start=1), total=len(XYs), desc="Training trees"): base = DecisionTreeClassifier( max_depth=max_depth, min_samples_leaf=min_samples_leaf, diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 2d07eed65..7de9d2584 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -30,7 +30,7 @@ def compute_activations_multibatch( After all batches: - Concatenate along batch dimension - - Apply seq_mean if ndim==3 (for LM tasks) + - Keep sequence dimension for per-token analysis (no seq_mean) Args: model: ComponentModel to get activations from @@ -39,7 +39,8 @@ def compute_activations_multibatch( n_batches: Number of batches to process Returns: - Dictionary mapping module keys to concatenated activations (on CPU, seq_mean applied) + Dictionary mapping module keys to concatenated activations + (on CPU, shape: batch, seq_len, n_components) """ print(f"Computing activations for {n_batches} batches...") all_component_acts: list[dict[str, Tensor]] = [] @@ -66,11 +67,9 @@ def compute_activations_multibatch( key: torch.cat([batch[key] for batch in all_component_acts], dim=0) for key in module_keys } - # Apply seq_mean if needed (LM task) - print("Applying seq_mean over sequence dimension...") - component_acts_concat = { - key: act.mean(dim=1) if act.ndim == 3 else act for key, act in component_acts_concat.items() - } + print("Activation shapes (keeping sequence dimension for per-token analysis):") + for key in module_keys[:3]: # Show first 3 for brevity + print(f" {key}: {component_acts_concat[key].shape}") return component_acts_concat @@ -82,22 +81,31 @@ def convert_to_boolean_layers( ) -> list[Bool[np.ndarray, "n_samples n_components"]]: """Convert activations to boolean, filter constant (always dead/alive) components. + Handles 3D activations (batch, seq_len, n_components) by flattening to 2D (batch*seq_len, n_components). + Args: - component_acts: Dictionary of continuous activations per module (on CPU) + component_acts: Dictionary of continuous activations per module (on CPU, shape: batch, seq_len, n_components or batch, n_components) activation_threshold: Threshold for converting to boolean Returns: - List of boolean numpy arrays, one per module (layer) + List of boolean numpy arrays, one per module (layer), shape (batch*seq_len, n_varying_components) """ print("\nConverting to boolean and filtering constant components...") layers_true: list[Bool[np.ndarray, "n_samples n_components"]] = [] module_keys: list[str] = list(component_acts.keys()) for module_key in module_keys: - # Convert to numpy and boolean - module_acts_np: Float[np.ndarray, "n_samples n_components"] = component_acts[ - module_key - ].numpy() + # Convert to numpy + module_acts_tensor: Tensor = component_acts[module_key] + + # Flatten if 3D (batch, seq_len, n_components) -> (batch*seq_len, n_components) + if module_acts_tensor.ndim == 3: + batch_size, seq_len, n_components = module_acts_tensor.shape + module_acts_np: Float[np.ndarray, "n_samples n_components"] = ( + module_acts_tensor.reshape(batch_size * seq_len, n_components).numpy() + ) + else: + module_acts_np = module_acts_tensor.numpy() module_acts_bool: Bool[np.ndarray, "n_samples n_components"] = ( module_acts_np >= activation_threshold @@ -130,7 +138,8 @@ def convert_to_boolean_layers( dbg_tensor(module_acts_bool) dbg_tensor(module_acts_varying) - print(f"\nCreated {len(layers_true)} layers for decision tree training") + n_samples: int = layers_true[0].shape[0] if layers_true else 0 + print(f"Created {len(layers_true)} layers with {n_samples} samples for decision tree training") return layers_true diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index c491383ec..beab49264 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -435,6 +435,10 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La ap_arr: np.ndarray = np.array(ap_list) depth_arr: np.ndarray = np.array(depth_list) + # Plot baseline: for uncorrelated variables, expected AP = prevalence + prev_range: np.ndarray = np.logspace(np.log10(prevalence_arr.min()), np.log10(prevalence_arr.max()), 100) + ax.plot(prev_range, prev_range, 'k--', alpha=0.5, linewidth=1.5, label='Random baseline (AP = prevalence)', zorder=1) + scatter = ax.scatter( prevalence_arr, ap_arr, @@ -443,17 +447,21 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La alpha=0.6, s=30, edgecolors="none", + markeredgewidth=0, + zorder=2, ) ax.set_title( r"Average Precision vs Component Prevalence" + "\n" - r"Prevalence = $\frac{n_{\text{active samples}}}{n_{\text{total samples}}}$, colored by tree depth" + r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$ where $P_n = \frac{\text{TP}}{\text{TP}+\text{FP}}$, $R_n = \frac{\text{TP}}{\text{TP}+\text{FN}}$" + "\n" + r"Colored by tree depth" ) ax.set_xlabel("Prevalence (log scale)") ax.set_ylabel("Average Precision") ax.set_xscale("log") ax.set_ylim(-0.05, 1.05) ax.grid(True, alpha=0.3) + ax.legend(loc='lower right') cbar = plt.colorbar(scatter, ax=ax) cbar.set_label("Tree Depth") diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 69b4f1360..0c8b9e02f 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -40,8 +40,9 @@ config = CIDTConfig( # batch_size=50, # 50 ~~ 16GB VRAM max # n_batches=8, - batch_size=32, - n_batches=4, + batch_size=16, + n_batches=2, + n_ctx=64, activation_threshold=0.01, max_depth=3, random_state=42, @@ -72,7 +73,7 @@ name=cfg.task_config.dataset_name, hf_tokenizer_path=pretrained_model_name, split=cfg.task_config.train_data_split, - n_ctx=cfg.task_config.max_seq_len, + n_ctx=config.n_ctx, column_name=cfg.task_config.column_name, is_tokenized=False, streaming=False, From 2db397e593c40419410eea9a476496af6d935ba2 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 12:16:55 +0100 Subject: [PATCH 57/77] wip --- spd/clustering/ci_dt/pipeline.py | 6 +++-- spd/clustering/ci_dt/plot.py | 11 +++++++-- spd/clustering/ci_dt/run.py | 41 ++++++++++++++++---------------- 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 7de9d2584..2afae338c 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -100,10 +100,12 @@ def convert_to_boolean_layers( # Flatten if 3D (batch, seq_len, n_components) -> (batch*seq_len, n_components) if module_acts_tensor.ndim == 3: - batch_size, seq_len, n_components = module_acts_tensor.shape + print(f" {module_key}: original shape = {module_acts_tensor.shape}") + # Keep last dimension (n_components) intact, flatten first two dimensions module_acts_np: Float[np.ndarray, "n_samples n_components"] = ( - module_acts_tensor.reshape(batch_size * seq_len, n_components).numpy() + module_acts_tensor.reshape(-1, module_acts_tensor.shape[-1]).numpy() ) + print(f" {module_key}: flattened shape = {module_acts_np.shape}") else: module_acts_np = module_acts_tensor.numpy() diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index beab49264..c26a27d6d 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -6,6 +6,7 @@ import numpy as np from jaxtyping import Float, Int from sklearn.tree import plot_tree +import torch from spd.clustering.ci_dt.core import LayerModel, get_estimator_for @@ -447,7 +448,7 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La alpha=0.6, s=30, edgecolors="none", - markeredgewidth=0, + linewidths=0, zorder=2, ) @@ -470,7 +471,7 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La def plot_component_activity_breakdown( - component_acts: dict[str, np.ndarray], + component_acts: dict[str, np.ndarray|torch.Tensor], module_keys: list[str], activation_threshold: float, logy: bool = False, @@ -494,6 +495,12 @@ def plot_component_activity_breakdown( # Convert to numpy if needed if hasattr(acts, "cpu"): acts = acts.cpu().numpy() + + # Flatten if 3D (batch, seq_len, n_components) -> (batch*seq_len, n_components) + # This treats each token position as a separate sample, consistent with decision tree training + if acts.ndim == 3: + acts = acts.reshape(-1, acts.shape[-1]) + # Convert to boolean acts_bool: np.ndarray = (acts >= activation_threshold).astype(bool) diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 0c8b9e02f..9e9a118bf 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -31,8 +31,8 @@ from spd.models.component_model import ComponentModel, SPDRunInfo # magic autoreload -# %load_ext autoreload -# %autoreload 2 +%load_ext autoreload +%autoreload 2 # %% # ----------------------- configuration ----------------------- @@ -98,6 +98,25 @@ dataloader=dataloader, n_batches=config.n_batches, ) +# Get module keys for labeling +module_keys: list[str] = list(component_acts_concat.keys()) + +# %% +# ----------------------- plot: component activity breakdown ----------------------- + +plot_component_activity_breakdown( + component_acts_concat, + module_keys, + config.activation_threshold, + logy=False, +) +plot_component_activity_breakdown( + component_acts_concat, + module_keys, + config.activation_threshold, + logy=True, +) +print("Component activity breakdown plot generated.") # %% # ----------------------- convert to boolean layers ----------------------- @@ -128,9 +147,6 @@ # ----------------------- compute orderings ----------------------- # Generate sample ordering once for use in multiple plots -# Get module keys for labeling -module_keys: list[str] = list(component_acts_concat.keys()) - # Concatenate true activations for ordering A_true_concat: np.ndarray = np.concatenate(layers_true, axis=1).astype(float) @@ -162,22 +178,7 @@ plot_ap_vs_prevalence(per_layer_stats, models) print("AP vs prevalence plot generated.") -# %% -# ----------------------- plot: component activity breakdown ----------------------- -plot_component_activity_breakdown( - component_acts_concat, - module_keys, - config.activation_threshold, - logy=False, -) -plot_component_activity_breakdown( - component_acts_concat, - module_keys, - config.activation_threshold, - logy=True, -) -print("Component activity breakdown plot generated.") # %% # ----------------------- plot: tree statistics ----------------------- From 79c9f44888ddea6f7c7e3134a01595e1938a6ba4 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 14:32:45 +0100 Subject: [PATCH 58/77] wip --- spd/clustering/ci_dt/core.py | 114 +++++++++---- spd/clustering/ci_dt/pipeline.py | 16 +- spd/clustering/ci_dt/plot.py | 280 ++++++++++++++++--------------- spd/clustering/ci_dt/run.py | 81 +++------ 4 files changed, 260 insertions(+), 231 deletions(-) diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index 754b15e90..8b537ce7a 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -2,6 +2,8 @@ from collections.abc import Sequence from dataclasses import dataclass +from typing import Literal +import warnings import numpy as np from jaxtyping import Bool, Float @@ -76,27 +78,17 @@ def extract_prob_class_1( proba_list: list[np.ndarray], model: MultiOutputClassifier, ) -> np.ndarray: - """Extract P(y=1) for each output, handling constant components. + """Extract P(y=1) for each output. - When a component is always 0 or always 1 in training data, - sklearn only returns probabilities for the observed class. - This function handles all cases correctly. + Assumes constant components are filtered out, so both classes should always be present. """ result: list[np.ndarray] = [] for i, p in enumerate(proba_list): estimator = model.estimators_[i] classes = estimator.classes_ - if len(classes) == 1: - # Only one class observed during training - if classes[0] == 0: - # Only saw class 0, so P(y=1) = 0 - result.append(np.zeros(p.shape[0])) - else: # classes[0] == 1 - # Only saw class 1, so P(y=1) = 1 - result.append(np.ones(p.shape[0])) - else: - # Saw both classes, extract P(y=1) from second column - result.append(p[:, 1]) + assert len(classes) == 2, f"Expected 2 classes but got {len(classes)} for output {i}" + # Extract P(y=1) from second column + result.append(p[:, 1]) return np.stack(result, axis=1) @@ -136,35 +128,89 @@ def predict_all( return out +MetricKey = Literal["ap", "acc", "bacc", "prev", "tpr", "tnr", "precision", "npv", "f1"] + def layer_metrics( Y_true: Bool[np.ndarray, "n t"], Y_prob: Float[np.ndarray, "n t"], Y_pred: Bool[np.ndarray, "n t"], -) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Return per-target AP, acc, bacc, prevalence.""" +) -> dict[MetricKey, np.ndarray]: + """Return per-target metrics: AP, acc, bacc, prevalence, TPR, TNR, precision, NPV, F1. + + Returns: + Dictionary with keys: + - ap: Average precision + - acc: Accuracy + - bacc: Balanced accuracy + - prev: Prevalence (fraction of positive samples) + - tpr: True Positive Rate (Recall/Sensitivity) + - tnr: True Negative Rate (Specificity) + - precision: Precision (when we predict active, how often are we right?) + - npv: Negative Predictive Value (when we predict inactive, how often are we right?) + - f1: F1 score + + Each value is an array of length T (number of target components). + """ T: int = Y_true.shape[1] - ap: np.ndarray = np.zeros(T) - acc: np.ndarray = np.zeros(T) - bacc: np.ndarray = np.zeros(T) - prev: np.ndarray = np.zeros(T) + + ap: Float[np.ndarray, " t"] = np.full(T, np.nan) + acc: Float[np.ndarray, " t"] = np.full(T, np.nan) + bacc: Float[np.ndarray, " t"] = np.full(T, np.nan) + prev: Float[np.ndarray, " t"] = np.full(T, np.nan) + tpr: Float[np.ndarray, " t"] = np.full(T, np.nan) + tnr: Float[np.ndarray, " t"] = np.full(T, np.nan) + precision: Float[np.ndarray, " t"] = np.full(T, np.nan) + npv: Float[np.ndarray, " t"] = np.full(T, np.nan) + f1: Float[np.ndarray, " t"] = np.full(T, np.nan) + for j in range(T): y: np.ndarray = Y_true[:, j].astype(int) p: np.ndarray = Y_prob[:, j] yhat: np.ndarray = Y_pred[:, j].astype(int) prev[j] = float(y.mean()) - try: - ap[j] = average_precision_score(y, p) - except Exception: - ap[j] = np.nan - try: - acc[j] = accuracy_score(y, yhat) - except Exception: - acc[j] = np.nan - try: - bacc[j] = balanced_accuracy_score(y, yhat) - except Exception: - bacc[j] = np.nan - return ap, acc, bacc, prev + + # Compute confusion matrix elements + tp: int = int(((y == 1) & (yhat == 1)).sum()) + tn: int = int(((y == 0) & (yhat == 0)).sum()) + fp: int = int(((y == 0) & (yhat == 1)).sum()) + fn: int = int(((y == 1) & (yhat == 0)).sum()) + + # TPR (Recall/Sensitivity) = TP / (TP + FN) + tpr[j] = tp / (tp + fn) + + # TNR (Specificity) = TN / (TN + FP) + tnr[j] = tn / (tn + fp) + + # Precision (PPV) = TP / (TP + FP) - when we predict active, how often are we right? + if (tp + fp) > 0: + precision[j] = tp / (tp + fp) + else: + precision[j] = np.nan + warnings.warn(f"Precision failed: {tp=}, {fp=}, {tp+fp=}") + + # Negative Predictive Value = TN / (TN + FN) - when we predict inactive, how often are we right? + npv[j] = tn / (tn + fn) + + # F1 = 2 * (precision * recall) / (precision + recall) + f1[j] = 2 * (precision[j] * tpr[j]) / (precision[j] + tpr[j]) + + # Sklearn metrics + ap[j] = average_precision_score(y, p) + acc[j] = accuracy_score(y, yhat) + bacc[j] = balanced_accuracy_score(y, yhat) + + + return { + "ap": ap, + "acc": acc, + "bacc": bacc, + "prev": prev, + "tpr": tpr, + "tnr": tnr, + "precision": precision, + "npv": npv, + "f1": f1, + } def proba_for_layer(lm: LayerModel, X: np.ndarray) -> np.ndarray: diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 2afae338c..4b0c10750 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -169,19 +169,17 @@ def compute_tree_metrics( for lm, (Xk, Yk) in zip(models, XYs_demo, strict=True): Pk: np.ndarray = proba_for_layer(lm, Xk) Yhat_k: np.ndarray = Pk >= 0.5 - ap, acc, bacc, prev = layer_metrics(Yk, Pk, Yhat_k) + metrics = layer_metrics(Yk, Pk, Yhat_k) per_layer_stats.append( { - "ap": ap, - "acc": acc, - "bacc": bacc, - "prev": prev, - "mean_ap": float(np.nanmean(ap)), - "mean_acc": float(np.nanmean(acc)), - "mean_bacc": float(np.nanmean(bacc)), + **metrics, + **{ + f"mean_{key}": float(np.nanmean(values)) + for key, values in metrics.items() + } } ) - for j, apj in enumerate(ap): + for j, apj in enumerate(metrics["ap"]): all_triplets.append((lm.layer_index, j, float(apj))) # identify best and worst trees across all outputs by AP diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index c26a27d6d..fc7cfa6c0 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -8,7 +8,145 @@ from sklearn.tree import plot_tree import torch -from spd.clustering.ci_dt.core import LayerModel, get_estimator_for +from spd.clustering.ci_dt.core import LayerModel, MetricKey, get_estimator_for + +METRIC_DISPLAY_INFO: dict[MetricKey, dict[str, str]] = { + "ap": { + "ylabel": "Average Precision", + "title": ( + r"Average Precision per Target Component" + "\n" + r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$ where " + r"$P_n = \frac{\text{TP}}{\text{TP}+\text{FP}}$, " + r"$R_n = \frac{\text{TP}}{\text{TP}+\text{FN}}$" + ), + "color": "C0", + }, + "acc": { + "ylabel": "Accuracy", + "title": ( + r"Accuracy per Target Component" + "\n" + r"$\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}}$" + ), + "color": "C1", + }, + "bacc": { + "ylabel": "Balanced Accuracy", + "title": ( + r"Balanced Accuracy per Target Component" + "\n" + r"$\text{Balanced Acc} = \frac{1}{2}\left(\frac{\text{TP}}{\text{TP}+\text{FN}} + \frac{\text{TN}}{\text{TN}+\text{FP}}\right)$" + ), + "color": "C2", + }, + "prev": { + "ylabel": "Prevalence", + "title": ( + r"Component Prevalence" + "\n" + r"$\text{Prevalence} = \frac{\text{TP}+\text{FN}}{\text{TP}+\text{TN}+\text{FP}+\text{FN}}$ (fraction of samples where component is active)" + ), + "color": "C5", + }, + "tpr": { + "ylabel": "TPR", + "title": ( + r"True Positive Rate (TPR / Recall / Sensitivity)" + "\n" + r"$\text{TPR} = \frac{\text{TP}}{\text{TP}+\text{FN}}$ (how well we predict active components)" + ), + "color": "C0", + }, + "tnr": { + "ylabel": "TNR", + "title": ( + r"True Negative Rate (TNR / Specificity)" + "\n" + r"$\text{TNR} = \frac{\text{TN}}{\text{TN}+\text{FP}}$ (how well we predict inactive components)" + ), + "color": "C1", + }, + "precision": { + "ylabel": "Precision", + "title": ( + r"Precision (Positive Predictive Value)" + "\n" + r"$\text{PPV} = \frac{\text{TP}}{\text{TP}+\text{FP}}$ (when we predict active, how often are we right?)" + ), + "color": "C2", + }, + "npv": { + "ylabel": "NPV", + "title": ( + r"Negative Predictive Value (NPV)" + "\n" + r"$\text{NPV} = \frac{\text{TN}}{\text{TN}+\text{FN}}$ (when we predict inactive, how often are we right?)" + ), + "color": "C3", + }, + "f1": { + "ylabel": "F1 Score", + "title": ( + r"F1 Score per Target Component" + "\n" + r"$F1 = 2 \cdot \frac{\text{Precision} \cdot \text{Recall}}{\text{Precision} + \text{Recall}}$ (harmonic mean)" + ), + "color": "C4", + }, +} + + +def _plot_metric_scatter( + ax: plt.Axes, + per_layer_stats: list[dict[str, Any]], + module_keys: list[str], + metric_key: MetricKey, + jitter_amount: float = 0.15, +) -> None: + """Helper function to plot jittered scatter with mean lines for a metric. + + Handles all formatting including axis labels, ticks, grid, and module name cleaning. + Display properties (title, ylabel, color) are looked up from METRIC_DISPLAY_INFO. + + Args: + ax: Matplotlib axis to plot on + per_layer_stats: List of dicts with metrics per layer + module_keys: List of module names for x-axis labels + metric_key: Key for metric (e.g., "tpr", "npv", "f1") + jitter_amount: Amount of horizontal jitter for scatter points + """ + # Look up display properties + display_info = METRIC_DISPLAY_INFO[metric_key] + mean_key = f"mean_{metric_key}" + ylabel = display_info["ylabel"] + title = display_info["title"] + color = display_info["color"] + + L: int = len(per_layer_stats) + np.random.seed(42) + + # Plot scatter and means + for layer_idx, stats in enumerate(per_layer_stats): + values: np.ndarray = stats[metric_key] + valid: np.ndarray = values[~np.isnan(values)] + if len(valid) > 0: + x_positions: np.ndarray = np.ones(len(valid)) * (layer_idx + 1) + x_jittered: np.ndarray = x_positions + np.random.uniform( + -jitter_amount, jitter_amount, len(valid) + ) + ax.scatter(x_jittered, valid, alpha=0.5, s=20, color=color, edgecolors="none") + ax.plot( + [layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], + [stats[mean_key], stats[mean_key]], + "r-", + linewidth=2, + label="Mean" if layer_idx == 0 else "", + ) + + # Clean module names + clean_keys = [k.removeprefix("model.layers.").replace("_proj", "") for k in module_keys] + + # Formatting + ax.set_title(title) + ax.set_xlabel("Target Module") + ax.set_ylabel(ylabel) + ax.set_xticks(np.arange(1, L + 1)) + ax.set_xticklabels(clean_keys[1 : L + 1], rotation=45, ha="right") + ax.set_ylim(-0.05, 1.05) + ax.grid(True, alpha=0.3, axis="y") + ax.legend() def greedy_sort(A: np.ndarray, axis: int) -> np.ndarray: @@ -266,147 +404,27 @@ def plot_covariance( fig.tight_layout() -def plot_average_precision( +def plot_metric( per_layer_stats: list[dict[str, Any]], module_keys: list[str], + metric_key: MetricKey, ) -> None: - """Plot distribution of average precision per layer with scatter plot and jitter. - - Args: - per_layer_stats: List of dicts with metrics per layer - module_keys: List of module names for x-axis labels - """ - L: int = len(per_layer_stats) - np.random.seed(42) # Reproducible jitter - jitter_amount: float = 0.15 + """Plot distribution of a metric per layer with scatter plot and jitter. - fig, ax = plt.subplots(figsize=(10, 5)) - for layer_idx, stats in enumerate(per_layer_stats): - ap_values: np.ndarray = stats["ap"] - # Remove NaN values - ap_valid: np.ndarray = ap_values[~np.isnan(ap_values)] - if len(ap_valid) > 0: - # Add horizontal jitter - x_positions: np.ndarray = np.ones(len(ap_valid)) * (layer_idx + 1) - x_jittered: np.ndarray = x_positions + np.random.uniform( - -jitter_amount, jitter_amount, len(ap_valid) - ) - ax.scatter(x_jittered, ap_valid, alpha=0.5, s=20, color="C0", edgecolors="none") - # Add mean line - ax.plot( - [layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], - [stats["mean_ap"], stats["mean_ap"]], - "r-", - linewidth=2, - label="Mean" if layer_idx == 0 else "", - ) - - ax.set_title( - r"Average Precision per Target Component" + "\n" - r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$ where " - r"$P_n = \frac{\text{TP}}{\text{TP}+\text{FP}}$, " - r"$R_n = \frac{\text{TP}}{\text{TP}+\text{FN}}$" - ) - ax.set_xlabel("Target Module") - ax.set_ylabel("Average Precision") - ax.set_xticks(np.arange(1, L + 1)) - # Only use module keys that correspond to target layers (skip input layer) - ax.set_xticklabels(module_keys[1 : L + 1], rotation=45, ha="right") - ax.set_ylim(-0.05, 1.05) - ax.grid(True, alpha=0.3, axis="y") - ax.legend() - fig.tight_layout() - - -def plot_accuracy( - per_layer_stats: list[dict[str, Any]], - module_keys: list[str], -) -> None: - """Plot distribution of accuracy per layer with scatter plot and jitter. - - Args: - per_layer_stats: List of dicts with metrics per layer - module_keys: List of module names for x-axis labels - """ - L: int = len(per_layer_stats) - np.random.seed(42) # Reproducible jitter - jitter_amount: float = 0.15 - - fig, ax = plt.subplots(figsize=(10, 5)) - for layer_idx, stats in enumerate(per_layer_stats): - acc_values: np.ndarray = stats["acc"] - acc_valid: np.ndarray = acc_values[~np.isnan(acc_values)] - if len(acc_valid) > 0: - x_positions: np.ndarray = np.ones(len(acc_valid)) * (layer_idx + 1) - x_jittered: np.ndarray = x_positions + np.random.uniform( - -jitter_amount, jitter_amount, len(acc_valid) - ) - ax.scatter(x_jittered, acc_valid, alpha=0.5, s=20, color="C1", edgecolors="none") - ax.plot( - [layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], - [stats["mean_acc"], stats["mean_acc"]], - "r-", - linewidth=2, - label="Mean" if layer_idx == 0 else "", - ) - - ax.set_title( - r"Accuracy per Target Component" + "\n" - r"$\text{Accuracy} = \frac{\text{TP} + \text{TN}}{\text{TP} + \text{TN} + \text{FP} + \text{FN}}$" - ) - ax.set_xlabel("Target Module") - ax.set_ylabel("Accuracy") - ax.set_xticks(np.arange(1, L + 1)) - ax.set_xticklabels(module_keys[1 : L + 1], rotation=45, ha="right") - ax.set_ylim(-0.05, 1.05) - ax.grid(True, alpha=0.3, axis="y") - ax.legend() - fig.tight_layout() - - -def plot_balanced_accuracy( - per_layer_stats: list[dict[str, Any]], - module_keys: list[str], -) -> None: - """Plot distribution of balanced accuracy per layer with scatter plot and jitter. + Display properties (title, ylabel, color) are looked up from METRIC_DISPLAY_INFO. Args: per_layer_stats: List of dicts with metrics per layer module_keys: List of module names for x-axis labels + metric_key: Key for metric to plot (e.g., "tpr", "npv", "f1", "ap", "bacc") """ - L: int = len(per_layer_stats) - np.random.seed(42) # Reproducible jitter - jitter_amount: float = 0.15 - fig, ax = plt.subplots(figsize=(10, 5)) - for layer_idx, stats in enumerate(per_layer_stats): - bacc_values: np.ndarray = stats["bacc"] - bacc_valid: np.ndarray = bacc_values[~np.isnan(bacc_values)] - if len(bacc_valid) > 0: - x_positions: np.ndarray = np.ones(len(bacc_valid)) * (layer_idx + 1) - x_jittered: np.ndarray = x_positions + np.random.uniform( - -jitter_amount, jitter_amount, len(bacc_valid) - ) - ax.scatter(x_jittered, bacc_valid, alpha=0.5, s=20, color="C2", edgecolors="none") - ax.plot( - [layer_idx + 1 - 0.3, layer_idx + 1 + 0.3], - [stats["mean_bacc"], stats["mean_bacc"]], - "r-", - linewidth=2, - label="Mean" if layer_idx == 0 else "", - ) - - ax.set_title( - r"Balanced Accuracy per Target Component" + "\n" - r"$\text{Balanced Acc} = \frac{1}{2}\left(\frac{\text{TP}}{\text{TP}+\text{FN}} + \frac{\text{TN}}{\text{TN}+\text{FP}}\right)$" + _plot_metric_scatter( + ax=ax, + per_layer_stats=per_layer_stats, + module_keys=module_keys, + metric_key=metric_key, ) - ax.set_xlabel("Target Module") - ax.set_ylabel("Balanced Accuracy") - ax.set_xticks(np.arange(1, L + 1)) - ax.set_xticklabels(module_keys[1 : L + 1], rotation=45, ha="right") - ax.set_ylim(-0.05, 1.05) - ax.grid(True, alpha=0.3, axis="y") - ax.legend() fig.tight_layout() diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 9e9a118bf..bd94ab3f6 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -15,13 +15,11 @@ ) from spd.clustering.ci_dt.plot import ( greedy_sort, - plot_accuracy, plot_activations, plot_ap_vs_prevalence, - plot_average_precision, - plot_balanced_accuracy, plot_component_activity_breakdown, plot_covariance, + plot_metric, plot_selected_trees, plot_tree_statistics, ) @@ -40,9 +38,9 @@ config = CIDTConfig( # batch_size=50, # 50 ~~ 16GB VRAM max # n_batches=8, - batch_size=16, + batch_size=2, n_batches=2, - n_ctx=64, + n_ctx=16, activation_threshold=0.01, max_depth=3, random_state=42, @@ -116,7 +114,6 @@ config.activation_threshold, logy=True, ) -print("Component activity breakdown plot generated.") # %% # ----------------------- convert to boolean layers ----------------------- @@ -155,43 +152,29 @@ print(f"Computed sample ordering ({len(sample_order)} samples)") # %% -# ----------------------- plot: average precision ----------------------- - -plot_average_precision(per_layer_stats, module_keys) -print("Average precision plot generated.") - -# %% -# ----------------------- plot: accuracy ----------------------- - -plot_accuracy(per_layer_stats, module_keys) -print("Accuracy plot generated.") - -# %% -# ----------------------- plot: balanced accuracy ----------------------- - -plot_balanced_accuracy(per_layer_stats, module_keys) -print("Balanced accuracy plot generated.") - -# %% -# ----------------------- plot: AP vs prevalence ----------------------- +# ----------------------- plots: metrics ----------------------- +plot_metric(per_layer_stats, module_keys, "ap") +plot_metric(per_layer_stats, module_keys, "acc") +plot_metric(per_layer_stats, module_keys, "bacc") +plot_metric(per_layer_stats, module_keys, "prev") +plot_metric(per_layer_stats, module_keys, "tpr") +plot_metric(per_layer_stats, module_keys, "tnr") +plot_metric(per_layer_stats, module_keys, "precision") +plot_metric(per_layer_stats, module_keys, "npv") +plot_metric(per_layer_stats, module_keys, "f1") plot_ap_vs_prevalence(per_layer_stats, models) -print("AP vs prevalence plot generated.") # %% -# ----------------------- plot: tree statistics ----------------------- -# Distributions of tree depth, leaf counts, and correlations with accuracy +# ----------------------- plots: tree statistics ----------------------- plot_tree_statistics(models, per_layer_stats) -print("Tree statistics plots generated.") # %% -# ----------------------- plot: activations ----------------------- -# Heatmaps of true vs predicted activations (unsorted and sorted) +# ----------------------- plots: activations ----------------------- -# Unsorted version with layer boundaries plot_activations( layers_true=layers_true, layers_pred=layers_pred, @@ -199,49 +182,33 @@ activation_threshold=config.activation_threshold, sample_order=None, ) -print("Activation plots (unsorted) generated.") -# # Sorted version with diff plot -plot_activations( - layers_true=layers_true, - layers_pred=layers_pred, - module_keys=module_keys, - activation_threshold=config.activation_threshold, - sample_order=sample_order, -) -print("Activation plots (sorted by samples) generated.") +# plot_activations( +# layers_true=layers_true, +# layers_pred=layers_pred, +# module_keys=module_keys, +# activation_threshold=config.activation_threshold, +# sample_order=sample_order, +# ) # %% -# ----------------------- plot: covariance ----------------------- -# Covariance matrix - can be slow with many components +# ----------------------- plots: covariance ----------------------- -# Unsorted version with layer boundaries plot_covariance( layers_true=layers_true, module_keys=module_keys, component_order=None, ) -print("Covariance plot (unsorted) generated.") -# Sorted version by component similarity component_order: np.ndarray = greedy_sort(A_true_concat, axis=1) plot_covariance( layers_true=layers_true, module_keys=module_keys, component_order=component_order, ) -print("Covariance plot (sorted by components) generated.") # %% -# ----------------------- plot: worst trees ----------------------- -# Decision tree visualization for worst performing trees +# ----------------------- plots: decision trees ----------------------- plot_selected_trees(worst_list, "Worst", models) -print("Worst trees plots generated.") - -# %% -# ----------------------- plot: best trees ----------------------- -# Decision tree visualization for best performing trees - plot_selected_trees(best_list, "Best", models) -print("Best trees plots generated.") From 40df505c543bda9d5e21938728159a217640fc78 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 16:33:17 +0100 Subject: [PATCH 59/77] remove idx_in_ensemble, always auto-assigned now see https://github.com/goodfire-ai/spd/pull/227#discussion_r2454317036 --- spd/clustering/clustering_run_config.py | 21 +-- spd/clustering/configs/crc/example.yaml | 1 - spd/clustering/dataset.py | 2 +- spd/clustering/ensemble_registry.py | 21 +-- spd/clustering/scripts/run_clustering.py | 29 ++--- tests/clustering/test_ensemble_registry.py | 145 +++------------------ 6 files changed, 42 insertions(+), 177 deletions(-) diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py index f82e00203..95d72f9bd 100644 --- a/spd/clustering/clustering_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -4,9 +4,9 @@ import hashlib import json from pathlib import Path -from typing import Any, Literal, Self +from typing import Any -from pydantic import Field, NonNegativeInt, PositiveInt, field_validator, model_validator +from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig from spd.clustering.merge_config import MergeConfig @@ -31,10 +31,6 @@ class LoggingIntervals(BaseConfig): ) -ClusteringEnsembleIndex = NonNegativeInt | Literal[-1] -"index in an ensemble; -1 will cause register_clustering_run() to auto-assign the next available index" - - class ClusteringRunConfig(BaseConfig): """Configuration for a single clustering run. @@ -58,12 +54,6 @@ class ClusteringRunConfig(BaseConfig): default=None, description="Ensemble identifier for WandB grouping", ) - # TODO: given our use of `register_clustering_run()` and the atomic guarantees of that, do we even need this index? - # probably still nice to have for clarity - idx_in_ensemble: ClusteringEnsembleIndex | None = Field( - default=None, description="Index of this run in the ensemble" - ) - merge_config: MergeConfig = Field(description="Merge algorithm configuration") logging_intervals: LoggingIntervals = Field( default_factory=LoggingIntervals, @@ -108,13 +98,6 @@ def validate_model_path(cls, v: str) -> str: raise ValueError(f"model_path must start with 'wandb:', got: {v}") return v - @model_validator(mode="after") - def validate_ensemble_id_index(self) -> Self: - assert (self.idx_in_ensemble is None) == (self.ensemble_id is None), ( - "If ensemble_id is None, idx_in_ensemble must also be None" - ) - return self - @property def wandb_decomp_model(self) -> str: """Extract the WandB run ID of the source decomposition.""" diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml index 3729941ce..9345307d2 100644 --- a/spd/clustering/configs/crc/example.yaml +++ b/spd/clustering/configs/crc/example.yaml @@ -1,7 +1,6 @@ model_path: wandb:goodfire/spd/runs/zxbu57pt # WandB path to the decomposed model batch_size: 8 # Batch size for processing -- number of samples for each run in the ensemble dataset_seed: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) -# idx_in_ensemble: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) # output_dir: .data/clustering/clustering_runs # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) # ensemble_id: 1234567890 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index c514aa69f..ea9b9f904 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -23,7 +23,7 @@ def load_dataset( ) -> BatchTensor: """Load a single batch for clustering. - Each run gets its own dataset batch, seeded by idx_in_ensemble. + Each run gets its own dataset batch, seeded by index in ensemble. Args: model_path: Path to decomposed model diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py index 540312d8e..c54fe408b 100644 --- a/spd/clustering/ensemble_registry.py +++ b/spd/clustering/ensemble_registry.py @@ -6,7 +6,6 @@ import sqlite3 from contextlib import contextmanager -from spd.clustering.clustering_run_config import ClusteringEnsembleIndex from spd.settings import SPD_CACHE_DIR # SQLite database path @@ -40,9 +39,7 @@ def _get_connection(): conn.close() -def register_clustering_run( - pipeline_run_id: str, idx: ClusteringEnsembleIndex, clustering_run_id: str -) -> int: +def register_clustering_run(pipeline_run_id: str, clustering_run_id: str) -> int: """Register a clustering run as part of a pipeline ensemble. Args: @@ -57,16 +54,12 @@ def register_clustering_run( # Use BEGIN IMMEDIATE for thread-safe auto-increment conn.execute("BEGIN IMMEDIATE") - assigned_idx: int - if idx == -1: - # Auto-assign next available index - cursor = conn.execute( - "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", - (pipeline_run_id,), - ) - assigned_idx = cursor.fetchone()[0] - else: - assigned_idx = idx + # Auto-assign next available index, we rely on atomicity of the transaction here + cursor = conn.execute( + "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", + (pipeline_run_id,), + ) + assigned_idx: int = cursor.fetchone()[0] conn.execute( "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 6f8dfd722..54f0805c6 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -228,25 +228,23 @@ def main(run_config: ClusteringRunConfig) -> Path: logger.info(f"Clustering run ID: {clustering_run_id}") # Register with ensemble if this is part of a pipeline + assigned_idx: int | None if run_config.ensemble_id: - assert run_config.idx_in_ensemble is not None, ( - "idx_in_ensemble must be set when ensemble_id is provided! to auto-assign, set idx_in_ensemble = -1.\n" - f"{'!' * 50}\nNOTE: this should be an unreachable state -- such a case should have been caught by the pydantic validator.\n{'!' * 50}" + assigned_idx = register_clustering_run( + pipeline_run_id=run_config.ensemble_id, + clustering_run_id=clustering_run_id, ) - assigned_idx: int = register_clustering_run( - run_config.ensemble_id, - run_config.idx_in_ensemble, - clustering_run_id, - ) - - # Update config if index was auto-assigned - if run_config.idx_in_ensemble == -1: - run_config = replace_pydantic_model(run_config, {"idx_in_ensemble": assigned_idx}) - logger.info(f"Auto-assigned ensemble index: {assigned_idx}") logger.info( f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx} in {_ENSEMBLE_REGISTRY_DB}" ) + # IMPORTANT: set dataset seed based on assigned index + run_config = replace_pydantic_model( + run_config, + {"dataset_seed": run_config.dataset_seed + assigned_idx}, + ) + else: + assigned_idx = None # save config run_config.to_file(storage.config_path) @@ -292,7 +290,7 @@ def main(run_config: ClusteringRunConfig) -> Path: f"task:{task_name}", f"model:{run_config.wandb_decomp_model}", f"ensemble_id:{run_config.ensemble_id}", - f"idx:{run_config.idx_in_ensemble}", + f"assigned_idx:{assigned_idx}", ], ) # logger.info(f"WandB run: {wandb_run.url}") @@ -426,9 +424,6 @@ def cli() -> None: } # Handle ensemble-related overrides - if args.idx_in_ensemble is not None: - overrides["dataset_seed"] = run_config.dataset_seed + args.idx_in_ensemble - overrides["idx_in_ensemble"] = args.idx_in_ensemble if args.pipeline_run_id is not None: overrides["ensemble_id"] = args.pipeline_run_id diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py index bb2936cfd..c903af801 100644 --- a/tests/clustering/test_ensemble_registry.py +++ b/tests/clustering/test_ensemble_registry.py @@ -24,58 +24,27 @@ def _temp_registry_db(monkeypatch: Any): # pyright: ignore[reportUnusedFunction class TestRegisterClusteringRun: """Test register_clustering_run() function.""" - def test_register_with_explicit_index(self, _temp_registry_db: Any): - """Test registering a run with an explicit index.""" + def test_register_single_run(self, _temp_registry_db: Any): + """Test registering a single run.""" pipeline_id = "pipeline_001" - idx = 0 run_id = "run_001" - assigned_idx = register_clustering_run(pipeline_id, idx, run_id) + assigned_idx = register_clustering_run(pipeline_id, run_id) - # Should return the same index - assert assigned_idx == idx - - # Verify in database - runs = get_clustering_runs(pipeline_id) - assert runs == [(0, "run_001")] - - def test_register_multiple_explicit_indices(self, _temp_registry_db: Any): - """Test registering multiple runs with explicit indices.""" - pipeline_id = "pipeline_002" - - idx0 = register_clustering_run(pipeline_id, 0, "run_001") - idx1 = register_clustering_run(pipeline_id, 1, "run_002") - idx2 = register_clustering_run(pipeline_id, 2, "run_003") - - assert idx0 == 0 - assert idx1 == 1 - assert idx2 == 2 - - # Verify order in database - runs = get_clustering_runs(pipeline_id) - assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] - - def test_auto_assign_single_index(self, _temp_registry_db: Any): - """Test auto-assigning a single index.""" - pipeline_id = "pipeline_003" - run_id = "run_001" - - assigned_idx = register_clustering_run(pipeline_id, -1, run_id) - - # First auto-assigned index should be 0 + # First index should be 0 assert assigned_idx == 0 # Verify in database runs = get_clustering_runs(pipeline_id) assert runs == [(0, "run_001")] - def test_auto_assign_multiple_indices(self, _temp_registry_db: Any): - """Test auto-assigning multiple indices sequentially.""" - pipeline_id = "pipeline_004" + def test_register_multiple_runs(self, _temp_registry_db: Any): + """Test registering multiple runs sequentially.""" + pipeline_id = "pipeline_002" - idx0 = register_clustering_run(pipeline_id, -1, "run_001") - idx1 = register_clustering_run(pipeline_id, -1, "run_002") - idx2 = register_clustering_run(pipeline_id, -1, "run_003") + idx0 = register_clustering_run(pipeline_id, "run_001") + idx1 = register_clustering_run(pipeline_id, "run_002") + idx2 = register_clustering_run(pipeline_id, "run_003") # Should auto-assign 0, 1, 2 assert idx0 == 0 @@ -86,95 +55,21 @@ def test_auto_assign_multiple_indices(self, _temp_registry_db: Any): runs = get_clustering_runs(pipeline_id) assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] - def test_auto_assign_after_explicit_indices(self, _temp_registry_db: Any): - """Test that auto-assignment continues from max existing index.""" - pipeline_id = "pipeline_005" - - # Register explicit indices - register_clustering_run(pipeline_id, 0, "run_001") - register_clustering_run(pipeline_id, 1, "run_002") - - # Auto-assign should get index 2 - idx = register_clustering_run(pipeline_id, -1, "run_003") - assert idx == 2 - - # Auto-assign again should get index 3 - idx = register_clustering_run(pipeline_id, -1, "run_004") - assert idx == 3 - - # Verify in database - runs = get_clustering_runs(pipeline_id) - assert runs == [ - (0, "run_001"), - (1, "run_002"), - (2, "run_003"), - (3, "run_004"), - ] - - def test_auto_assign_with_gaps(self, _temp_registry_db: Any): - """Test that auto-assignment uses max+1, even with gaps.""" - pipeline_id = "pipeline_006" - - # Register with gaps: 0, 5, 10 - register_clustering_run(pipeline_id, 0, "run_001") - register_clustering_run(pipeline_id, 5, "run_002") - register_clustering_run(pipeline_id, 10, "run_003") - - # Auto-assign should get index 11 (max + 1) - idx = register_clustering_run(pipeline_id, -1, "run_004") - assert idx == 11 - - # Verify in database (ordered by idx) - runs = get_clustering_runs(pipeline_id) - assert runs == [ - (0, "run_001"), - (5, "run_002"), - (10, "run_003"), - (11, "run_004"), - ] - - def test_mixed_explicit_and_auto_assign(self, _temp_registry_db: Any): - """Test mixing explicit and auto-assigned indices.""" - pipeline_id = "pipeline_007" - - # Mix of explicit and auto-assigned - idx0 = register_clustering_run(pipeline_id, -1, "run_001") # auto: 0 - idx1 = register_clustering_run(pipeline_id, 5, "run_002") # explicit: 5 - idx2 = register_clustering_run(pipeline_id, -1, "run_003") # auto: 6 - idx3 = register_clustering_run(pipeline_id, 2, "run_004") # explicit: 2 - idx4 = register_clustering_run(pipeline_id, -1, "run_005") # auto: 7 - - assert idx0 == 0 - assert idx1 == 5 - assert idx2 == 6 - assert idx3 == 2 - assert idx4 == 7 - - # Verify in database (ordered by idx) - runs = get_clustering_runs(pipeline_id) - assert runs == [ - (0, "run_001"), - (2, "run_004"), - (5, "run_002"), - (6, "run_003"), - (7, "run_005"), - ] - def test_different_pipelines_independent(self, _temp_registry_db: Any): """Test that different pipelines have independent index sequences.""" pipeline_a = "pipeline_a" pipeline_b = "pipeline_b" # Both should start at 0 when auto-assigning - idx_a0 = register_clustering_run(pipeline_a, -1, "run_a1") - idx_b0 = register_clustering_run(pipeline_b, -1, "run_b1") + idx_a0 = register_clustering_run(pipeline_a, "run_a1") + idx_b0 = register_clustering_run(pipeline_b, "run_b1") assert idx_a0 == 0 assert idx_b0 == 0 # Both should increment independently - idx_a1 = register_clustering_run(pipeline_a, -1, "run_a2") - idx_b1 = register_clustering_run(pipeline_b, -1, "run_b2") + idx_a1 = register_clustering_run(pipeline_a, "run_a2") + idx_b1 = register_clustering_run(pipeline_b, "run_b2") assert idx_a1 == 1 assert idx_b1 == 1 @@ -199,17 +94,17 @@ def test_get_runs_sorted_by_index(self, _temp_registry_db: Any): """Test that runs are returned sorted by index.""" pipeline_id = "pipeline_sort" - # Register out of order - register_clustering_run(pipeline_id, 5, "run_005") - register_clustering_run(pipeline_id, 1, "run_001") - register_clustering_run(pipeline_id, 3, "run_003") - register_clustering_run(pipeline_id, 0, "run_000") + # Register runs (indices will be auto-assigned in order) + register_clustering_run(pipeline_id, "run_000") + register_clustering_run(pipeline_id, "run_001") + register_clustering_run(pipeline_id, "run_002") + register_clustering_run(pipeline_id, "run_003") # Should be returned in sorted order runs = get_clustering_runs(pipeline_id) assert runs == [ (0, "run_000"), (1, "run_001"), + (2, "run_002"), (3, "run_003"), - (5, "run_005"), ] From cf64a7972127fe59c8360ec720599c842b0b3e11 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 17:56:27 +0100 Subject: [PATCH 60/77] only allow passing clustering run config path, not inline see discussion at https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 have tried to make this change as isolated as possible -- i think this was a useful feature and we may want to add it back at some point --- .../configs/pipeline-dev-simplestories.yaml | 20 +-- spd/clustering/scripts/run_pipeline.py | 84 ++------- tests/clustering/test_pipeline_config.py | 168 ++---------------- 3 files changed, 21 insertions(+), 251 deletions(-) diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index dfee51d64..6d181424a 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,22 +6,4 @@ slurm_partition: null wandb_project: "spd-cluster" # wandb fails in CI wandb_entity: "goodfire" create_git_snapshot: false -# run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" -run_clustering_config: - model_path: "wandb:goodfire/spd/runs/lxs77xye" - batch_size: 16 - wandb_project: "spd-cluster" - logging_intervals: - stat: 5 - tensor: 100 - plot: 10000 - artifact: 10000 - merge_config: - activation_threshold: 0.1 - alpha: 1.0 - iters: null - merge_pair_sampling_method: "range" - merge_pair_sampling_kwargs: - threshold: 0.001 - filter_dead_threshold: 0.1 - module_name_filter: null \ No newline at end of file +run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 5396cb640..614d7ac17 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -71,13 +71,8 @@ def distances_path(self, method: DistancesMethod) -> Path: class ClusteringPipelineConfig(BaseConfig): """Configuration for submitting an ensemble of clustering runs to SLURM.""" - run_clustering_config_path: Path | None = Field( - default=None, - description="Path to ClusteringRunConfig file. Mutually exclusive with run_clustering_config.", - ) - run_clustering_config: ClusteringRunConfig | None = Field( - default=None, - description="Inline ClusteringRunConfig. Mutually exclusive with run_clustering_config_path.", + run_clustering_config_path: Path = Field( + description="Path to ClusteringRunConfig file.", ) n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") distances_methods: list[DistancesMethod] = Field( @@ -101,29 +96,13 @@ class ClusteringPipelineConfig(BaseConfig): ) @model_validator(mode="after") - def validate_crc_fields(self) -> "ClusteringPipelineConfig": - """Validate that exactly one of run_clustering_config_path or run_clustering_config is provided.""" - has_path: bool = self.run_clustering_config_path is not None - has_inline: bool = self.run_clustering_config is not None - - if not has_path and not has_inline: - raise ValueError( - "Must specify exactly one of 'run_clustering_config_path' or 'run_clustering_config'" - ) - - if has_path: - if has_inline: - raise ValueError( - "Cannot specify both 'run_clustering_config_path' and 'run_clustering_config'. " - "Use only one." - ) - else: - # Ensure the path exists - # pyright ignore because it doesn't recognize that has_path implies not None - if not self.run_clustering_config_path.exists(): # pyright: ignore[reportOptionalMemberAccess] - raise ValueError( - f"run_clustering_config_path does not exist: {self.run_clustering_config_path = }" - ) + def validate_crc(self) -> "ClusteringPipelineConfig": + """Validate that exactly one of run_clustering_config_path points to a valid `ClusteringRunConfig`.""" + assert self.run_clustering_config_path.exists(), ( + f"run_clustering_config_path does not exist: {self.run_clustering_config_path}" + ) + # Try to load ClusteringRunConfig + assert ClusteringRunConfig.from_file(self.run_clustering_config_path) return self @@ -137,49 +116,6 @@ def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesM return v - def get_config_path(self) -> Path: - """Get the path to the ClusteringRunConfig file. - - - If run_clustering_config_path is provided, returns it directly. - - If run_clustering_config is provided, caches it to a deterministic path - based on its content hash and returns that path. - - if the config file already exists in the cache, assert that it is identical. - - Returns: - Path to the (potentially newly created) ClusteringRunConfig file - """ - if self.run_clustering_config_path is not None: - assert self.run_clustering_config_path.exists(), ( - f"no file at run_clustering_config_path: {self.run_clustering_config_path = }" - ) - return self.run_clustering_config_path - - assert self.run_clustering_config is not None, ( - "Either run_clustering_config_path or run_clustering_config must be set" - ) - - # Generate deterministic hash from config - hash_b64: str = self.run_clustering_config.stable_hash_b64() - - # Create cache directory - cache_dir: Path = SPD_CACHE_DIR / "clustering_run_configs" - cache_dir.mkdir(parents=True, exist_ok=True) - - # Write config to cache if it doesn't exist - config_path: Path = cache_dir / f"{hash_b64}.json" - if not config_path.exists(): - self.run_clustering_config.to_file(config_path) - logger.info(f"Cached inline config to {config_path}") - else: - # Verify that existing file matches - existing_config = ClusteringRunConfig.from_file(config_path) - if existing_config != self.run_clustering_config: - raise ValueError( - f"Hash collision detected for config hash {hash_b64} at {config_path}\n{existing_config=}\n{self.run_clustering_config=}" - ) - - return config_path - def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: """Create WandB workspace view for clustering runs. @@ -234,7 +170,7 @@ def generate_clustering_commands( "python", "spd/clustering/scripts/run_clustering.py", "--config", - pipeline_config.get_config_path().as_posix(), + pipeline_config.run_clustering_config_path.as_posix(), "--pipeline-run-id", pipeline_run_id, "--idx-in-ensemble", diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 8d527bd6c..264078392 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -2,12 +2,13 @@ from pathlib import Path +import pydantic_core import pytest from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.merge_config import MergeConfig from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig -from spd.settings import REPO_ROOT, SPD_CACHE_DIR +from spd.settings import REPO_ROOT class TestClusteringRunConfigStableHash: @@ -69,10 +70,11 @@ def test_stable_hash_b64(self): class TestClusteringPipelineConfigValidation: """Test ClusteringPipelineConfig validation logic.""" - def test_error_when_neither_field_provided(self): - """Test that error is raised when neither path nor inline config is provided.""" - with pytest.raises(ValueError, match="Must specify exactly one"): + def test_error_when_path_does_not_exist(self): + """Test that error is raised when run_clustering_config_path does not exist.""" + with pytest.raises(pydantic_core._pydantic_core.ValidationError): ClusteringPipelineConfig( + run_clustering_config_path=Path("nonexistent/path.json"), n_runs=2, distances_methods=["perm_invariant_hamming"], base_output_dir=Path("/tmp/test"), @@ -82,32 +84,8 @@ def test_error_when_neither_field_provided(self): create_git_snapshot=False, ) - def test_error_when_both_fields_provided(self): - """Test that error is raised when both path and inline config are provided.""" - with pytest.raises(ValueError, match="Cannot specify both"): - ClusteringPipelineConfig( - run_clustering_config_path=Path("some/path.json"), - run_clustering_config=ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - merge_config=MergeConfig(), - dataset_seed=0, - ), - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - slurm_job_name_prefix=None, - slurm_partition=None, - wandb_entity="test", - create_git_snapshot=False, - ) - - -class TestClusteringPipelineConfigGetConfigPath: - """Test ClusteringPipelineConfig.get_config_path() method.""" - - def test_returns_path_directly_when_using_path_field(self): - """Test that get_config_path returns the path directly when using run_clustering_config_path.""" + def test_valid_config_with_existing_path(self): + """Test that config is valid when path points to existing file.""" expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") config = ClusteringPipelineConfig( @@ -119,133 +97,7 @@ def test_returns_path_directly_when_using_path_field(self): create_git_snapshot=False, ) - assert config.get_config_path() == expected_path - - def test_creates_cached_file_when_using_inline_config(self): - """Test that get_config_path creates a cached file when using inline config.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - config_path = config.get_config_path() - - # Check that file exists - assert config_path.exists() - - # Check that it's in the expected directory - expected_cache_dir = SPD_CACHE_DIR / "clustering_run_configs" - assert config_path.parent == expected_cache_dir - - # Check that filename is the hash - expected_hash = inline_config.stable_hash_b64() - assert config_path.name == f"{expected_hash}.json" - - # Check that file contents match the config - loaded_config = ClusteringRunConfig.from_file(config_path) - assert loaded_config == inline_config - - # Clean up - config_path.unlink() - - def test_reuses_existing_cached_file(self): - """Test that get_config_path reuses existing cached file with same hash.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - - config1 = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - # First call creates the file - config_path1 = config1.get_config_path() - assert config_path1.exists() - - # Record modification time - mtime1 = config_path1.stat().st_mtime - - # Create another config with same inline config - config2 = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=3, # Different n_runs shouldn't matter - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - wandb_entity="test", - create_git_snapshot=False, - ) - - # Second call should reuse the file - config_path2 = config2.get_config_path() - - assert config_path1 == config_path2 - assert config_path2.stat().st_mtime == mtime1 # File not modified - - # Clean up - config_path1.unlink() - - def test_hash_collision_detection(self): - """Test that hash collision is detected when existing file differs.""" - inline_config = ClusteringRunConfig( - model_path="wandb:test/project/run1", - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - - # Create a fake collision by manually creating a file with same hash - hash_value = inline_config.stable_hash_b64() - cache_dir = SPD_CACHE_DIR / "clustering_run_configs" - cache_dir.mkdir(parents=True, exist_ok=True) - collision_path = cache_dir / f"{hash_value}.json" - - # Write a different config to the file - different_config = ClusteringRunConfig( - model_path="wandb:test/project/run2", # Different! - batch_size=32, - dataset_seed=0, - merge_config=MergeConfig(), - ) - different_config.to_file(collision_path) - - try: - config = ClusteringPipelineConfig( - run_clustering_config=inline_config, - n_runs=2, - distances_methods=["perm_invariant_hamming"], - base_output_dir=Path("/tmp/test"), - slurm_job_name_prefix=None, - slurm_partition=None, - wandb_entity="test", - create_git_snapshot=False, - ) - - # Should raise ValueError about hash collision - with pytest.raises(ValueError, match="Hash collision detected"): - config.get_config_path() - finally: - # Clean up - if collision_path.exists(): - collision_path.unlink() + assert config.run_clustering_config_path == expected_path def _get_config_files(path: Path): @@ -269,7 +121,7 @@ def test_config_validate_pipeline(self, config_file: Path): """Test that each pipeline config file is valid.""" print(config_file) _config = ClusteringPipelineConfig.from_file(config_file) - crc_path = _config.get_config_path() + crc_path = _config.run_clustering_config_path print(f"{crc_path = }") assert crc_path.exists() From 2a9f731b6e9453f1d439f23a862cd69085a02191 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 17:58:46 +0100 Subject: [PATCH 61/77] rename run_clustering_config_path -> clustering_run_config_path old name didnt make sense, since it should be a path to a file with a `ClusteringRunConfig` see https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 --- spd/clustering/configs/README.md | 2 +- .../configs/pipeline-dev-simplestories.yaml | 2 +- spd/clustering/configs/pipeline-test-resid_mlp1.yaml | 2 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- spd/clustering/scripts/run_pipeline.py | 12 ++++++------ tests/clustering/test_pipeline_config.py | 10 +++++----- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md index 51db8e8a0..e1ac41f47 100644 --- a/spd/clustering/configs/README.md +++ b/spd/clustering/configs/README.md @@ -1 +1 @@ -this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `run_clustering_config_path` field in the pipeline configs. \ No newline at end of file +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `clustering_run_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index 6d181424a..1868b5887 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -6,4 +6,4 @@ slurm_partition: null wandb_project: "spd-cluster" # wandb fails in CI wandb_entity: "goodfire" create_git_snapshot: false -run_clustering_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file +clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index db72fa3c0..37833c82c 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" +clustering_run_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" n_runs: 3 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index 24e686023..9872062d2 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/crc/test-simplestories.json" +clustering_run_config_path: "spd/clustering/configs/crc/test-simplestories.json" n_runs: 2 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 297b47d7b..3a533885d 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/crc/example.yaml" +clustering_run_config_path: "spd/clustering/configs/crc/example.yaml" n_runs: 2 distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index 614d7ac17..179bc8bca 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -71,7 +71,7 @@ def distances_path(self, method: DistancesMethod) -> Path: class ClusteringPipelineConfig(BaseConfig): """Configuration for submitting an ensemble of clustering runs to SLURM.""" - run_clustering_config_path: Path = Field( + clustering_run_config_path: Path = Field( description="Path to ClusteringRunConfig file.", ) n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") @@ -97,12 +97,12 @@ class ClusteringPipelineConfig(BaseConfig): @model_validator(mode="after") def validate_crc(self) -> "ClusteringPipelineConfig": - """Validate that exactly one of run_clustering_config_path points to a valid `ClusteringRunConfig`.""" - assert self.run_clustering_config_path.exists(), ( - f"run_clustering_config_path does not exist: {self.run_clustering_config_path}" + """Validate that exactly one of clustering_run_config_path points to a valid `ClusteringRunConfig`.""" + assert self.clustering_run_config_path.exists(), ( + f"clustering_run_config_path does not exist: {self.clustering_run_config_path}" ) # Try to load ClusteringRunConfig - assert ClusteringRunConfig.from_file(self.run_clustering_config_path) + assert ClusteringRunConfig.from_file(self.clustering_run_config_path) return self @@ -170,7 +170,7 @@ def generate_clustering_commands( "python", "spd/clustering/scripts/run_clustering.py", "--config", - pipeline_config.run_clustering_config_path.as_posix(), + pipeline_config.clustering_run_config_path.as_posix(), "--pipeline-run-id", pipeline_run_id, "--idx-in-ensemble", diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py index 264078392..ca6bad6ee 100644 --- a/tests/clustering/test_pipeline_config.py +++ b/tests/clustering/test_pipeline_config.py @@ -71,10 +71,10 @@ class TestClusteringPipelineConfigValidation: """Test ClusteringPipelineConfig validation logic.""" def test_error_when_path_does_not_exist(self): - """Test that error is raised when run_clustering_config_path does not exist.""" + """Test that error is raised when clustering_run_config_path does not exist.""" with pytest.raises(pydantic_core._pydantic_core.ValidationError): ClusteringPipelineConfig( - run_clustering_config_path=Path("nonexistent/path.json"), + clustering_run_config_path=Path("nonexistent/path.json"), n_runs=2, distances_methods=["perm_invariant_hamming"], base_output_dir=Path("/tmp/test"), @@ -89,7 +89,7 @@ def test_valid_config_with_existing_path(self): expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") config = ClusteringPipelineConfig( - run_clustering_config_path=expected_path, + clustering_run_config_path=expected_path, n_runs=2, distances_methods=["perm_invariant_hamming"], base_output_dir=Path("/tmp/test"), @@ -97,7 +97,7 @@ def test_valid_config_with_existing_path(self): create_git_snapshot=False, ) - assert config.run_clustering_config_path == expected_path + assert config.clustering_run_config_path == expected_path def _get_config_files(path: Path): @@ -121,7 +121,7 @@ def test_config_validate_pipeline(self, config_file: Path): """Test that each pipeline config file is valid.""" print(config_file) _config = ClusteringPipelineConfig.from_file(config_file) - crc_path = _config.run_clustering_config_path + crc_path = _config.clustering_run_config_path print(f"{crc_path = }") assert crc_path.exists() From 26c65209d998d670610338e7ac8bbae7191d64a4 Mon Sep 17 00:00:00 2001 From: mivanit Date: Fri, 24 Oct 2025 10:11:30 -0700 Subject: [PATCH 62/77] [clustering] config refactor (#227) ## Core changes - `ClusteringRunConfig.idx_in_ensemble` is removed entirely. index is now auto assigned - ~~pipeline config can contain merge run config either inline or as path~~ - ~~quite annoying when doing experiments to have to go and edit two files~~ - no longer a feature, but path will be validated on config load. see discussion in https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 - remove component popping - change brought over from #206 via commit [a1f1146](https://github.com/goodfire-ai/spd/pull/227/commits/a1f1146d0480b4ee08cfc2a7070be6170b9394d1) - clustering runs will now actually use the run id for the wandb run id - spd decomp runs don't yet ## Housekeeping - separate folders for pipeline configs and merge run configs - clearer to the user, and enables testing configs without trying to infer what type they are - more validators on the configs - some config tests, importantly validating all configs in the `spd/clustering/configs/` dir - often was previously the case that config schema or var names would change, old configs would go out of date, and this would not be immediately obvious - tests for ensemble registry ## Key questions: - [x] is this way of handling "incline config" vs "path to config" fine, or should I do it the more pydantic-y way? - only path to config, see https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 - [x] we can in principle now get rid of `idx_in_ensemble` entirely, since the database read/write should handle uniqueness of indexes. I have kept it in for now, but happy to get rid of it - got rid of it, see [40df505](https://github.com/goodfire-ai/spd/pull/227/commits/40df505c543bda9d5e21938728159a217640fc78) - ~~when a config already exists at the expected path, should we compare strings or the loaded object? the former is probably faster since pydantic validations add overhead. also, hash collisions are vanishingly unlikely here~~ N/A, removed inline config functionality Commits: * allow specifying either config path or mrc cfg in pipeline cfg * [wip] reorg configs * added default `None` for slurm partition and job name prefix * refactor configs, add config tests * fix tests * allow `None` or `-1` idx_in_ensemble - idx_in_ensemble is None iff ensemble_id is None - idx_in_ensemble == -1 will make register_clustering_run() auto-assign next avalible index - added tests for ensemble registry * whoops, wrong name on fixture * fix idx passed in tests when not needed * rename "mrc" -> "crc" in paths I forgot its no longer called "MergeRunConfig" * rename merge_run_config.py -> clustering_run_config.py * fix pyright * fix idx_in_ensemble being passed in tests * rename cache dir 'merge_run_configs' -> 'clustering_run_configs' * remove component popping changes brought in from PR https://github.com/goodfire-ai/spd/pull/206 branch clustering/refactor-multi-batch commit [9cbb52f](https://github.com/goodfire-ai/spd/pull/206/commits/9cbb52fd09cac8d79481a16de0a9e4c517960a33) * dont pass batch size, change not brought in here * fix history_path extension and storage usage * dev pipeline * better config validation tests * set default base output dir * wandb use run id for clustering, TODO for spd decomp * basedpyright 1.32.0 causes issues, esp w/ wandb https://github.com/goodfire-ai/spd/actions/runs/18719611602/job/53388090437 * remove idx_in_ensemble, always auto-assigned now see https://github.com/goodfire-ai/spd/pull/227#discussion_r2454317036 * only allow passing clustering run config path, not inline see discussion at https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 have tried to make this change as isolated as possible -- i think this was a useful feature and we may want to add it back at some point * rename run_clustering_config_path -> clustering_run_config_path old name didnt make sense, since it should be a path to a file with a `ClusteringRunConfig` see https://github.com/goodfire-ai/spd/pull/227#discussion_r2454299922 --- ...run_config.py => clustering_run_config.py} | 45 +++--- spd/clustering/compute_costs.py | 110 +------------- spd/clustering/configs/README.md | 1 + spd/clustering/configs/{ => crc}/example.yaml | 2 - .../configs/{ => crc}/resid_mlp1.json | 6 +- .../configs/{ => crc}/resid_mlp2.json | 5 +- .../configs/{ => crc}/simplestories_dev.json | 3 +- .../configs/{ => crc}/test-resid_mlp1.json | 1 - .../configs/{ => crc}/test-simplestories.json | 1 - .../configs/pipeline-dev-simplestories.yaml | 12 +- .../configs/pipeline-test-resid_mlp1.yaml | 2 +- .../configs/pipeline-test-simplestories.yaml | 2 +- spd/clustering/configs/pipeline_config.yaml | 2 +- spd/clustering/configs/resid_mlp3.json | 23 --- spd/clustering/dataset.py | 2 +- spd/clustering/ensemble_registry.py | 21 ++- spd/clustering/merge.py | 51 +------ spd/clustering/merge_config.py | 5 - spd/clustering/scripts/calc_distances.py | 13 +- spd/clustering/scripts/run_clustering.py | 41 +++--- spd/clustering/scripts/run_pipeline.py | 38 ++++- spd/utils/wandb_utils.py | 1 + tests/clustering/scripts/cluster_resid_mlp.py | 1 - tests/clustering/scripts/cluster_ss.py | 4 +- tests/clustering/test_calc_distances.py | 1 - tests/clustering/test_ensemble_registry.py | 110 ++++++++++++++ tests/clustering/test_merge_config.py | 2 - tests/clustering/test_merge_integration.py | 36 ----- tests/clustering/test_pipeline_config.py | 137 ++++++++++++++++++ .../test_run_clustering_happy_path.py | 4 +- 30 files changed, 376 insertions(+), 306 deletions(-) rename spd/clustering/{merge_run_config.py => clustering_run_config.py} (80%) create mode 100644 spd/clustering/configs/README.md rename spd/clustering/configs/{ => crc}/example.yaml (87%) rename spd/clustering/configs/{ => crc}/resid_mlp1.json (74%) rename spd/clustering/configs/{ => crc}/resid_mlp2.json (82%) rename spd/clustering/configs/{ => crc}/simplestories_dev.json (82%) rename spd/clustering/configs/{ => crc}/test-resid_mlp1.json (93%) rename spd/clustering/configs/{ => crc}/test-simplestories.json (94%) delete mode 100644 spd/clustering/configs/resid_mlp3.json create mode 100644 tests/clustering/test_ensemble_registry.py create mode 100644 tests/clustering/test_pipeline_config.py diff --git a/spd/clustering/merge_run_config.py b/spd/clustering/clustering_run_config.py similarity index 80% rename from spd/clustering/merge_run_config.py rename to spd/clustering/clustering_run_config.py index 60a5244d6..95d72f9bd 100644 --- a/spd/clustering/merge_run_config.py +++ b/spd/clustering/clustering_run_config.py @@ -1,9 +1,12 @@ """ClusteringRunConfig""" +import base64 +import hashlib +import json from pathlib import Path -from typing import Any, Self +from typing import Any -from pydantic import Field, PositiveInt, model_validator +from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig from spd.clustering.merge_config import MergeConfig @@ -51,8 +54,6 @@ class ClusteringRunConfig(BaseConfig): default=None, description="Ensemble identifier for WandB grouping", ) - idx_in_ensemble: int = Field(0, description="Index of this run in the ensemble") - merge_config: MergeConfig = Field(description="Merge algorithm configuration") logging_intervals: LoggingIntervals = Field( default_factory=LoggingIntervals, @@ -69,16 +70,6 @@ class ClusteringRunConfig(BaseConfig): description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", ) - # TODO: no way to check this without knowing task - # @model_validator(mode="after") - # def validate_streaming_compatibility(self) -> Self: - # """Ensure dataset_streaming is only enabled for compatible tasks.""" - # if self.dataset_streaming and self.task_name != "lm": - # raise ValueError( - # f"Streaming dataset loading only supported for 'lm' task, got '{self.task_name}'" - # ) - # return self - @model_validator(mode="before") def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: experiment_key: str | None = values.get("experiment_key") @@ -100,12 +91,12 @@ def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: return values - @model_validator(mode="after") - def validate_model_path(self) -> Self: + @field_validator("model_path") + def validate_model_path(cls, v: str) -> str: """Validate that model_path is a proper WandB path.""" - if not self.model_path.startswith("wandb:"): - raise ValueError(f"model_path must start with 'wandb:', got: {self.model_path}") - return self + if not v.startswith("wandb:"): + raise ValueError(f"model_path must start with 'wandb:', got: {v}") + return v @property def wandb_decomp_model(self) -> str: @@ -127,3 +118,19 @@ def model_dump_with_properties(self) -> dict[str, Any]: ) return base_dump + + def stable_hash_b64(self) -> str: + """Generate a stable, deterministic base64-encoded hash of this config. + + Uses SHA256 hash of the JSON representation with sorted keys for determinism. + Returns URL-safe base64 encoding without padding. + + Returns: + URL-safe base64-encoded hash (without padding) + """ + config_dict: dict[str, Any] = self.model_dump(mode="json") + config_json: str = json.dumps(config_dict, indent=2, sort_keys=True) + hash_digest: bytes = hashlib.sha256(config_json.encode()).digest() + # Use base64 URL-safe encoding and strip padding for filesystem safety + hash_b64: str = base64.urlsafe_b64encode(hash_digest).decode().rstrip("=") + return hash_b64 diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py index ba1ff274c..f1b3425d1 100644 --- a/spd/clustering/compute_costs.py +++ b/spd/clustering/compute_costs.py @@ -1,7 +1,7 @@ import math import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from spd.clustering.consts import ClusterCoactivationShaped, MergePair @@ -187,111 +187,3 @@ def recompute_coacts_merge_pair( coact_new, activation_mask_new, ) - - -def recompute_coacts_pop_group( - coact: ClusterCoactivationShaped, - merges: GroupMerge, - component_idx: int, - activation_mask: Bool[Tensor, "n_samples k_groups"], - activation_mask_orig: Bool[Tensor, "n_samples n_components"], -) -> tuple[ - GroupMerge, - Float[Tensor, "k_groups+1 k_groups+1"], - Bool[Tensor, "n_samples k_groups+1"], -]: - # sanity check dims - # ================================================== - - k_groups: int = coact.shape[0] - n_samples: int = activation_mask.shape[0] - k_groups_new: int = k_groups + 1 - assert coact.shape[1] == k_groups, "Coactivation matrix must be square" - assert activation_mask.shape[1] == k_groups, ( - "Activation mask must match coactivation matrix shape" - ) - assert n_samples == activation_mask_orig.shape[0], ( - "Activation mask original must match number of samples" - ) - - # get the activations we need - # ================================================== - # which group does the component belong to? - group_idx: int = int(merges.group_idxs[component_idx].item()) - group_size_old: int = int(merges.components_per_group[group_idx].item()) - group_size_new: int = group_size_old - 1 - - # activations of component we are popping out - acts_pop: Bool[Tensor, " samples"] = activation_mask_orig[:, component_idx] - - # activations of the "remainder" -- everything other than the component we are popping out, - # in the group we're popping it out of - acts_remainder: Bool[Tensor, " samples"] = ( - activation_mask_orig[ - :, [i for i in merges.components_in_group(group_idx) if i != component_idx] - ] - .max(dim=-1) - .values - ) - - # assemble the new activation mask - # ================================================== - # first concat the popped-out component onto the end - activation_mask_new: Bool[Tensor, " samples k_groups+1"] = torch.cat( - [activation_mask, acts_pop.unsqueeze(1)], - dim=1, - ) - # then replace the group we are popping out of with the remainder - activation_mask_new[:, group_idx] = acts_remainder - - # assemble the new coactivation matrix - # ================================================== - coact_new: Float[Tensor, "k_groups+1 k_groups+1"] = torch.full( - (k_groups_new, k_groups_new), - fill_value=float("nan"), - dtype=coact.dtype, - device=coact.device, - ) - # copy in the old coactivation matrix - coact_new[:k_groups, :k_groups] = coact.clone() - # compute new coactivations we need - coact_pop: Float[Tensor, " k_groups"] = acts_pop.float() @ activation_mask_new.float() - coact_remainder: Float[Tensor, " k_groups"] = ( - acts_remainder.float() @ activation_mask_new.float() - ) - - # replace the relevant rows and columns - coact_new[group_idx, :] = coact_remainder - coact_new[:, group_idx] = coact_remainder - coact_new[-1, :] = coact_pop - coact_new[:, -1] = coact_pop - - # assemble the new group merge - # ================================================== - group_idxs_new: Int[Tensor, " k_groups+1"] = merges.group_idxs.clone() - # the popped-out component is now its own group - new_group_idx: int = k_groups_new - 1 - group_idxs_new[component_idx] = new_group_idx - merge_new: GroupMerge = GroupMerge( - group_idxs=group_idxs_new, - k_groups=k_groups_new, - ) - - # sanity check - assert merge_new.components_per_group.shape == (k_groups_new,), ( - "New merge must have k_groups+1 components" - ) - assert merge_new.components_per_group[new_group_idx] == 1, ( - "New group must have exactly one component" - ) - assert merge_new.components_per_group[group_idx] == group_size_new, ( - "Old group must have one less component" - ) - - # return - # ================================================== - return ( - merge_new, - coact_new, - activation_mask_new, - ) diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md new file mode 100644 index 000000000..e1ac41f47 --- /dev/null +++ b/spd/clustering/configs/README.md @@ -0,0 +1 @@ +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `clustering_run_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/example.yaml b/spd/clustering/configs/crc/example.yaml similarity index 87% rename from spd/clustering/configs/example.yaml rename to spd/clustering/configs/crc/example.yaml index efa36d693..9345307d2 100644 --- a/spd/clustering/configs/example.yaml +++ b/spd/clustering/configs/crc/example.yaml @@ -1,7 +1,6 @@ model_path: wandb:goodfire/spd/runs/zxbu57pt # WandB path to the decomposed model batch_size: 8 # Batch size for processing -- number of samples for each run in the ensemble dataset_seed: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) -# idx_in_ensemble: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) # output_dir: .data/clustering/clustering_runs # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) # ensemble_id: 1234567890 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) @@ -12,7 +11,6 @@ merge_config: merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' merge_pair_sampling_kwargs: threshold: 0.05 # For range sampler: fraction of the range of costs to sample from - pop_component_prob: 0 # Probability of popping a component. i recommend 0 if you're doing an ensemble anyway filter_dead_threshold: 0.001 # Threshold for filtering dead components module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules diff --git a/spd/clustering/configs/resid_mlp1.json b/spd/clustering/configs/crc/resid_mlp1.json similarity index 74% rename from spd/clustering/configs/resid_mlp1.json rename to spd/clustering/configs/crc/resid_mlp1.json index a7d118ac7..1e13ce23e 100644 --- a/spd/clustering/configs/resid_mlp1.json +++ b/spd/clustering/configs/crc/resid_mlp1.json @@ -5,17 +5,13 @@ "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0, "module_name_filter": null }, "experiment_key": "resid_mlp1", - "distances_methods": ["perm_invariant_hamming"], - "n_batches": 8, "batch_size": 128, - "wandb_enabled": true, "wandb_project": "spd-cluster", - "intervals": { + "logging_intervals": { "stat": 1, "tensor": 5, "plot": 5, diff --git a/spd/clustering/configs/resid_mlp2.json b/spd/clustering/configs/crc/resid_mlp2.json similarity index 82% rename from spd/clustering/configs/resid_mlp2.json rename to spd/clustering/configs/crc/resid_mlp2.json index 2be350979..edc4849e2 100644 --- a/spd/clustering/configs/resid_mlp2.json +++ b/spd/clustering/configs/crc/resid_mlp2.json @@ -5,16 +5,13 @@ "iters": 100, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.01, "module_name_filter": null }, "experiment_key": "resid_mlp2", - "n_batches": 16, "batch_size": 1024, - "wandb_enabled": true, "wandb_project": "spd-cluster", - "intervals": { + "logging_intervals": { "stat": 1, "tensor": 5, "plot": 5, diff --git a/spd/clustering/configs/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json similarity index 82% rename from spd/clustering/configs/simplestories_dev.json rename to spd/clustering/configs/crc/simplestories_dev.json index f585e848f..e1647b6e4 100644 --- a/spd/clustering/configs/simplestories_dev.json +++ b/spd/clustering/configs/crc/simplestories_dev.json @@ -4,8 +4,7 @@ "alpha": 1.0, "iters": 100, "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.01}, - "pop_component_prob": 0, + "merge_pair_sampling_kwargs": {"threshold": 0.001}, "filter_dead_threshold": 0.1, "module_name_filter": null }, diff --git a/spd/clustering/configs/test-resid_mlp1.json b/spd/clustering/configs/crc/test-resid_mlp1.json similarity index 93% rename from spd/clustering/configs/test-resid_mlp1.json rename to spd/clustering/configs/crc/test-resid_mlp1.json index 01b510200..4b3a26ff8 100644 --- a/spd/clustering/configs/test-resid_mlp1.json +++ b/spd/clustering/configs/crc/test-resid_mlp1.json @@ -5,7 +5,6 @@ "iters": 16, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.1, "module_name_filter": null }, diff --git a/spd/clustering/configs/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json similarity index 94% rename from spd/clustering/configs/test-simplestories.json rename to spd/clustering/configs/crc/test-simplestories.json index 147634edb..911f71529 100644 --- a/spd/clustering/configs/test-simplestories.json +++ b/spd/clustering/configs/crc/test-simplestories.json @@ -5,7 +5,6 @@ "iters": 5, "merge_pair_sampling_method": "range", "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, "filter_dead_threshold": 0.9, "module_name_filter": "model.layers.0" }, diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml index 6909c5841..1868b5887 100644 --- a/spd/clustering/configs/pipeline-dev-simplestories.yaml +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -1,9 +1,9 @@ -run_clustering_config_path: "spd/clustering/configs/simplestories_dev.json" -n_runs: 4 -distances_methods: ["matching_dist", "matching_dist_vec", "perm_invariant_hamming"] -base_output_dir: "tests/.temp/clustering" +n_runs: 2 +distances_methods: ["matching_dist"] +# base_output_dir: "tests/.temp/clustering" slurm_job_name_prefix: null slurm_partition: null -wandb_project: null # wandb fails in CI +wandb_project: "spd-cluster" # wandb fails in CI wandb_entity: "goodfire" -create_git_snapshot: false \ No newline at end of file +create_git_snapshot: false +clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml index a413a5438..37833c82c 100644 --- a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/test-resid_mlp1.json" +clustering_run_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" n_runs: 3 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml index e406628c4..9872062d2 100644 --- a/spd/clustering/configs/pipeline-test-simplestories.yaml +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/test-simplestories.json" +clustering_run_config_path: "spd/clustering/configs/crc/test-simplestories.json" n_runs: 2 distances_methods: ["matching_dist"] base_output_dir: "tests/.temp/clustering" diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml index 6a40c9b29..3a533885d 100644 --- a/spd/clustering/configs/pipeline_config.yaml +++ b/spd/clustering/configs/pipeline_config.yaml @@ -1,4 +1,4 @@ -run_clustering_config_path: "spd/clustering/configs/example.yaml" +clustering_run_config_path: "spd/clustering/configs/crc/example.yaml" n_runs: 2 distances_methods: ["perm_invariant_hamming"] base_output_dir: "/mnt/polished-lake/spd/clustering" diff --git a/spd/clustering/configs/resid_mlp3.json b/spd/clustering/configs/resid_mlp3.json deleted file mode 100644 index 5d87e08d5..000000000 --- a/spd/clustering/configs/resid_mlp3.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "merge_config": { - "activation_threshold": 0.01, - "alpha": 1, - "iters": 350, - "merge_pair_sampling_method": "range", - "merge_pair_sampling_kwargs": {"threshold": 0.05}, - "pop_component_prob": 0, - "filter_dead_threshold": 0.01, - "module_name_filter": null - }, - "experiment_key": "resid_mlp3", - "n_batches": 4, - "batch_size": 1024, - "wandb_enabled": true, - "wandb_project": "spd-cluster", - "intervals": { - "stat": 1, - "tensor": 32, - "plot": 32, - "artifact": 32 - } -} \ No newline at end of file diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py index c514aa69f..ea9b9f904 100644 --- a/spd/clustering/dataset.py +++ b/spd/clustering/dataset.py @@ -23,7 +23,7 @@ def load_dataset( ) -> BatchTensor: """Load a single batch for clustering. - Each run gets its own dataset batch, seeded by idx_in_ensemble. + Each run gets its own dataset batch, seeded by index in ensemble. Args: model_path: Path to decomposed model diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py index 7756877d8..c54fe408b 100644 --- a/spd/clustering/ensemble_registry.py +++ b/spd/clustering/ensemble_registry.py @@ -39,21 +39,36 @@ def _get_connection(): conn.close() -def register_clustering_run(pipeline_run_id: str, idx: int, clustering_run_id: str) -> None: +def register_clustering_run(pipeline_run_id: str, clustering_run_id: str) -> int: """Register a clustering run as part of a pipeline ensemble. Args: pipeline_run_id: The ensemble/pipeline run ID - idx: Index of this run in the ensemble + idx: Index of this run in the ensemble. If -1, auto-assigns the next available index. clustering_run_id: The individual clustering run ID + + Returns: + The index assigned to this run (either the provided idx or the auto-assigned one) """ with _get_connection() as conn: + # Use BEGIN IMMEDIATE for thread-safe auto-increment + conn.execute("BEGIN IMMEDIATE") + + # Auto-assign next available index, we rely on atomicity of the transaction here + cursor = conn.execute( + "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", + (pipeline_run_id,), + ) + assigned_idx: int = cursor.fetchone()[0] + conn.execute( "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", - (pipeline_run_id, idx, clustering_run_id), + (pipeline_run_id, assigned_idx, clustering_run_id), ) conn.commit() + return assigned_idx + def get_clustering_runs(pipeline_run_id: str) -> list[tuple[int, str]]: """Get all clustering runs for a pipeline ensemble. diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py index fd982b83f..dba55c878 100644 --- a/spd/clustering/merge.py +++ b/spd/clustering/merge.py @@ -8,7 +8,7 @@ from typing import Protocol import torch -from jaxtyping import Bool, Float, Int +from jaxtyping import Bool, Float from torch import Tensor from tqdm import tqdm @@ -16,7 +16,6 @@ compute_mdl_cost, compute_merge_costs, recompute_coacts_merge_pair, - recompute_coacts_pop_group, ) from spd.clustering.consts import ( ActivationsTensor, @@ -76,24 +75,6 @@ def merge_iteration( # determine number of iterations based on config and number of components num_iters: int = merge_config.get_num_iters(c_components) - # pop logic setup - # -------------------------------------------------- - # for speed, we precompute whether to pop components and which components to pop - # if we are not popping, we don't need these variables and can also delete other things - do_pop: bool = merge_config.pop_component_prob > 0.0 - if do_pop: - # at each iteration, we will pop a component with probability `pop_component_prob` - iter_pop: Bool[Tensor, " iters"] = ( - torch.rand(num_iters, device=coact.device) < merge_config.pop_component_prob - ) - # we pick a subcomponent at random, and if we decide to pop, we pop that one out of its group - # if the component is a singleton, nothing happens. this naturally biases towards popping - # less at the start and more at the end, since the effective probability of popping a component - # is actually something like `pop_component_prob * (c_components - k_groups) / c_components` - pop_component_idx: Int[Tensor, " iters"] = torch.randint( - 0, c_components, (num_iters,), device=coact.device - ) - # initialize vars # -------------------------------------------------- # start with an identity merge @@ -110,12 +91,6 @@ def merge_iteration( labels=component_labels, ) - # free up memory - if not do_pop: - del coact - del activation_mask_orig - activation_mask_orig = None - # merge iteration # ================================================== pbar: tqdm[int] = tqdm( @@ -124,30 +99,6 @@ def merge_iteration( total=num_iters, ) for iter_idx in pbar: - # pop components - # -------------------------------------------------- - if do_pop and iter_pop[iter_idx]: # pyright: ignore[reportPossiblyUnboundVariable] - # we split up the group which our chosen component belongs to - pop_component_idx_i: int = int(pop_component_idx[iter_idx].item()) # pyright: ignore[reportPossiblyUnboundVariable] - n_components_in_pop_grp: int = int( - current_merge.components_per_group[ # pyright: ignore[reportArgumentType] - current_merge.group_idxs[pop_component_idx_i].item() - ] - ) - - # but, if the component is the only one in its group, there is nothing to do - if n_components_in_pop_grp > 1: - current_merge, current_coact, current_act_mask = recompute_coacts_pop_group( - coact=current_coact, - merges=current_merge, - component_idx=pop_component_idx_i, - activation_mask=current_act_mask, - # this complains if `activation_mask_orig is None`, but this is only the case - # if `do_pop` is False, which it won't be here. we do this to save memory - activation_mask_orig=activation_mask_orig, # pyright: ignore[reportArgumentType] - ) - k_groups = current_coact.shape[0] - # compute costs, figure out what to merge # -------------------------------------------------- # HACK: this is messy diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py index 3bf8b6d5b..f471879b2 100644 --- a/spd/clustering/merge_config.py +++ b/spd/clustering/merge_config.py @@ -23,7 +23,6 @@ "iters", "merge_pair_sampling_method", "merge_pair_sampling_kwargs", - "pop_component_prob", "filter_dead_threshold", ] @@ -65,10 +64,6 @@ class MergeConfig(BaseConfig): default_factory=lambda: {"threshold": 0.05}, description="Keyword arguments for the merge pair sampling method.", ) - pop_component_prob: Probability = Field( - default=0, - description="Probability of popping a component in each iteration. If 0, no components are popped.", - ) filter_dead_threshold: float = Field( default=0.001, description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", diff --git a/spd/clustering/scripts/calc_distances.py b/spd/clustering/scripts/calc_distances.py index 709d3c1c6..993335671 100644 --- a/spd/clustering/scripts/calc_distances.py +++ b/spd/clustering/scripts/calc_distances.py @@ -25,8 +25,10 @@ from spd.clustering.math.merge_distances import compute_distances from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble from spd.clustering.plotting.merge import plot_dists_distribution +from spd.clustering.scripts.run_clustering import ClusteringRunStorage from spd.log import logger from spd.settings import SPD_CACHE_DIR +from spd.utils.run_utils import ExecutionStamp # Set spawn method for CUDA compatibility with multiprocessing # Must be done before any CUDA operations @@ -57,7 +59,16 @@ def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: # Load histories from individual clustering run directories histories: list[MergeHistory] = [] for idx, clustering_run_id in clustering_runs: - history_path = SPD_CACHE_DIR / "cluster" / clustering_run_id / "history.npz" + history_path = ClusteringRunStorage( + ExecutionStamp( + run_id=clustering_run_id, + snapshot_branch="", + commit_hash="", + run_type="cluster", + ) + ).history_path + + # SPD_CACHE_DIR / "cluster" / clustering_run_id / "history.npz" if not history_path.exists(): raise FileNotFoundError( f"History not found for run {clustering_run_id}: {history_path}" diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py index 7c614407a..54f0805c6 100644 --- a/spd/clustering/scripts/run_clustering.py +++ b/spd/clustering/scripts/run_clustering.py @@ -31,6 +31,7 @@ component_activations, process_activations, ) +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import ( ActivationsTensor, BatchTensor, @@ -43,7 +44,6 @@ from spd.clustering.math.semilog import semilog from spd.clustering.merge import merge_iteration from spd.clustering.merge_history import MergeHistory -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration from spd.clustering.storage import StorageBase @@ -66,7 +66,8 @@ class ClusteringRunStorage(StorageBase): # Relative path constants _CONFIG = "clustering_run_config.json" - _HISTORY = "history.npz" + # we are saving a zip file with things in it besides npy files -- hence, `.zip` and not `.npz` + _HISTORY = "history.zip" def __init__(self, execution_stamp: ExecutionStamp) -> None: super().__init__(execution_stamp) @@ -227,19 +228,29 @@ def main(run_config: ClusteringRunConfig) -> Path: logger.info(f"Clustering run ID: {clustering_run_id}") # Register with ensemble if this is part of a pipeline + assigned_idx: int | None if run_config.ensemble_id: - assert run_config.idx_in_ensemble is not None, ( - "idx_in_ensemble must be set when ensemble_id is provided" - ) - register_clustering_run( - run_config.ensemble_id, - run_config.idx_in_ensemble, - clustering_run_id, + assigned_idx = register_clustering_run( + pipeline_run_id=run_config.ensemble_id, + clustering_run_id=clustering_run_id, ) + logger.info( - f"Registered with pipeline {run_config.ensemble_id} at index {run_config.idx_in_ensemble} in {_ENSEMBLE_REGISTRY_DB}" + f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx} in {_ENSEMBLE_REGISTRY_DB}" ) + # IMPORTANT: set dataset seed based on assigned index + run_config = replace_pydantic_model( + run_config, + {"dataset_seed": run_config.dataset_seed + assigned_idx}, + ) + else: + assigned_idx = None + + # save config + run_config.to_file(storage.config_path) + logger.info(f"Config saved to {storage.config_path}") + # start logger.info("Starting clustering run") logger.info(f"Output directory: {storage.base_dir}") device = get_device() @@ -269,6 +280,7 @@ def main(run_config: ClusteringRunConfig) -> Path: wandb_run: Run | None = None if run_config.wandb_project is not None: wandb_run = wandb.init( + id=clustering_run_id, entity=run_config.wandb_entity, project=run_config.wandb_project, group=run_config.ensemble_id, @@ -278,7 +290,7 @@ def main(run_config: ClusteringRunConfig) -> Path: f"task:{task_name}", f"model:{run_config.wandb_decomp_model}", f"ensemble_id:{run_config.ensemble_id}", - f"idx:{run_config.idx_in_ensemble}", + f"assigned_idx:{assigned_idx}", ], ) # logger.info(f"WandB run: {wandb_run.url}") @@ -347,9 +359,7 @@ def main(run_config: ClusteringRunConfig) -> Path: log_callback=log_callback, ) - # 8. Save merge history and config - run_config.to_file(storage.config_path) - logger.info(f"Config saved to {storage.config_path}") + # 8. Save merge history history.save(storage.history_path) logger.info(f"History saved to {storage.history_path}") @@ -414,9 +424,6 @@ def cli() -> None: } # Handle ensemble-related overrides - if args.idx_in_ensemble is not None: - overrides["dataset_seed"] = run_config.dataset_seed + args.idx_in_ensemble - overrides["idx_in_ensemble"] = args.idx_in_ensemble if args.pipeline_run_id is not None: overrides["ensemble_id"] = args.pipeline_run_id diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py index cde83ffa1..179bc8bca 100644 --- a/spd/clustering/scripts/run_pipeline.py +++ b/spd/clustering/scripts/run_pipeline.py @@ -25,12 +25,14 @@ from typing import Any import wandb_workspaces.workspaces as ws -from pydantic import Field, PositiveInt, field_validator +from pydantic import Field, PositiveInt, field_validator, model_validator from spd.base_config import BaseConfig +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.consts import DistancesMethod from spd.clustering.storage import StorageBase from spd.log import logger +from spd.settings import SPD_CACHE_DIR from spd.utils.command_utils import run_script_array_local from spd.utils.general_utils import replace_pydantic_model from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str @@ -69,20 +71,40 @@ def distances_path(self, method: DistancesMethod) -> Path: class ClusteringPipelineConfig(BaseConfig): """Configuration for submitting an ensemble of clustering runs to SLURM.""" - run_clustering_config_path: Path = Field(description="Path to ClusteringRunConfig file.") + clustering_run_config_path: Path = Field( + description="Path to ClusteringRunConfig file.", + ) n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") distances_methods: list[DistancesMethod] = Field( description="List of method(s) to use for calculating distances" ) - base_output_dir: Path = Field(description="Base directory for outputs of clustering runs.") - slurm_job_name_prefix: str | None = Field(description="Prefix for SLURM job names") - slurm_partition: str | None = Field(description="SLURM partition to use") + base_output_dir: Path = Field( + default=SPD_CACHE_DIR / "clustering_pipeline", + description="Base directory for outputs of clustering ensemble pipeline runs.", + ) + slurm_job_name_prefix: str | None = Field( + default=None, description="Prefix for SLURM job names" + ) + slurm_partition: str | None = Field(default=None, description="SLURM partition to use") wandb_project: str | None = Field( default=None, description="Weights & Biases project name (set to None to disable WandB logging)", ) - wandb_entity: str = Field(description="WandB entity (team/user) name") - create_git_snapshot: bool = Field(description="Create a git snapshot for the run") + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") + create_git_snapshot: bool = Field( + default=False, description="Create a git snapshot for the run" + ) + + @model_validator(mode="after") + def validate_crc(self) -> "ClusteringPipelineConfig": + """Validate that exactly one of clustering_run_config_path points to a valid `ClusteringRunConfig`.""" + assert self.clustering_run_config_path.exists(), ( + f"clustering_run_config_path does not exist: {self.clustering_run_config_path}" + ) + # Try to load ClusteringRunConfig + assert ClusteringRunConfig.from_file(self.clustering_run_config_path) + + return self @field_validator("distances_methods") @classmethod @@ -148,7 +170,7 @@ def generate_clustering_commands( "python", "spd/clustering/scripts/run_clustering.py", "--config", - pipeline_config.run_clustering_config_path.as_posix(), + pipeline_config.clustering_run_config_path.as_posix(), "--pipeline-run-id", pipeline_run_id, "--idx-in-ensemble", diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index da7382644..d883785d5 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -153,6 +153,7 @@ def init_wandb[T_config: BaseConfig]( """ load_dotenv(override=True) + # TODO: pass run id from ExecutionStamp wandb.init( project=project, entity=os.getenv("WANDB_ENTITY"), diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py index 1d2a69c93..bbfb5259e 100644 --- a/tests/clustering/scripts/cluster_resid_mlp.py +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -120,7 +120,6 @@ iters=int(PROCESSED_ACTIVATIONS.n_components_alive * 0.9), merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, - pop_component_prob=0, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py index acb6f394e..0b7f8de97 100644 --- a/tests/clustering/scripts/cluster_ss.py +++ b/tests/clustering/scripts/cluster_ss.py @@ -16,11 +16,11 @@ component_activations, process_activations, ) +from spd.clustering.clustering_run_config import ClusteringRunConfig from spd.clustering.dataset import load_dataset from spd.clustering.merge import merge_iteration from spd.clustering.merge_config import MergeConfig from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble -from spd.clustering.merge_run_config import ClusteringRunConfig from spd.clustering.plotting.activations import plot_activations from spd.clustering.plotting.merge import plot_dists_distribution from spd.models.component_model import ComponentModel, SPDRunInfo @@ -51,7 +51,6 @@ model_path=MODEL_PATH, batch_size=2, dataset_seed=42, - idx_in_ensemble=0, dataset_streaming=True, # no effect since we do this manually ) @@ -104,7 +103,6 @@ iters=2, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, module_name_filter=FILTER_MODULES, filter_dead_threshold=FILTER_DEAD_THRESHOLD, ) diff --git a/tests/clustering/test_calc_distances.py b/tests/clustering/test_calc_distances.py index d8971df05..b06350f4b 100644 --- a/tests/clustering/test_calc_distances.py +++ b/tests/clustering/test_calc_distances.py @@ -11,7 +11,6 @@ def test_merge_history_normalization_happy_path(): iters=3, alpha=1.0, activation_threshold=None, - pop_component_prob=0.0, ) histories = [] diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py new file mode 100644 index 000000000..c903af801 --- /dev/null +++ b/tests/clustering/test_ensemble_registry.py @@ -0,0 +1,110 @@ +"""Tests for ensemble_registry module.""" + +import tempfile +from pathlib import Path +from typing import Any + +import pytest + +from spd.clustering.ensemble_registry import ( + get_clustering_runs, + register_clustering_run, +) + + +@pytest.fixture +def _temp_registry_db(monkeypatch: Any): # pyright: ignore[reportUnusedFunction] + """Create a temporary registry database for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + temp_db_path = Path(tmpdir) / "test_registry.db" + monkeypatch.setattr("spd.clustering.ensemble_registry._ENSEMBLE_REGISTRY_DB", temp_db_path) + yield temp_db_path + + +class TestRegisterClusteringRun: + """Test register_clustering_run() function.""" + + def test_register_single_run(self, _temp_registry_db: Any): + """Test registering a single run.""" + pipeline_id = "pipeline_001" + run_id = "run_001" + + assigned_idx = register_clustering_run(pipeline_id, run_id) + + # First index should be 0 + assert assigned_idx == 0 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001")] + + def test_register_multiple_runs(self, _temp_registry_db: Any): + """Test registering multiple runs sequentially.""" + pipeline_id = "pipeline_002" + + idx0 = register_clustering_run(pipeline_id, "run_001") + idx1 = register_clustering_run(pipeline_id, "run_002") + idx2 = register_clustering_run(pipeline_id, "run_003") + + # Should auto-assign 0, 1, 2 + assert idx0 == 0 + assert idx1 == 1 + assert idx2 == 2 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] + + def test_different_pipelines_independent(self, _temp_registry_db: Any): + """Test that different pipelines have independent index sequences.""" + pipeline_a = "pipeline_a" + pipeline_b = "pipeline_b" + + # Both should start at 0 when auto-assigning + idx_a0 = register_clustering_run(pipeline_a, "run_a1") + idx_b0 = register_clustering_run(pipeline_b, "run_b1") + + assert idx_a0 == 0 + assert idx_b0 == 0 + + # Both should increment independently + idx_a1 = register_clustering_run(pipeline_a, "run_a2") + idx_b1 = register_clustering_run(pipeline_b, "run_b2") + + assert idx_a1 == 1 + assert idx_b1 == 1 + + # Verify in database + runs_a = get_clustering_runs(pipeline_a) + runs_b = get_clustering_runs(pipeline_b) + + assert runs_a == [(0, "run_a1"), (1, "run_a2")] + assert runs_b == [(0, "run_b1"), (1, "run_b2")] + + +class TestGetClusteringRuns: + """Test get_clustering_runs() function.""" + + def test_get_empty_pipeline(self, _temp_registry_db: Any): + """Test getting runs from a pipeline that doesn't exist.""" + runs = get_clustering_runs("nonexistent_pipeline") + assert runs == [] + + def test_get_runs_sorted_by_index(self, _temp_registry_db: Any): + """Test that runs are returned sorted by index.""" + pipeline_id = "pipeline_sort" + + # Register runs (indices will be auto-assigned in order) + register_clustering_run(pipeline_id, "run_000") + register_clustering_run(pipeline_id, "run_001") + register_clustering_run(pipeline_id, "run_002") + register_clustering_run(pipeline_id, "run_003") + + # Should be returned in sorted order + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_000"), + (1, "run_001"), + (2, "run_002"), + (3, "run_003"), + ] diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py index 9f191075b..63f4e88f7 100644 --- a/tests/clustering/test_merge_config.py +++ b/tests/clustering/test_merge_config.py @@ -74,7 +74,6 @@ def test_config_with_all_parameters(self): iters=200, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.5}, - pop_component_prob=0.1, filter_dead_threshold=0.001, module_name_filter="model.layers", ) @@ -84,7 +83,6 @@ def test_config_with_all_parameters(self): assert config.iters == 200 assert config.merge_pair_sampling_method == "mcmc" assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} - assert config.pop_component_prob == 0.1 assert config.filter_dead_threshold == 0.001 assert config.module_name_filter == "model.layers" diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py index 14811b7c5..8492300de 100644 --- a/tests/clustering/test_merge_integration.py +++ b/tests/clustering/test_merge_integration.py @@ -25,7 +25,6 @@ def test_merge_with_range_sampler(self): iters=5, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -59,7 +58,6 @@ def test_merge_with_mcmc_sampler(self): iters=5, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0}, - pop_component_prob=0, filter_dead_threshold=0.001, ) @@ -77,37 +75,6 @@ def test_merge_with_mcmc_sampler(self): assert history.merges.k_groups[-1].item() < n_components assert history.merges.k_groups[-1].item() >= 2 - def test_merge_with_popping(self): - """Test merge iteration with component popping.""" - # Create test data - n_samples = 100 - n_components = 15 - activations = torch.rand(n_samples, n_components) - component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) - - # Configure with popping enabled - config = MergeConfig( - activation_threshold=0.1, - alpha=1.0, - iters=10, - merge_pair_sampling_method="range", - merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0.3, # 30% chance of popping - filter_dead_threshold=0.001, - ) - - # Run merge iteration - history = merge_iteration( - activations=activations, merge_config=config, component_labels=component_labels - ) - - # Check results - assert history is not None - # First entry is after first merge, so should be n_components - 1 - assert history.merges.k_groups[0].item() == n_components - 1 - # Final group count depends on pops, but should be less than initial - assert history.merges.k_groups[-1].item() < n_components - def test_merge_comparison_samplers(self): """Compare behavior of different samplers with same data.""" # Create test data with clear structure @@ -128,7 +95,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum - pop_component_prob=0, ) history_range = merge_iteration( @@ -144,7 +110,6 @@ def test_merge_comparison_samplers(self): iters=3, merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp - pop_component_prob=0, ) history_mcmc = merge_iteration( @@ -173,7 +138,6 @@ def test_merge_with_small_components(self): iters=1, # Just one merge merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 2.0}, - pop_component_prob=0, ) history = merge_iteration( diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py new file mode 100644 index 000000000..ca6bad6ee --- /dev/null +++ b/tests/clustering/test_pipeline_config.py @@ -0,0 +1,137 @@ +"""Tests for ClusteringPipelineConfig and ClusteringRunConfig with inline config support.""" + +from pathlib import Path + +import pydantic_core +import pytest + +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.merge_config import MergeConfig +from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig +from spd.settings import REPO_ROOT + + +class TestClusteringRunConfigStableHash: + """Test ClusteringRunConfig.stable_hash_b64() method.""" + + def test_stable_hash_b64(self): + """Test that stable_hash_b64 is deterministic, unique, and URL-safe.""" + # Create 4 configs: 2 identical, 2 different + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config3 = ClusteringRunConfig( + model_path="wandb:test/project/run2", # Different model_path + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config4 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig( + activation_threshold=0.2 + ), # Different merge_config to test nested fields + ) + + hash1 = config1.stable_hash_b64() + hash2 = config2.stable_hash_b64() + hash3 = config3.stable_hash_b64() + hash4 = config4.stable_hash_b64() + + # Identical configs produce identical hashes + assert hash1 == hash2 + + # Different configs produce different hashes + assert hash1 != hash3 + assert hash1 != hash4 + assert hash3 != hash4 + + # Hashes are strings + assert isinstance(hash1, str) + assert len(hash1) > 0 + + # Hashes are URL-safe base64 (no padding, URL-safe chars only) + assert "=" not in hash1 + valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert all(c in valid_chars for c in hash1) + + +class TestClusteringPipelineConfigValidation: + """Test ClusteringPipelineConfig validation logic.""" + + def test_error_when_path_does_not_exist(self): + """Test that error is raised when clustering_run_config_path does not exist.""" + with pytest.raises(pydantic_core._pydantic_core.ValidationError): + ClusteringPipelineConfig( + clustering_run_config_path=Path("nonexistent/path.json"), + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + def test_valid_config_with_existing_path(self): + """Test that config is valid when path points to existing file.""" + expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") + + config = ClusteringPipelineConfig( + clustering_run_config_path=expected_path, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.clustering_run_config_path == expected_path + + +def _get_config_files(path: Path): + """Helper to get all config files.""" + pipeline_config_files = ( + list(path.glob("*.yaml")) + list(path.glob("*.yml")) + list(path.glob("*.json")) + ) + assert len(pipeline_config_files) > 0, f"No pipeline files found in {path}" + return pipeline_config_files + + +class TestAllConfigsValidation: + """Test that all existing config files can be loaded and validated.""" + + @pytest.mark.parametrize( + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs"), + ids=lambda p: p.stem, + ) + def test_config_validate_pipeline(self, config_file: Path): + """Test that each pipeline config file is valid.""" + print(config_file) + _config = ClusteringPipelineConfig.from_file(config_file) + crc_path = _config.clustering_run_config_path + print(f"{crc_path = }") + assert crc_path.exists() + + @pytest.mark.parametrize( + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs" / "crc"), + ids=lambda p: p.stem, + ) + def test_config_validate_pipeline_clustering_run(self, config_file: Path): + """Test that each clustering run config file is valid.""" + print(config_file) + _config = ClusteringRunConfig.from_file(config_file) + assert isinstance(_config, ClusteringRunConfig) diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py index 91a7cf2ad..5e2cbbd1c 100644 --- a/tests/clustering/test_run_clustering_happy_path.py +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -3,8 +3,8 @@ import pytest +from spd.clustering.clustering_run_config import ClusteringRunConfig, LoggingIntervals from spd.clustering.merge_config import MergeConfig -from spd.clustering.merge_run_config import ClusteringRunConfig, LoggingIntervals from spd.clustering.scripts.run_clustering import main @@ -16,7 +16,6 @@ def test_run_clustering_happy_path(): model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run batch_size=4, dataset_seed=0, - idx_in_ensemble=0, base_output_dir=Path(temp_dir), ensemble_id=None, merge_config=MergeConfig( @@ -25,7 +24,6 @@ def test_run_clustering_happy_path(): iters=3, merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.05}, - pop_component_prob=0, ), wandb_project=None, wandb_entity="goodfire", From 1f0725cb8ff970912f04bd7e165222c4dde7171e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Fri, 24 Oct 2025 18:14:46 +0100 Subject: [PATCH 63/77] deps --- uv.lock | 88 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/uv.lock b/uv.lock index 774fdd55e..4f14f1ffd 100644 --- a/uv.lock +++ b/uv.lock @@ -331,7 +331,7 @@ wheels = [ [[package]] name = "datasets" -version = "4.2.0" +version = "4.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dill" }, @@ -349,9 +349,9 @@ dependencies = [ { name = "tqdm" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/48/0186fbc4b86a4f9ecaf04eb01e877e78b53bfa0b03be9c84b2298431ba33/datasets-4.2.0.tar.gz", hash = "sha256:8333a7db9f3bb8044c1b819a35d4e3e2809596c837793b0921382efffdc36e78", size = 582256, upload-time = "2025-10-09T16:10:15.534Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/47/325206ac160f7699ed9f1798afa8f8f8d5189b03bf3815654859ac1d5cba/datasets-4.3.0.tar.gz", hash = "sha256:bc9118ed9afd92346c5be7ed3aaa00177eb907c25467f9d072a0d22777efbd2b", size = 582801, upload-time = "2025-10-23T16:31:51.547Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/91/9e/0bbbd09b116fd8ee2d3617e28e6598551d2f0f24d3a2ce99cc87ec85aeb0/datasets-4.2.0-py3-none-any.whl", hash = "sha256:fdc43aaf4a73b31f64f80f72f195ab413a1141ed15555d675b2fd17926f8b026", size = 506316, upload-time = "2025-10-09T16:10:13.375Z" }, + { url = "https://files.pythonhosted.org/packages/ca/51/409a8184ed35453d9cbb3d6b20d524b1115c2c2d117b85d5e9b06cd70b45/datasets-4.3.0-py3-none-any.whl", hash = "sha256:0ea157e72138b3ca6c7d2415f19a164ecf7d4c4fa72da2a570da286882e96903", size = 506846, upload-time = "2025-10-23T16:31:49.965Z" }, ] [[package]] @@ -592,7 +592,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.35.3" +version = "0.36.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -604,9 +604,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/10/7e/a0a97de7c73671863ca6b3f61fa12518caf35db37825e43d63a70956738c/huggingface_hub-0.35.3.tar.gz", hash = "sha256:350932eaa5cc6a4747efae85126ee220e4ef1b54e29d31c3b45c5612ddf0b32a", size = 461798, upload-time = "2025-09-29T14:29:58.625Z" } +sdist = { url = "https://files.pythonhosted.org/packages/98/63/4910c5fa9128fdadf6a9c5ac138e8b1b6cee4ca44bf7915bbfbce4e355ee/huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25", size = 463358, upload-time = "2025-10-23T12:12:01.413Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/31/a0/651f93d154cb72323358bf2bbae3e642bdb5d2f1bfc874d096f7cb159fa0/huggingface_hub-0.35.3-py3-none-any.whl", hash = "sha256:0e3a01829c19d86d03793e4577816fe3bdfc1602ac62c7fb220d593d351224ba", size = 564262, upload-time = "2025-09-29T14:29:55.813Z" }, + { url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" }, ] [[package]] @@ -1406,24 +1406,24 @@ wheels = [ [[package]] name = "pyarrow" -version = "21.0.0" +version = "22.0.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ef/c2/ea068b8f00905c06329a3dfcd40d0fcc2b7d0f2e355bdb25b65e0a0e4cd4/pyarrow-21.0.0.tar.gz", hash = "sha256:5051f2dccf0e283ff56335760cbc8622cf52264d67e359d5569541ac11b6d5bc", size = 1133487, upload-time = "2025-07-18T00:57:31.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/53/04a7fdc63e6056116c9ddc8b43bc28c12cdd181b85cbeadb79278475f3ae/pyarrow-22.0.0.tar.gz", hash = "sha256:3d600dc583260d845c7d8a6db540339dd883081925da2bd1c5cb808f720b3cd9", size = 1151151, upload-time = "2025-10-24T12:30:00.762Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/ca/c7eaa8e62db8fb37ce942b1ea0c6d7abfe3786ca193957afa25e71b81b66/pyarrow-21.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e99310a4ebd4479bcd1964dff9e14af33746300cb014aa4a3781738ac63baf4a", size = 31154306, upload-time = "2025-07-18T00:56:04.42Z" }, - { url = "https://files.pythonhosted.org/packages/ce/e8/e87d9e3b2489302b3a1aea709aaca4b781c5252fcb812a17ab6275a9a484/pyarrow-21.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d2fe8e7f3ce329a71b7ddd7498b3cfac0eeb200c2789bd840234f0dc271a8efe", size = 32680622, upload-time = "2025-07-18T00:56:07.505Z" }, - { url = "https://files.pythonhosted.org/packages/84/52/79095d73a742aa0aba370c7942b1b655f598069489ab387fe47261a849e1/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:f522e5709379d72fb3da7785aa489ff0bb87448a9dc5a75f45763a795a089ebd", size = 41104094, upload-time = "2025-07-18T00:56:10.994Z" }, - { url = "https://files.pythonhosted.org/packages/89/4b/7782438b551dbb0468892a276b8c789b8bbdb25ea5c5eb27faadd753e037/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:69cbbdf0631396e9925e048cfa5bce4e8c3d3b41562bbd70c685a8eb53a91e61", size = 42825576, upload-time = "2025-07-18T00:56:15.569Z" }, - { url = "https://files.pythonhosted.org/packages/b3/62/0f29de6e0a1e33518dec92c65be0351d32d7ca351e51ec5f4f837a9aab91/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:731c7022587006b755d0bdb27626a1a3bb004bb56b11fb30d98b6c1b4718579d", size = 43368342, upload-time = "2025-07-18T00:56:19.531Z" }, - { url = "https://files.pythonhosted.org/packages/90/c7/0fa1f3f29cf75f339768cc698c8ad4ddd2481c1742e9741459911c9ac477/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dc56bc708f2d8ac71bd1dcb927e458c93cec10b98eb4120206a4091db7b67b99", size = 45131218, upload-time = "2025-07-18T00:56:23.347Z" }, - { url = "https://files.pythonhosted.org/packages/01/63/581f2076465e67b23bc5a37d4a2abff8362d389d29d8105832e82c9c811c/pyarrow-21.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:186aa00bca62139f75b7de8420f745f2af12941595bbbfa7ed3870ff63e25636", size = 26087551, upload-time = "2025-07-18T00:56:26.758Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ab/357d0d9648bb8241ee7348e564f2479d206ebe6e1c47ac5027c2e31ecd39/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:a7a102574faa3f421141a64c10216e078df467ab9576684d5cd696952546e2da", size = 31290064, upload-time = "2025-07-18T00:56:30.214Z" }, - { url = "https://files.pythonhosted.org/packages/3f/8a/5685d62a990e4cac2043fc76b4661bf38d06efed55cf45a334b455bd2759/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:1e005378c4a2c6db3ada3ad4c217b381f6c886f0a80d6a316fe586b90f77efd7", size = 32727837, upload-time = "2025-07-18T00:56:33.935Z" }, - { url = "https://files.pythonhosted.org/packages/fc/de/c0828ee09525c2bafefd3e736a248ebe764d07d0fd762d4f0929dbc516c9/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:65f8e85f79031449ec8706b74504a316805217b35b6099155dd7e227eef0d4b6", size = 41014158, upload-time = "2025-07-18T00:56:37.528Z" }, - { url = "https://files.pythonhosted.org/packages/6e/26/a2865c420c50b7a3748320b614f3484bfcde8347b2639b2b903b21ce6a72/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:3a81486adc665c7eb1a2bde0224cfca6ceaba344a82a971ef059678417880eb8", size = 42667885, upload-time = "2025-07-18T00:56:41.483Z" }, - { url = "https://files.pythonhosted.org/packages/0a/f9/4ee798dc902533159250fb4321267730bc0a107d8c6889e07c3add4fe3a5/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:fc0d2f88b81dcf3ccf9a6ae17f89183762c8a94a5bdcfa09e05cfe413acf0503", size = 43276625, upload-time = "2025-07-18T00:56:48.002Z" }, - { url = "https://files.pythonhosted.org/packages/5a/da/e02544d6997037a4b0d22d8e5f66bc9315c3671371a8b18c79ade1cefe14/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6299449adf89df38537837487a4f8d3bd91ec94354fdd2a7d30bc11c48ef6e79", size = 44951890, upload-time = "2025-07-18T00:56:52.568Z" }, - { url = "https://files.pythonhosted.org/packages/e5/4e/519c1bc1876625fe6b71e9a28287c43ec2f20f73c658b9ae1d485c0c206e/pyarrow-21.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:222c39e2c70113543982c6b34f3077962b44fca38c0bd9e68bb6781534425c10", size = 26371006, upload-time = "2025-07-18T00:56:56.379Z" }, + { url = "https://files.pythonhosted.org/packages/a6/d6/d0fac16a2963002fc22c8fa75180a838737203d558f0ed3b564c4a54eef5/pyarrow-22.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e6e95176209257803a8b3d0394f21604e796dadb643d2f7ca21b66c9c0b30c9a", size = 34204629, upload-time = "2025-10-24T10:06:20.274Z" }, + { url = "https://files.pythonhosted.org/packages/c6/9c/1d6357347fbae062ad3f17082f9ebc29cc733321e892c0d2085f42a2212b/pyarrow-22.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:001ea83a58024818826a9e3f89bf9310a114f7e26dfe404a4c32686f97bd7901", size = 35985783, upload-time = "2025-10-24T10:06:27.301Z" }, + { url = "https://files.pythonhosted.org/packages/ff/c0/782344c2ce58afbea010150df07e3a2f5fdad299cd631697ae7bd3bac6e3/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ce20fe000754f477c8a9125543f1936ea5b8867c5406757c224d745ed033e691", size = 45020999, upload-time = "2025-10-24T10:06:35.387Z" }, + { url = "https://files.pythonhosted.org/packages/1b/8b/5362443737a5307a7b67c1017c42cd104213189b4970bf607e05faf9c525/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e0a15757fccb38c410947df156f9749ae4a3c89b2393741a50521f39a8cf202a", size = 47724601, upload-time = "2025-10-24T10:06:43.551Z" }, + { url = "https://files.pythonhosted.org/packages/69/4d/76e567a4fc2e190ee6072967cb4672b7d9249ac59ae65af2d7e3047afa3b/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cedb9dd9358e4ea1d9bce3665ce0797f6adf97ff142c8e25b46ba9cdd508e9b6", size = 48001050, upload-time = "2025-10-24T10:06:52.284Z" }, + { url = "https://files.pythonhosted.org/packages/01/5e/5653f0535d2a1aef8223cee9d92944cb6bccfee5cf1cd3f462d7cb022790/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:252be4a05f9d9185bb8c18e83764ebcfea7185076c07a7a662253af3a8c07941", size = 50307877, upload-time = "2025-10-24T10:07:02.405Z" }, + { url = "https://files.pythonhosted.org/packages/2d/f8/1d0bd75bf9328a3b826e24a16e5517cd7f9fbf8d34a3184a4566ef5a7f29/pyarrow-22.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:a4893d31e5ef780b6edcaf63122df0f8d321088bb0dee4c8c06eccb1ca28d145", size = 27977099, upload-time = "2025-10-24T10:08:07.259Z" }, + { url = "https://files.pythonhosted.org/packages/90/81/db56870c997805bf2b0f6eeeb2d68458bf4654652dccdcf1bf7a42d80903/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:f7fe3dbe871294ba70d789be16b6e7e52b418311e166e0e3cba9522f0f437fb1", size = 34336685, upload-time = "2025-10-24T10:07:11.47Z" }, + { url = "https://files.pythonhosted.org/packages/1c/98/0727947f199aba8a120f47dfc229eeb05df15bcd7a6f1b669e9f882afc58/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:ba95112d15fd4f1105fb2402c4eab9068f0554435e9b7085924bcfaac2cc306f", size = 36032158, upload-time = "2025-10-24T10:07:18.626Z" }, + { url = "https://files.pythonhosted.org/packages/96/b4/9babdef9c01720a0785945c7cf550e4acd0ebcd7bdd2e6f0aa7981fa85e2/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c064e28361c05d72eed8e744c9605cbd6d2bb7481a511c74071fd9b24bc65d7d", size = 44892060, upload-time = "2025-10-24T10:07:26.002Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ca/2f8804edd6279f78a37062d813de3f16f29183874447ef6d1aadbb4efa0f/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6f9762274496c244d951c819348afbcf212714902742225f649cf02823a6a10f", size = 47504395, upload-time = "2025-10-24T10:07:34.09Z" }, + { url = "https://files.pythonhosted.org/packages/b9/f0/77aa5198fd3943682b2e4faaf179a674f0edea0d55d326d83cb2277d9363/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a9d9ffdc2ab696f6b15b4d1f7cec6658e1d788124418cb30030afbae31c64746", size = 48066216, upload-time = "2025-10-24T10:07:43.528Z" }, + { url = "https://files.pythonhosted.org/packages/79/87/a1937b6e78b2aff18b706d738c9e46ade5bfcf11b294e39c87706a0089ac/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ec1a15968a9d80da01e1d30349b2b0d7cc91e96588ee324ce1b5228175043e95", size = 50288552, upload-time = "2025-10-24T10:07:53.519Z" }, + { url = "https://files.pythonhosted.org/packages/60/ae/b5a5811e11f25788ccfdaa8f26b6791c9807119dffcf80514505527c384c/pyarrow-22.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:bba208d9c7decf9961998edf5c65e3ea4355d5818dd6cd0f6809bec1afb951cc", size = 28262504, upload-time = "2025-10-24T10:08:00.932Z" }, ] [[package]] @@ -1736,28 +1736,28 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9e/58/6ca66896635352812de66f71cdf9ff86b3a4f79071ca5730088c0cd0fc8d/ruff-0.14.1.tar.gz", hash = "sha256:1dd86253060c4772867c61791588627320abcb6ed1577a90ef432ee319729b69", size = 5513429, upload-time = "2025-10-16T18:05:41.766Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/39/9cc5ab181478d7a18adc1c1e051a84ee02bec94eb9bdfd35643d7c74ca31/ruff-0.14.1-py3-none-linux_armv6l.whl", hash = "sha256:083bfc1f30f4a391ae09c6f4f99d83074416b471775b59288956f5bc18e82f8b", size = 12445415, upload-time = "2025-10-16T18:04:48.227Z" }, - { url = "https://files.pythonhosted.org/packages/ef/2e/1226961855ccd697255988f5a2474890ac7c5863b080b15bd038df820818/ruff-0.14.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f6fa757cd717f791009f7669fefb09121cc5f7d9bd0ef211371fad68c2b8b224", size = 12784267, upload-time = "2025-10-16T18:04:52.515Z" }, - { url = "https://files.pythonhosted.org/packages/c1/ea/fd9e95863124ed159cd0667ec98449ae461de94acda7101f1acb6066da00/ruff-0.14.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d6191903d39ac156921398e9c86b7354d15e3c93772e7dbf26c9fcae59ceccd5", size = 11781872, upload-time = "2025-10-16T18:04:55.396Z" }, - { url = "https://files.pythonhosted.org/packages/1e/5a/e890f7338ff537dba4589a5e02c51baa63020acfb7c8cbbaea4831562c96/ruff-0.14.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed04f0e04f7a4587244e5c9d7df50e6b5bf2705d75059f409a6421c593a35896", size = 12226558, upload-time = "2025-10-16T18:04:58.166Z" }, - { url = "https://files.pythonhosted.org/packages/a6/7a/8ab5c3377f5bf31e167b73651841217542bcc7aa1c19e83030835cc25204/ruff-0.14.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c9e6cf6cd4acae0febbce29497accd3632fe2025c0c583c8b87e8dbdeae5f61", size = 12187898, upload-time = "2025-10-16T18:05:01.455Z" }, - { url = "https://files.pythonhosted.org/packages/48/8d/ba7c33aa55406955fc124e62c8259791c3d42e3075a71710fdff9375134f/ruff-0.14.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fa2458527794ecdfbe45f654e42c61f2503a230545a91af839653a0a93dbc6", size = 12939168, upload-time = "2025-10-16T18:05:04.397Z" }, - { url = "https://files.pythonhosted.org/packages/b4/c2/70783f612b50f66d083380e68cbd1696739d88e9b4f6164230375532c637/ruff-0.14.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:39f1c392244e338b21d42ab29b8a6392a722c5090032eb49bb4d6defcdb34345", size = 14386942, upload-time = "2025-10-16T18:05:07.102Z" }, - { url = "https://files.pythonhosted.org/packages/48/44/cd7abb9c776b66d332119d67f96acf15830d120f5b884598a36d9d3f4d83/ruff-0.14.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7382fa12a26cce1f95070ce450946bec357727aaa428983036362579eadcc5cf", size = 13990622, upload-time = "2025-10-16T18:05:09.882Z" }, - { url = "https://files.pythonhosted.org/packages/eb/56/4259b696db12ac152fe472764b4f78bbdd9b477afd9bc3a6d53c01300b37/ruff-0.14.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd0bf2be3ae8521e1093a487c4aa3b455882f139787770698530d28ed3fbb37c", size = 13431143, upload-time = "2025-10-16T18:05:13.46Z" }, - { url = "https://files.pythonhosted.org/packages/e0/35/266a80d0eb97bd224b3265b9437bd89dde0dcf4faf299db1212e81824e7e/ruff-0.14.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabcaa9ccf8089fb4fdb78d17cc0e28241520f50f4c2e88cb6261ed083d85151", size = 13132844, upload-time = "2025-10-16T18:05:16.1Z" }, - { url = "https://files.pythonhosted.org/packages/65/6e/d31ce218acc11a8d91ef208e002a31acf315061a85132f94f3df7a252b18/ruff-0.14.1-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:747d583400f6125ec11a4c14d1c8474bf75d8b419ad22a111a537ec1a952d192", size = 13401241, upload-time = "2025-10-16T18:05:19.395Z" }, - { url = "https://files.pythonhosted.org/packages/9f/b5/dbc4221bf0b03774b3b2f0d47f39e848d30664157c15b965a14d890637d2/ruff-0.14.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5a6e74c0efd78515a1d13acbfe6c90f0f5bd822aa56b4a6d43a9ffb2ae6e56cd", size = 12132476, upload-time = "2025-10-16T18:05:22.163Z" }, - { url = "https://files.pythonhosted.org/packages/98/4b/ac99194e790ccd092d6a8b5f341f34b6e597d698e3077c032c502d75ea84/ruff-0.14.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0ea6a864d2fb41a4b6d5b456ed164302a0d96f4daac630aeba829abfb059d020", size = 12139749, upload-time = "2025-10-16T18:05:25.162Z" }, - { url = "https://files.pythonhosted.org/packages/47/26/7df917462c3bb5004e6fdfcc505a49e90bcd8a34c54a051953118c00b53a/ruff-0.14.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0826b8764f94229604fa255918d1cc45e583e38c21c203248b0bfc9a0e930be5", size = 12544758, upload-time = "2025-10-16T18:05:28.018Z" }, - { url = "https://files.pythonhosted.org/packages/64/d0/81e7f0648e9764ad9b51dd4be5e5dac3fcfff9602428ccbae288a39c2c22/ruff-0.14.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cbc52160465913a1a3f424c81c62ac8096b6a491468e7d872cb9444a860bc33d", size = 13221811, upload-time = "2025-10-16T18:05:30.707Z" }, - { url = "https://files.pythonhosted.org/packages/c3/07/3c45562c67933cc35f6d5df4ca77dabbcd88fddaca0d6b8371693d29fd56/ruff-0.14.1-py3-none-win32.whl", hash = "sha256:e037ea374aaaff4103240ae79168c0945ae3d5ae8db190603de3b4012bd1def6", size = 12319467, upload-time = "2025-10-16T18:05:33.261Z" }, - { url = "https://files.pythonhosted.org/packages/02/88/0ee4ca507d4aa05f67e292d2e5eb0b3e358fbcfe527554a2eda9ac422d6b/ruff-0.14.1-py3-none-win_amd64.whl", hash = "sha256:59d599cdff9c7f925a017f6f2c256c908b094e55967f93f2821b1439928746a1", size = 13401123, upload-time = "2025-10-16T18:05:35.984Z" }, - { url = "https://files.pythonhosted.org/packages/b8/81/4b6387be7014858d924b843530e1b2a8e531846807516e9bea2ee0936bf7/ruff-0.14.1-py3-none-win_arm64.whl", hash = "sha256:e3b443c4c9f16ae850906b8d0a707b2a4c16f8d2f0a7fe65c475c5886665ce44", size = 12436636, upload-time = "2025-10-16T18:05:38.995Z" }, +version = "0.14.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/34/8218a19b2055b80601e8fd201ec723c74c7fe1ca06d525a43ed07b6d8e85/ruff-0.14.2.tar.gz", hash = "sha256:98da787668f239313d9c902ca7c523fe11b8ec3f39345553a51b25abc4629c96", size = 5539663, upload-time = "2025-10-23T19:37:00.956Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/dd/23eb2db5ad9acae7c845700493b72d3ae214dce0b226f27df89216110f2b/ruff-0.14.2-py3-none-linux_armv6l.whl", hash = "sha256:7cbe4e593505bdec5884c2d0a4d791a90301bc23e49a6b1eb642dd85ef9c64f1", size = 12533390, upload-time = "2025-10-23T19:36:18.044Z" }, + { url = "https://files.pythonhosted.org/packages/5a/8c/5f9acff43ddcf3f85130d0146d0477e28ccecc495f9f684f8f7119b74c0d/ruff-0.14.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8d54b561729cee92f8d89c316ad7a3f9705533f5903b042399b6ae0ddfc62e11", size = 12887187, upload-time = "2025-10-23T19:36:22.664Z" }, + { url = "https://files.pythonhosted.org/packages/99/fa/047646491479074029665022e9f3dc6f0515797f40a4b6014ea8474c539d/ruff-0.14.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5c8753dfa44ebb2cde10ce5b4d2ef55a41fb9d9b16732a2c5df64620dbda44a3", size = 11925177, upload-time = "2025-10-23T19:36:24.778Z" }, + { url = "https://files.pythonhosted.org/packages/15/8b/c44cf7fe6e59ab24a9d939493a11030b503bdc2a16622cede8b7b1df0114/ruff-0.14.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d0bbeffb8d9f4fccf7b5198d566d0bad99a9cb622f1fc3467af96cb8773c9e3", size = 12358285, upload-time = "2025-10-23T19:36:26.979Z" }, + { url = "https://files.pythonhosted.org/packages/45/01/47701b26254267ef40369aea3acb62a7b23e921c27372d127e0f3af48092/ruff-0.14.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7047f0c5a713a401e43a88d36843d9c83a19c584e63d664474675620aaa634a8", size = 12303832, upload-time = "2025-10-23T19:36:29.192Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5c/ae7244ca4fbdf2bee9d6405dcd5bc6ae51ee1df66eb7a9884b77b8af856d/ruff-0.14.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bf8d2f9aa1602599217d82e8e0af7fd33e5878c4d98f37906b7c93f46f9a839", size = 13036995, upload-time = "2025-10-23T19:36:31.861Z" }, + { url = "https://files.pythonhosted.org/packages/27/4c/0860a79ce6fd4c709ac01173f76f929d53f59748d0dcdd662519835dae43/ruff-0.14.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1c505b389e19c57a317cf4b42db824e2fca96ffb3d86766c1c9f8b96d32048a7", size = 14512649, upload-time = "2025-10-23T19:36:33.915Z" }, + { url = "https://files.pythonhosted.org/packages/7f/7f/d365de998069720a3abfc250ddd876fc4b81a403a766c74ff9bde15b5378/ruff-0.14.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a307fc45ebd887b3f26b36d9326bb70bf69b01561950cdcc6c0bdf7bb8e0f7cc", size = 14088182, upload-time = "2025-10-23T19:36:36.983Z" }, + { url = "https://files.pythonhosted.org/packages/6c/ea/d8e3e6b209162000a7be1faa41b0a0c16a133010311edc3329753cc6596a/ruff-0.14.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:61ae91a32c853172f832c2f40bd05fd69f491db7289fb85a9b941ebdd549781a", size = 13599516, upload-time = "2025-10-23T19:36:39.208Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ea/c7810322086db68989fb20a8d5221dd3b79e49e396b01badca07b433ab45/ruff-0.14.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1967e40286f63ee23c615e8e7e98098dedc7301568bd88991f6e544d8ae096", size = 13272690, upload-time = "2025-10-23T19:36:41.453Z" }, + { url = "https://files.pythonhosted.org/packages/a9/39/10b05acf8c45786ef501d454e00937e1b97964f846bf28883d1f9619928a/ruff-0.14.2-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:2877f02119cdebf52a632d743a2e302dea422bfae152ebe2f193d3285a3a65df", size = 13496497, upload-time = "2025-10-23T19:36:43.61Z" }, + { url = "https://files.pythonhosted.org/packages/59/a1/1f25f8301e13751c30895092485fada29076e5e14264bdacc37202e85d24/ruff-0.14.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e681c5bc777de5af898decdcb6ba3321d0d466f4cb43c3e7cc2c3b4e7b843a05", size = 12266116, upload-time = "2025-10-23T19:36:45.625Z" }, + { url = "https://files.pythonhosted.org/packages/5c/fa/0029bfc9ce16ae78164e6923ef392e5f173b793b26cc39aa1d8b366cf9dc/ruff-0.14.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e21be42d72e224736f0c992cdb9959a2fa53c7e943b97ef5d081e13170e3ffc5", size = 12281345, upload-time = "2025-10-23T19:36:47.618Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ab/ece7baa3c0f29b7683be868c024f0838770c16607bea6852e46b202f1ff6/ruff-0.14.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b8264016f6f209fac16262882dbebf3f8be1629777cf0f37e7aff071b3e9b92e", size = 12629296, upload-time = "2025-10-23T19:36:49.789Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7f/638f54b43f3d4e48c6a68062794e5b367ddac778051806b9e235dfb7aa81/ruff-0.14.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5ca36b4cb4db3067a3b24444463ceea5565ea78b95fe9a07ca7cb7fd16948770", size = 13371610, upload-time = "2025-10-23T19:36:51.882Z" }, + { url = "https://files.pythonhosted.org/packages/8d/35/3654a973ebe5b32e1fd4a08ed2d46755af7267da7ac710d97420d7b8657d/ruff-0.14.2-py3-none-win32.whl", hash = "sha256:41775927d287685e08f48d8eb3f765625ab0b7042cc9377e20e64f4eb0056ee9", size = 12415318, upload-time = "2025-10-23T19:36:53.961Z" }, + { url = "https://files.pythonhosted.org/packages/71/30/3758bcf9e0b6a4193a6f51abf84254aba00887dfa8c20aba18aa366c5f57/ruff-0.14.2-py3-none-win_amd64.whl", hash = "sha256:0df3424aa5c3c08b34ed8ce099df1021e3adaca6e90229273496b839e5a7e1af", size = 13565279, upload-time = "2025-10-23T19:36:56.578Z" }, + { url = "https://files.pythonhosted.org/packages/2e/5d/aa883766f8ef9ffbe6aa24f7192fb71632f31a30e77eb39aa2b0dc4290ac/ruff-0.14.2-py3-none-win_arm64.whl", hash = "sha256:ea9d635e83ba21569fbacda7e78afbfeb94911c9434aff06192d9bc23fd5495a", size = 12554956, upload-time = "2025-10-23T19:36:58.714Z" }, ] [[package]] From 39dda42e43d937bfc1c405777eb3615b83626a6e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 11:29:53 +0000 Subject: [PATCH 64/77] format --- spd/clustering/ci_dt/core.py | 6 +++--- spd/clustering/ci_dt/feature_names.py | 5 ++--- spd/clustering/ci_dt/pipeline.py | 5 +---- spd/clustering/ci_dt/plot.py | 23 +++++++++++++++++------ spd/clustering/ci_dt/run.py | 7 +++---- 5 files changed, 26 insertions(+), 20 deletions(-) diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index 8b537ce7a..3be97055b 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -1,9 +1,9 @@ """Core library functions for causal importance decision trees.""" +import warnings from collections.abc import Sequence from dataclasses import dataclass from typing import Literal -import warnings import numpy as np from jaxtyping import Bool, Float @@ -130,6 +130,7 @@ def predict_all( MetricKey = Literal["ap", "acc", "bacc", "prev", "tpr", "tnr", "precision", "npv", "f1"] + def layer_metrics( Y_true: Bool[np.ndarray, "n t"], Y_prob: Float[np.ndarray, "n t"], @@ -186,7 +187,7 @@ def layer_metrics( precision[j] = tp / (tp + fp) else: precision[j] = np.nan - warnings.warn(f"Precision failed: {tp=}, {fp=}, {tp+fp=}") + warnings.warn(f"Precision failed: {tp=}, {fp=}, {tp+fp=}", stacklevel=1) # Negative Predictive Value = TN / (TN + FN) - when we predict inactive, how often are we right? npv[j] = tn / (tn + fn) @@ -199,7 +200,6 @@ def layer_metrics( acc[j] = accuracy_score(y, yhat) bacc[j] = balanced_accuracy_score(y, yhat) - return { "ap": ap, "acc": acc, diff --git a/spd/clustering/ci_dt/feature_names.py b/spd/clustering/ci_dt/feature_names.py index 4f96332d7..fc43b33b8 100644 --- a/spd/clustering/ci_dt/feature_names.py +++ b/spd/clustering/ci_dt/feature_names.py @@ -1,8 +1,7 @@ """Generate feature names for decision tree visualization with activation and decoding info.""" -import numpy as np import torch -from jaxtyping import Bool, Float +from jaxtyping import Float from torch import Tensor from spd.models.component_model import ComponentModel @@ -123,4 +122,4 @@ def get_component_directions( read_direction = V[:, component_idx] # [d_in] write_direction = U[component_idx, :] # [d_out] - return read_direction, write_direction \ No newline at end of file + return read_direction, write_direction diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 4b0c10750..7747e41c8 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -173,10 +173,7 @@ def compute_tree_metrics( per_layer_stats.append( { **metrics, - **{ - f"mean_{key}": float(np.nanmean(values)) - for key, values in metrics.items() - } + **{f"mean_{key}": float(np.nanmean(values)) for key, values in metrics.items()}, } ) for j, apj in enumerate(metrics["ap"]): diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index fc7cfa6c0..5317f0ddf 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -4,9 +4,9 @@ import matplotlib.pyplot as plt import numpy as np +import torch from jaxtyping import Float, Int from sklearn.tree import plot_tree -import torch from spd.clustering.ci_dt.core import LayerModel, MetricKey, get_estimator_for @@ -455,8 +455,18 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La depth_arr: np.ndarray = np.array(depth_list) # Plot baseline: for uncorrelated variables, expected AP = prevalence - prev_range: np.ndarray = np.logspace(np.log10(prevalence_arr.min()), np.log10(prevalence_arr.max()), 100) - ax.plot(prev_range, prev_range, 'k--', alpha=0.5, linewidth=1.5, label='Random baseline (AP = prevalence)', zorder=1) + prev_range: np.ndarray = np.logspace( + np.log10(prevalence_arr.min()), np.log10(prevalence_arr.max()), 100 + ) + ax.plot( + prev_range, + prev_range, + "k--", + alpha=0.5, + linewidth=1.5, + label="Random baseline (AP = prevalence)", + zorder=1, + ) scatter = ax.scatter( prevalence_arr, @@ -472,7 +482,8 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La ax.set_title( r"Average Precision vs Component Prevalence" + "\n" - r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$ where $P_n = \frac{\text{TP}}{\text{TP}+\text{FP}}$, $R_n = \frac{\text{TP}}{\text{TP}+\text{FN}}$" + "\n" + r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$ where $P_n = \frac{\text{TP}}{\text{TP}+\text{FP}}$, $R_n = \frac{\text{TP}}{\text{TP}+\text{FN}}$" + + "\n" r"Colored by tree depth" ) ax.set_xlabel("Prevalence (log scale)") @@ -480,7 +491,7 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La ax.set_xscale("log") ax.set_ylim(-0.05, 1.05) ax.grid(True, alpha=0.3) - ax.legend(loc='lower right') + ax.legend(loc="lower right") cbar = plt.colorbar(scatter, ax=ax) cbar.set_label("Tree Depth") @@ -489,7 +500,7 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La def plot_component_activity_breakdown( - component_acts: dict[str, np.ndarray|torch.Tensor], + component_acts: dict[str, np.ndarray | torch.Tensor], module_keys: list[str], activation_threshold: float, logy: bool = False, diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index bd94ab3f6..3e4de124a 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -29,8 +29,8 @@ from spd.models.component_model import ComponentModel, SPDRunInfo # magic autoreload -%load_ext autoreload -%autoreload 2 +# %load_ext autoreload +# %autoreload 2 # %% # ----------------------- configuration ----------------------- @@ -40,7 +40,7 @@ # n_batches=8, batch_size=2, n_batches=2, - n_ctx=16, + n_ctx=16, activation_threshold=0.01, max_depth=3, random_state=42, @@ -166,7 +166,6 @@ plot_ap_vs_prevalence(per_layer_stats, models) - # %% # ----------------------- plots: tree statistics ----------------------- From 5e764da9c829671e8f157c6cd5681f6565d91b65 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 11:36:33 +0000 Subject: [PATCH 65/77] some type fixes --- spd/clustering/ci_dt/attn.py | 4 +- spd/clustering/ci_dt/ci_dt_old.py | 9 +- spd/clustering/ci_dt/feature_names.py | 125 -------------------------- spd/clustering/ci_dt/plot.py | 23 ++--- 4 files changed, 16 insertions(+), 145 deletions(-) delete mode 100644 spd/clustering/ci_dt/feature_names.py diff --git a/spd/clustering/ci_dt/attn.py b/spd/clustering/ci_dt/attn.py index 2ff5a7533..31b3a38c2 100644 --- a/spd/clustering/ci_dt/attn.py +++ b/spd/clustering/ci_dt/attn.py @@ -1,6 +1,8 @@ # %% """Attention pattern visualization for CI decision tree analysis.""" +from typing import Any + import matplotlib.pyplot as plt import numpy as np import torch @@ -78,7 +80,7 @@ def extract_attention_patterns_multibatch( model: ComponentModel, device: torch.device | str, - dataloader: DataLoader, + dataloader: DataLoader[Any], n_batches: int, ) -> dict[str, Float[Tensor, "total_samples n_heads seq_len seq_len"]]: """Extract attention patterns over multiple batches. diff --git a/spd/clustering/ci_dt/ci_dt_old.py b/spd/clustering/ci_dt/ci_dt_old.py index 3ae33aa07..40af0ddf9 100644 --- a/spd/clustering/ci_dt/ci_dt_old.py +++ b/spd/clustering/ci_dt/ci_dt_old.py @@ -115,11 +115,12 @@ def predict_k( lm: LayerModel = next(m for m in models if m.layer_index == k) X: np.ndarray = concat_cols(prefix_layers) proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore - if isinstance(proba, list): - P: np.ndarray = np.stack([p[:, 1] for p in proba], axis=1) + P_: np.ndarray + if isinstance(proba, list): # noqa: SIM108 + P_ = np.stack([p[:, 1] for p in proba], axis=1) else: - P = proba[..., 1] # type: ignore - Y_hat: np.ndarray = (float(threshold) <= P).astype(bool) + P_ = proba[..., 1] # type: ignore + Y_hat: np.ndarray = (float(threshold) <= P_).astype(bool) return Y_hat diff --git a/spd/clustering/ci_dt/feature_names.py b/spd/clustering/ci_dt/feature_names.py deleted file mode 100644 index fc43b33b8..000000000 --- a/spd/clustering/ci_dt/feature_names.py +++ /dev/null @@ -1,125 +0,0 @@ -"""Generate feature names for decision tree visualization with activation and decoding info.""" - -import torch -from jaxtyping import Float -from torch import Tensor - -from spd.models.component_model import ComponentModel -from spd.models.components import EmbeddingComponents, LinearComponents - - -def get_embed_unembed_matrices( - model: ComponentModel, -) -> tuple[Float[Tensor, "vocab d_model"], Float[Tensor, "d_model vocab"]]: - """Extract embedding and unembedding matrices from the target model. - - For GPT-2 style models, returns (wte.weight, lm_head.weight or wte.weight.T) - For LLaMA style models, returns (embed_tokens.weight, lm_head.weight) - - Returns: - embed: Embedding matrix [vocab_size, d_model] - unembed: Unembedding matrix [d_model, vocab_size] - """ - target_model = model.target_model - - # Try to find embedding layer (GPT-2 style) - if hasattr(target_model, "transformer") and hasattr(target_model.transformer, "wte"): - embed = target_model.transformer.wte.weight # [vocab, d_model] - # Try LLaMA style - elif hasattr(target_model, "model") and hasattr(target_model.model, "embed_tokens"): - embed = target_model.model.embed_tokens.weight # [vocab, d_model] - else: - raise ValueError( - "Could not find embedding layer. Expected transformer.wte or model.embed_tokens" - ) - - # Try to find unembedding layer - if hasattr(target_model, "lm_head"): - unembed = target_model.lm_head.weight.T # [d_model, vocab] - elif hasattr(target_model, "transformer") and hasattr(target_model.transformer, "wte"): - # For tied embeddings, unembed is transpose of embed - unembed = target_model.transformer.wte.weight.T # [d_model, vocab] - else: - raise ValueError("Could not find unembedding layer (lm_head)") - - return embed, unembed - - -def decode_direction_top_k( - direction: Float[Tensor, " d_model"], - embed: Float[Tensor, "vocab d_model"], - unembed: Float[Tensor, "d_model vocab"], - tokenizer, - k: int = 3, - use_embed: bool = True, -) -> str: - """Decode a direction vector to top-k tokens. - - Args: - direction: Direction vector in d_model space - embed: Embedding matrix [vocab, d_model] - unembed: Unembedding matrix [d_model, vocab] - tokenizer: Tokenizer for converting token IDs to strings - k: Number of top tokens to return - use_embed: If True, use embed matrix; if False, use unembed matrix - - Returns: - String representation of top-k tokens - """ - if use_embed: - # Project direction onto embedding space: compute cosine similarity - # direction: [d_model], embed: [vocab, d_model] - direction_norm = direction / (direction.norm() + 1e-8) - embed_norm = embed / (embed.norm(dim=1, keepdim=True) + 1e-8) - similarities = torch.matmul(embed_norm, direction_norm) # [vocab] - else: - # Project direction onto unembedding space - # direction: [d_model], unembed: [d_model, vocab] - logits = torch.matmul(direction, unembed) # [vocab] - similarities = logits - - # Get top-k tokens - top_k_values, top_k_indices = torch.topk(similarities, k) - - # Decode tokens - tokens = [] - for idx, val in zip(top_k_indices.tolist(), top_k_values.tolist(), strict=False): - token_str = tokenizer.decode([idx]) - # Clean up token string for display - token_str = repr(token_str)[1:-1] # Remove quotes and escape special chars - tokens.append(f"{token_str}({val:.2f})") - - return ",".join(tokens) - - -def get_component_directions( - component_model: ComponentModel, - module_key: str, - component_idx: int, -) -> tuple[Float[Tensor, " d_in"], Float[Tensor, " d_out"]]: - """Get read (V) and write (U) direction vectors for a component. - - Args: - component_model: The ComponentModel containing components - module_key: Key identifying the module (e.g., "transformer.h.0.attn.c_attn") - component_idx: Index of the component - - Returns: - read_direction: V[:, component_idx] - the read direction [d_in] - write_direction: U[component_idx, :] - the write direction [d_out] - """ - # Get the component module - component = component_model.components[module_key] - - assert isinstance(component, LinearComponents | EmbeddingComponents), ( - f"Expected LinearComponents or EmbeddingComponents, got {type(component)}" - ) - - # Extract V and U - V = component.V # [d_in, C] or [vocab, C] for embedding - U = component.U # [C, d_out] - - read_direction = V[:, component_idx] # [d_in] - write_direction = U[component_idx, :] # [d_out] - - return read_direction, write_direction diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 5317f0ddf..fe9fa5f48 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -329,6 +329,7 @@ def plot_activations( # Add diff plot if sorted if sample_order is not None: A_diff: Float[np.ndarray, "n_samples n_components"] = A_pred - A_true + assert isinstance(ax3, plt.Axes) im3 = ax3.imshow( A_diff.T, aspect="auto", interpolation="nearest", cmap="RdBu_r", vmin=-1, vmax=1 ) @@ -602,7 +603,6 @@ def plot_selected_trees( picks: list[tuple[int, int, float]], title_prefix: str, models: list[LayerModel], - feature_names: list[list[str]] | None = None, ) -> None: """Plot a list of selected trees by (layer, target_idx, score). @@ -610,8 +610,6 @@ def plot_selected_trees( picks: List of (layer_idx, target_idx, score) tuples identifying trees to plot title_prefix: Prefix for plot titles (e.g. "Best" or "Worst") models: Trained LayerModel objects - feature_names: Optional list of feature name lists, one per layer. - feature_names[k] contains names for all features used to predict layer k. """ for layer_idx, target_idx, score in picks: est = get_estimator_for(models, layer_idx, target_idx) @@ -619,12 +617,7 @@ def plot_selected_trees( ax = fig.add_subplot(1, 1, 1) ax.set_title(f"{title_prefix}: layer {layer_idx}, target {target_idx}, AP={score:.3f}") - # Get feature names for this layer if available - feat_names = None - if feature_names is not None and 0 <= layer_idx < len(feature_names): - feat_names = feature_names[layer_idx] - - plot_tree(est, ax=ax, filled=False, feature_names=feat_names) + plot_tree(est, ax=ax, filled=False) fig.tight_layout() @@ -661,21 +654,21 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st stats = extract_tree_stats(models, per_layer_stats) # Distribution of tree depths - fig1, ax1 = plt.subplots() + _fig1, ax1 = plt.subplots() ax1.hist(stats["depth"], bins=range(int(stats["depth"].max()) + 2)) ax1.set_yscale("log") ax1.set_xlabel("Tree depth") ax1.set_ylabel("Count (log scale)") # Distribution of leaf counts - fig2, ax2 = plt.subplots() + _fig2, ax2 = plt.subplots() ax2.hist(stats["n_leaves"], bins=50) ax2.set_yscale("log") ax2.set_xlabel("Number of leaves") ax2.set_ylabel("Count (log scale)") # Distribution of accuracies - fig3, ax3 = plt.subplots() + _fig3, ax3 = plt.subplots() ax3.hist(stats["accuracy"][~np.isnan(stats["accuracy"])], bins=30) ax3.set_yscale("log") ax3.set_xlabel("Accuracy") @@ -692,7 +685,7 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st stats["depth"][valid_mask], stats["accuracy"][valid_mask], bins=[depth_bins, acc_bins] ) - fig4, ax4 = plt.subplots() + _fig4, ax4 = plt.subplots() heatmap_log: Float[np.ndarray, "depth_bins acc_bins"] = np.log10( heatmap_depth_acc.T + 1 ) # +1 to avoid log(0) @@ -719,7 +712,7 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st stats["n_leaves"][valid_mask], stats["accuracy"][valid_mask], bins=[leaf_bins, acc_bins] ) - fig5, ax5 = plt.subplots() + _fig5, ax5 = plt.subplots() heatmap_log = np.log10(heatmap_leaf_acc.T + 1) im = ax5.imshow(heatmap_log, origin="lower", aspect="auto", cmap="Blues") ax5.set_xticks(range(len(leaf_bins) - 1)) @@ -741,7 +734,7 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st stats["depth"][valid_mask], stats["n_leaves"][valid_mask], bins=[depth_bins, leaf_bins] ) - fig6, ax6 = plt.subplots() + _fig6, ax6 = plt.subplots() heatmap_log = np.log10(heatmap_depth_leaf.T + 1) im = ax6.imshow(heatmap_log, origin="lower", aspect="auto", cmap="Blues") ax6.set_xticks(range(len(depth_bins) - 1)) From 4b52539f5d668dc09dbbff8530c3ce2d9f475334 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 11:38:22 +0000 Subject: [PATCH 66/77] delete old script --- spd/clustering/ci_dt/ci_dt_old.py | 423 ------------------------------ 1 file changed, 423 deletions(-) delete mode 100644 spd/clustering/ci_dt/ci_dt_old.py diff --git a/spd/clustering/ci_dt/ci_dt_old.py b/spd/clustering/ci_dt/ci_dt_old.py deleted file mode 100644 index 40af0ddf9..000000000 --- a/spd/clustering/ci_dt/ci_dt_old.py +++ /dev/null @@ -1,423 +0,0 @@ -# %% - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Literal - -import matplotlib.pyplot as plt -import numpy as np -import torch -from jaxtyping import Bool, Float -from sklearn.base import ClassifierMixin -from sklearn.metrics import ( - accuracy_score, - average_precision_score, - balanced_accuracy_score, -) -from sklearn.multioutput import MultiOutputClassifier -from sklearn.tree import DecisionTreeClassifier, plot_tree -from torch import Tensor - -from spd.clustering.activations import ( - ProcessedActivations, - component_activations, - process_activations, -) -from spd.configs import Config -from spd.data import DatasetConfig, create_data_loader -from spd.experiments.lm.configs import LMTaskConfig -from spd.models.component_model import ComponentModel, SPDRunInfo - -# ----------------------- config ----------------------- - - -@dataclass -class CIDTConfig: - """Configuration for causal importance decision tree training.""" - - experiment_key: str = "ss_emb" # Key from EXPERIMENT_REGISTRY - n_samples: int = 250 - activation_threshold: float = 0.01 # Threshold for boolean conversion - filter_dead_threshold: float = 0.001 # Threshold for filtering dead components - max_depth: int = 8 # Maximum depth for decision trees - random_state: int = 7 # Random state for reproducibility - - -# ----------------------- library code ----------------------- - - -@dataclass -class LayerModel: - """Holds a trained per-layer model.""" - - layer_index: int - model: ClassifierMixin - feature_dim: int - target_dim: int - - -def concat_cols( - Xs: Sequence[Bool[np.ndarray, "n_samples n_features"]], -) -> Bool[np.ndarray, "n_samples n_concat"]: - """Column-concat a sequence or return empty (n,0).""" - n_samples: int = Xs[0].shape[0] if len(Xs) else 0 - return np.concatenate(Xs, axis=1) if len(Xs) else np.zeros((n_samples, 0), bool) - - -def build_xy( - layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], -) -> list[ - tuple[ - Bool[np.ndarray, "n_samples n_features"], - Bool[np.ndarray, "n_samples n_targets"], - ] -]: - """Return (X_k,Y_k) for k=1..L-1 with X_k=concat(layers[:k]).""" - XYs: list[tuple[np.ndarray, np.ndarray]] = [] - for k in range(1, len(layers)): - X_k: np.ndarray = concat_cols(layers[:k]) - Y_k: np.ndarray = layers[k] - XYs.append((X_k, Y_k)) - return XYs - - -def train_trees( - layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], - *, - strategy: Literal["one_vs_all", "single_tree"] = "one_vs_all", - max_depth: int | None = None, - min_samples_leaf: int = 1, - random_state: int | None = 0, -) -> list[LayerModel]: - """Train one model per target layer using previous layers as features.""" - XYs = build_xy(layers) - models: list[LayerModel] = [] - for k, (X_k, Y_k) in enumerate(XYs, start=1): - base = DecisionTreeClassifier( - max_depth=max_depth, - min_samples_leaf=min_samples_leaf, - random_state=random_state, - ) - model: ClassifierMixin = MultiOutputClassifier(base) if strategy == "one_vs_all" else base - _ = model.fit(X_k.astype(np.uint8), Y_k.astype(np.uint8)) - models.append(LayerModel(k, model, int(X_k.shape[1]), int(Y_k.shape[1]))) - return models - - -def predict_k( - models: Sequence[LayerModel], - prefix_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], - k: int, - *, - threshold: float = 0.5, -) -> Bool[np.ndarray, "n_samples n_components_k"]: - """Predict layer k activations from layers[:k].""" - lm: LayerModel = next(m for m in models if m.layer_index == k) - X: np.ndarray = concat_cols(prefix_layers) - proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore - P_: np.ndarray - if isinstance(proba, list): # noqa: SIM108 - P_ = np.stack([p[:, 1] for p in proba], axis=1) - else: - P_ = proba[..., 1] # type: ignore - Y_hat: np.ndarray = (float(threshold) <= P_).astype(bool) - return Y_hat - - -def predict_all( - models: Sequence[LayerModel], - seed_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], - *, - thresholds: Sequence[float] | None = None, -) -> list[Bool[np.ndarray, "n_samples n_components"]]: - """Sequentially predict layers 1.. using layer 0 as seed.""" - out: list[np.ndarray] = [seed_layers[0].copy()] - ths: list[float] = list(thresholds) if thresholds is not None else [] - for i, lm in enumerate(sorted(models, key=lambda m: m.layer_index)): - thr: float = ths[i] if i < len(ths) else 0.5 - out.append(predict_k(models, out, lm.layer_index, threshold=thr)) - return out - - -# ----------------------- configuration ----------------------- - -config = CIDTConfig() -device: str = "cuda" if torch.cuda.is_available() else "cpu" - -# ----------------------- load model ----------------------- - -wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" - -spd_run: SPDRunInfo = SPDRunInfo.from_path(wandb_run_path) -model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) -model.to(device) -cfg: Config = spd_run.config - -print(f"Loaded model from {wandb_run_path}") - -# ----------------------- load dataset ----------------------- - -# Create LM dataset and dataloader -assert isinstance(cfg.task_config, LMTaskConfig) -pretrained_model_name = cfg.pretrained_model_name -assert pretrained_model_name is not None - -dataset_config = DatasetConfig( - name=cfg.task_config.dataset_name, - hf_tokenizer_path=pretrained_model_name, - split=cfg.task_config.train_data_split, - n_ctx=cfg.task_config.max_seq_len, - column_name=cfg.task_config.column_name, - is_tokenized=False, - streaming=False, - seed=0, -) -dataloader, _ = create_data_loader( - dataset_config=dataset_config, - batch_size=config.n_samples, - buffer_size=cfg.task_config.buffer_size, - global_seed=cfg.seed, - ddp_rank=0, - ddp_world_size=1, -) -batch_data = next(iter(dataloader)) -batch: Tensor = batch_data["input_ids"] -print(f"Created LM dataset with {cfg.task_config.dataset_name}, batch shape: {batch.shape}") - -# ----------------------- get activations ----------------------- - -# Get component activations (on device) -print("Computing component activations...") -component_acts: dict[str, Tensor] = component_activations( - model=model, - device=device, - batch=batch, -) - -# Process activations (filter dead components, concatenate) -print("Processing activations...") -processed_acts: ProcessedActivations = process_activations( - component_acts, - filter_dead_threshold=config.filter_dead_threshold, - seq_mode="seq_mean", # LM task needs seq_mean -) - -print(f"Total components (before filtering): {processed_acts.n_components_original}") -print(f"Alive components: {processed_acts.n_components_alive}") -print(f"Dead components: {processed_acts.n_components_dead}") -print(f"Module keys: {processed_acts.module_keys}") - -# ----------------------- convert to layers ----------------------- - -# Move to CPU and convert to numpy for sklearn -# Group by module to create "layers" for decision trees -print("\nConverting to boolean layers...") -layers_true: list[np.ndarray] = [] -for module_key in processed_acts.module_keys: - # Get the activations for this module from activations_raw, move to CPU - module_acts_cpu = processed_acts.activations_raw[module_key].cpu().numpy() - module_acts_bool = (module_acts_cpu >= config.activation_threshold).astype(bool) - layers_true.append(module_acts_bool) - print(f"Layer {len(layers_true) - 1} ({module_key}): {module_acts_bool.shape[1]} components") - -print(f"\nCreated {len(layers_true)} layers for decision tree training") - -# ----------------------- fit and predict ----------------------- - -print("\nTraining decision trees...") -models: list[LayerModel] = train_trees( - layers_true, max_depth=config.max_depth, random_state=config.random_state -) -layers_pred: list[np.ndarray] = predict_all(models, [layers_true[0]]) - -# ----------------------- metrics ----------------------- - - -def layer_metrics( - Y_true: Bool[np.ndarray, "n t"], - Y_prob: Float[np.ndarray, "n t"], - Y_pred: Bool[np.ndarray, "n t"], -) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Return per-target AP, acc, bacc, prevalence.""" - T: int = Y_true.shape[1] - ap: np.ndarray = np.zeros(T) - acc: np.ndarray = np.zeros(T) - bacc: np.ndarray = np.zeros(T) - prev: np.ndarray = np.zeros(T) - for j in range(T): - y: np.ndarray = Y_true[:, j].astype(int) - p: np.ndarray = Y_prob[:, j] - yhat: np.ndarray = Y_pred[:, j].astype(int) - prev[j] = float(y.mean()) - try: - ap[j] = average_precision_score(y, p) - except Exception: - ap[j] = np.nan - try: - acc[j] = accuracy_score(y, yhat) - except Exception: - acc[j] = np.nan - try: - bacc[j] = balanced_accuracy_score(y, yhat) - except Exception: - bacc[j] = np.nan - return ap, acc, bacc, prev - - -# get probabilities for each layer -def proba_for_layer(lm: LayerModel, X: np.ndarray) -> np.ndarray: - """Return P(y=1) per target column.""" - pr = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore - if isinstance(pr, list): - return np.stack([p[:, 1] for p in pr], axis=1) - return pr[..., 1] # type: ignore - - -XYs_demo = build_xy(layers_true) -per_layer_stats: list[dict[str, Any]] = [] -all_triplets: list[tuple[int, int, float]] = [] # (layer, target_idx, AP) - -for lm, (Xk, Yk) in zip(models, XYs_demo, strict=True): - Pk: np.ndarray = proba_for_layer(lm, Xk) - Yhat_k: np.ndarray = Pk >= 0.5 - ap, acc, bacc, prev = layer_metrics(Yk, Pk, Yhat_k) - per_layer_stats.append( - { - "ap": ap, - "acc": acc, - "bacc": bacc, - "prev": prev, - "mean_ap": float(np.nanmean(ap)), - "mean_acc": float(np.nanmean(acc)), - "mean_bacc": float(np.nanmean(bacc)), - } - ) - for j, apj in enumerate(ap): - all_triplets.append((lm.layer_index, j, float(apj))) - -# identify best and worst trees across all outputs by AP -sorted_triplets = sorted(all_triplets, key=lambda t: (np.isnan(t[2]), t[2])) -worst_list = [t for t in sorted_triplets if not np.isnan(t[2])][:2] -best_list = [t for t in sorted_triplets if not np.isnan(t[2])][-2:] - - -# pull corresponding estimators (MultiOutputClassifier -> estimators_ list) -def get_estimator_for( - models: list[LayerModel], layer_idx: int, target_idx: int -) -> DecisionTreeClassifier: - """Fetch the per-output estimator for a given layer and column.""" - lm = next(m for m in models if m.layer_index == layer_idx) - if isinstance(lm.model, MultiOutputClassifier): - return lm.model.estimators_[target_idx] # type: ignore - return lm.model # type: ignore - - -# ----------------------- plotting ----------------------- - - -# 1) Single fig showing activations across all layers (true vs predicted stacked) -def plot_activations(layers_true: list[np.ndarray], layers_pred: list[np.ndarray]) -> None: - """Show true and predicted activations as heatmaps.""" - A_true: np.ndarray = np.concatenate(layers_true, axis=1) - A_pred: np.ndarray = np.concatenate([layers_pred[0]] + layers_pred[1:], axis=1) - fig1 = plt.figure(figsize=(10, 6)) - ax1 = fig1.add_subplot(2, 1, 1) - ax1.set_title("Activations (True)") - ax1.imshow(A_true, aspect="auto", interpolation="nearest") - ax1.set_xlabel("components (all layers concatenated)") - ax1.set_ylabel("samples") - ax2 = fig1.add_subplot(2, 1, 2) - ax2.set_title("Activations (Predicted)") - ax2.imshow(A_pred, aspect="auto", interpolation="nearest") - ax2.set_xlabel("components (all layers concatenated)") - ax2.set_ylabel("samples") - fig1.tight_layout() - - -# 2) Covariance matrix of all components -def plot_covariance(layers_true: list[np.ndarray]) -> None: - """Plot covariance between all components across layers.""" - A: np.ndarray = np.concatenate(layers_true, axis=1).astype(float) - C: np.ndarray = np.cov(A, rowvar=False) - fig2 = plt.figure(figsize=(6, 6)) - ax = fig2.add_subplot(1, 1, 1) - ax.set_title("Covariance of components (all layers)") - _im = ax.imshow(C, aspect="auto", interpolation="nearest") - ax.set_xlabel("component index") - ax.set_ylabel("component index") - fig2.tight_layout() - - -# 3) Accuracy ideas: bar of mean metrics per layer; scatter of prevalence vs AP -def plot_layer_metrics(per_layer_stats: list[dict[str, Any]]) -> None: - """Plot summary metrics per layer and per-target AP vs prevalence.""" - L: int = len(per_layer_stats) - mean_ap: np.ndarray = np.array([d["mean_ap"] for d in per_layer_stats]) - mean_acc: np.ndarray = np.array([d["mean_acc"] for d in per_layer_stats]) - mean_bacc: np.ndarray = np.array([d["mean_bacc"] for d in per_layer_stats]) - - # bar: mean AP, ACC, BACC per layer (three separate figures to respect one-plot rule) - fig3 = plt.figure(figsize=(8, 3)) - ax3 = fig3.add_subplot(1, 1, 1) - ax3.set_title("Mean Average Precision per layer") - ax3.bar(np.arange(1, L + 1), mean_ap) - ax3.set_xlabel("layer index (target)") - ax3.set_ylabel("mean AP") - fig3.tight_layout() - - fig4 = plt.figure(figsize=(8, 3)) - ax4 = fig4.add_subplot(1, 1, 1) - ax4.set_title("Mean Accuracy per layer") - ax4.bar(np.arange(1, L + 1), mean_acc) - ax4.set_xlabel("layer index (target)") - ax4.set_ylabel("mean accuracy") - fig4.tight_layout() - - fig5 = plt.figure(figsize=(8, 3)) - ax5 = fig5.add_subplot(1, 1, 1) - ax5.set_title("Mean Balanced Accuracy per layer") - ax5.bar(np.arange(1, L + 1), mean_bacc) - ax5.set_xlabel("layer index (target)") - ax5.set_ylabel("mean balanced accuracy") - fig5.tight_layout() - - # scatter: prevalence vs AP for all targets across layers - fig6 = plt.figure(figsize=(6, 5)) - ax6 = fig6.add_subplot(1, 1, 1) - ax6.set_title("Per-target AP vs prevalence") - x_list: list[float] = [] - y_list: list[float] = [] - for d in per_layer_stats: - x_list.extend(list(d["prev"])) - y_list.extend(list(d["ap"])) - ax6.scatter(x_list, y_list, alpha=0.6) - ax6.set_xlabel("prevalence") - ax6.set_ylabel("average precision") - fig6.tight_layout() - - -# 4) Display a couple decision trees (worst and best by AP) -def plot_selected_trees( - picks: list[tuple[int, int, float]], - title_prefix: str, - models: list[LayerModel], - feature_dims_prefix: list[int], -) -> None: - """Plot a list of selected trees by (layer, target_idx, score).""" - for layer_idx, target_idx, score in picks: - est = get_estimator_for(models, layer_idx, target_idx) - fig = plt.figure(figsize=(10, 6)) - ax = fig.add_subplot(1, 1, 1) - ax.set_title(f"{title_prefix}: layer {layer_idx}, target {target_idx}, AP={score:.3f}") - plot_tree(est, ax=ax, filled=False) # default styling - fig.tight_layout() - - -# Run the plots -plot_activations(layers_true, layers_pred) -plot_covariance(layers_true) -plot_layer_metrics(per_layer_stats) -plot_selected_trees(worst_list, "Worst", models, []) -plot_selected_trees(best_list, "Best", models, []) - -print("Plots generated.") From bee64e3dcbfb0bca1fd61518a6be67c6fe2d9fed Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 11:45:52 +0000 Subject: [PATCH 67/77] more type fixes --- spd/clustering/ci_dt/pipeline.py | 4 +-- spd/clustering/ci_dt/plot.py | 54 ++++++++++++++++++-------------- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 7747e41c8..8126764f1 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -18,7 +18,7 @@ def compute_activations_multibatch( model: ComponentModel, device: torch.device | str, - dataloader: DataLoader, + dataloader: DataLoader[dict[str, Any]], n_batches: int, ) -> dict[str, Tensor]: """Compute activations over multiple batches, concatenate on CPU. @@ -100,12 +100,10 @@ def convert_to_boolean_layers( # Flatten if 3D (batch, seq_len, n_components) -> (batch*seq_len, n_components) if module_acts_tensor.ndim == 3: - print(f" {module_key}: original shape = {module_acts_tensor.shape}") # Keep last dimension (n_components) intact, flatten first two dimensions module_acts_np: Float[np.ndarray, "n_samples n_components"] = ( module_acts_tensor.reshape(-1, module_acts_tensor.shape[-1]).numpy() ) - print(f" {module_key}: flattened shape = {module_acts_np.shape}") else: module_acts_np = module_acts_tensor.numpy() diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index fe9fa5f48..9e215546d 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -1,11 +1,12 @@ """Plotting functions for causal importance decision trees.""" +from collections.abc import Mapping from typing import Any import matplotlib.pyplot as plt import numpy as np import torch -from jaxtyping import Float, Int +from jaxtyping import Bool, Float, Int from sklearn.tree import plot_tree from spd.clustering.ci_dt.core import LayerModel, MetricKey, get_estimator_for @@ -163,14 +164,13 @@ def greedy_sort(A: np.ndarray, axis: int) -> np.ndarray: Array of indices in sorted order """ # Transpose if sorting columns - if axis == 1: - A = A.T + activations_normalized: Float[np.ndarray, "n d"] = A.T if axis == 1 else A # Compute cosine similarity - norms: Float[np.ndarray, "n 1"] = np.linalg.norm(A, axis=1, keepdims=True) + norms: Float[np.ndarray, "n 1"] = np.linalg.norm(activations_normalized, axis=1, keepdims=True) norms = np.where(norms > 1e-8, norms, 1.0) # Avoid division by zero - A_normalized: Float[np.ndarray, "n d"] = A / norms - similarity: Float[np.ndarray, "n n"] = A_normalized @ A_normalized.T + activations_normalized = activations_normalized / norms + similarity: Float[np.ndarray, "n n"] = activations_normalized @ activations_normalized.T # Start from most central item (highest average similarity) n: int = similarity.shape[0] @@ -298,6 +298,7 @@ def plot_activations( # Determine number of subplots n_plots: int = 3 if sample_order is not None else 2 fig, axes = plt.subplots(n_plots, 1, figsize=(12, 6 * n_plots)) + ax3: plt.Axes | None = None if n_plots == 2: ax1, ax2 = axes else: @@ -361,18 +362,20 @@ def plot_covariance( ) # Apply component ordering if provided + activations_concat: Float[np.ndarray, "n_samples n_components"] if component_order is not None: - A = A[:, component_order] + activations_concat = A[:, component_order] sorted_label: str = " (Sorted by Component Similarity)" xlabel: str = "Component index (sorted)" ylabel: str = "Component index (sorted)" else: + activations_concat = A sorted_label = "" xlabel = "Component index" ylabel = "Component index" # Compute covariance - C: Float[np.ndarray, "n_components n_components"] = np.cov(A, rowvar=False) + C: Float[np.ndarray, "n_components n_components"] = np.cov(activations_concat, rowvar=False) # Center colormap on 0 vmax: float = float(np.abs(C).max()) @@ -501,7 +504,7 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La def plot_component_activity_breakdown( - component_acts: dict[str, np.ndarray | torch.Tensor], + component_acts: Mapping[str, np.ndarray | torch.Tensor], module_keys: list[str], activation_threshold: float, logy: bool = False, @@ -521,23 +524,26 @@ def plot_component_activity_breakdown( n_always_alive_list: list[int] = [] for module_key in module_keys: - acts: np.ndarray = component_acts[module_key] + acts_raw: np.ndarray | torch.Tensor = component_acts[module_key] # Convert to numpy if needed - if hasattr(acts, "cpu"): - acts = acts.cpu().numpy() + acts_np: Float[np.ndarray, "n_samples n_components"] = ( + acts_raw.cpu().numpy() if isinstance(acts_raw, torch.Tensor) else acts_raw + ) # Flatten if 3D (batch, seq_len, n_components) -> (batch*seq_len, n_components) # This treats each token position as a separate sample, consistent with decision tree training - if acts.ndim == 3: - acts = acts.reshape(-1, acts.shape[-1]) + if acts_np.ndim == 3: + acts_np = acts_np.reshape(-1, acts_np.shape[-1]) # Convert to boolean - acts_bool: np.ndarray = (acts >= activation_threshold).astype(bool) + acts_bool: Bool[np.ndarray, "n_samples n_components"] = ( + acts_np >= activation_threshold + ).astype(bool) # Count each category - always_dead: np.ndarray = ~acts_bool.any(axis=0) - always_alive: np.ndarray = acts_bool.all(axis=0) - varying: np.ndarray = ~(always_dead | always_alive) + always_dead: Bool[np.ndarray, " n_components"] = ~acts_bool.any(axis=0) + always_alive: Bool[np.ndarray, " n_components"] = acts_bool.all(axis=0) + varying: Bool[np.ndarray, " n_components"] = ~(always_dead | always_alive) n_always_dead_list.append(int(always_dead.sum())) n_always_alive_list.append(int(always_alive.sum())) @@ -698,9 +704,9 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st ax4.set_ylabel("Accuracy") for i in range(len(depth_bins) - 1): for j in range(len(acc_bins) - 1): - count: int = int(heatmap_depth_acc[i, j]) - if count > 0: - ax4.text(i, j, str(count), ha="center", va="center") + tree_count: int = int(heatmap_depth_acc[i, j]) + if tree_count > 0: + ax4.text(i, j, str(tree_count), ha="center", va="center") plt.colorbar(im, ax=ax4, label="log10(count+1)") # Heatmap: leaf count vs accuracy @@ -723,9 +729,9 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st ax5.set_ylabel("Accuracy") for i in range(len(leaf_bins) - 1): for j in range(len(acc_bins) - 1): - count: int = int(heatmap_leaf_acc[i, j]) - if count > 0: - ax5.text(i, j, str(count), ha="center", va="center") + tree_count: int = int(heatmap_leaf_acc[i, j]) + if tree_count > 0: + ax5.text(i, j, str(tree_count), ha="center", va="center") plt.colorbar(im, ax=ax5, label="log10(count+1)") # Heatmap: depth vs leaf count From fd27dbc58163c50c4d7531f923274dd80af5710d Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 11:49:20 +0000 Subject: [PATCH 68/77] wip --- spd/clustering/ci_dt/plot.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 9e215546d..690d55f43 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -729,9 +729,9 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st ax5.set_ylabel("Accuracy") for i in range(len(leaf_bins) - 1): for j in range(len(acc_bins) - 1): - tree_count: int = int(heatmap_leaf_acc[i, j]) - if tree_count > 0: - ax5.text(i, j, str(tree_count), ha="center", va="center") + leaf_tree_count: int = int(heatmap_leaf_acc[i, j]) + if leaf_tree_count > 0: + ax5.text(i, j, str(leaf_tree_count), ha="center", va="center") plt.colorbar(im, ax=ax5, label="log10(count+1)") # Heatmap: depth vs leaf count @@ -751,9 +751,9 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st ax6.set_ylabel("Number of leaves") for i in range(len(depth_bins) - 1): for j in range(len(leaf_bins) - 1): - count: int = int(heatmap_depth_leaf[i, j]) - if count > 0: - ax6.text(i, j, str(count), ha="center", va="center") + cell_count: int = int(heatmap_depth_leaf[i, j]) + if cell_count > 0: + ax6.text(i, j, str(cell_count), ha="center", va="center") plt.colorbar(im, ax=ax6, label="log10(count+1)") # Heatmap: AP vs prevalence @@ -807,9 +807,9 @@ def plot_tree_statistics(models: list[LayerModel], per_layer_stats: list[dict[st # Add counts to cells for i in range(len(prev_bins) - 1): for j in range(len(ap_bins_heatmap) - 1): - count = int(heatmap_prev_ap[i, j]) - if count > 0: - ax7.text(i, j, str(count), ha="center", va="center", fontsize=8) + heatmap_count = int(heatmap_prev_ap[i, j]) + if heatmap_count > 0: + ax7.text(i, j, str(heatmap_count), ha="center", va="center", fontsize=8) plt.colorbar(im, ax=ax7, label="log10(count+1)") fig7.tight_layout() From 010fbaad591e95b34c6e28ef7506ba1a26db7c3b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 12:04:37 +0000 Subject: [PATCH 69/77] wip type stuff --- spd/clustering/ci_dt/core.py | 5 ++++- spd/clustering/ci_dt/plot.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index 3be97055b..e9b3a65b0 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -85,6 +85,7 @@ def extract_prob_class_1( result: list[np.ndarray] = [] for i, p in enumerate(proba_list): estimator = model.estimators_[i] + assert isinstance(estimator, DecisionTreeClassifier) classes = estimator.classes_ assert len(classes) == 2, f"Expected 2 classes but got {len(classes)} for output {i}" # Extract P(y=1) from second column @@ -224,4 +225,6 @@ def get_estimator_for( ) -> DecisionTreeClassifier: """Fetch the per-output estimator for a given layer and column.""" lm = next(m for m in models if m.layer_index == layer_idx) - return lm.model.estimators_[target_idx] # type: ignore + estimator = lm.model.estimators_[target_idx] + assert isinstance(estimator, DecisionTreeClassifier) + return estimator diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 690d55f43..7e4e66e3e 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -7,7 +7,7 @@ import numpy as np import torch from jaxtyping import Bool, Float, Int -from sklearn.tree import plot_tree +from sklearn.tree import DecisionTreeClassifier, plot_tree from spd.clustering.ci_dt.core import LayerModel, MetricKey, get_estimator_for @@ -452,6 +452,7 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La ap_list.append(ap) # Get tree depth for this target estimator = model.model.estimators_[target_idx] + assert isinstance(estimator, DecisionTreeClassifier) depth_list.append(int(estimator.tree_.max_depth)) prevalence_arr: np.ndarray = np.array(prevalence_list) @@ -640,6 +641,7 @@ def extract_tree_stats( for lm, stats in zip(models, per_layer_stats, strict=True): for i, estimator in enumerate(lm.model.estimators_): + assert isinstance(estimator, DecisionTreeClassifier) depths.append(int(estimator.tree_.max_depth)) leaf_counts.append(int(estimator.tree_.n_leaves)) accuracies.append(float(stats["acc"][i])) From 2f56924d4eed92d6844cebf5e4538b1c25ffbb02 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:06:01 +0000 Subject: [PATCH 70/77] type fixes --- spd/clustering/ci_dt/core.py | 4 ++-- spd/clustering/ci_dt/pipeline.py | 8 ++++++-- spd/clustering/ci_dt/plot.py | 10 +++++++--- spd/clustering/ci_dt/run.py | 8 ++++---- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py index e9b3a65b0..ec9ba585b 100644 --- a/spd/clustering/ci_dt/core.py +++ b/spd/clustering/ci_dt/core.py @@ -84,7 +84,7 @@ def extract_prob_class_1( """ result: list[np.ndarray] = [] for i, p in enumerate(proba_list): - estimator = model.estimators_[i] + estimator = model.estimators_[i] # pyright: ignore[reportIndexIssue] assert isinstance(estimator, DecisionTreeClassifier) classes = estimator.classes_ assert len(classes) == 2, f"Expected 2 classes but got {len(classes)} for output {i}" @@ -225,6 +225,6 @@ def get_estimator_for( ) -> DecisionTreeClassifier: """Fetch the per-output estimator for a given layer and column.""" lm = next(m for m in models if m.layer_index == layer_idx) - estimator = lm.model.estimators_[target_idx] + estimator = lm.model.estimators_[target_idx] # pyright: ignore[reportIndexIssue] assert isinstance(estimator, DecisionTreeClassifier) return estimator diff --git a/spd/clustering/ci_dt/pipeline.py b/spd/clustering/ci_dt/pipeline.py index 8126764f1..ab12d6081 100644 --- a/spd/clustering/ci_dt/pipeline.py +++ b/spd/clustering/ci_dt/pipeline.py @@ -117,8 +117,12 @@ def convert_to_boolean_layers( varying_mask: Bool[np.ndarray, " n_components"] = component_variance > 0 # Count always-dead and always-alive components for diagnostics - always_dead_mask: Bool[np.ndarray, " n_components"] = ~module_acts_bool.any(axis=0) - always_alive_mask: Bool[np.ndarray, " n_components"] = module_acts_bool.all(axis=0) + # NOTE: any(axis=0) and all(axis=0) are typed as returning numpy.bool_ | NDArray + # because they could return a scalar for 0-d arrays. We know these are always 1-d + # arrays at runtime, so we use type: ignore. Can't use assert isinstance() because + # pyright doesn't narrow union types with ndarray checks. + always_dead_mask: Bool[np.ndarray, " n_components"] = ~module_acts_bool.any(axis=0) # pyright: ignore[reportAssignmentType] + always_alive_mask: Bool[np.ndarray, " n_components"] = module_acts_bool.all(axis=0) # pyright: ignore[reportAssignmentType] n_always_dead: int = int(always_dead_mask.sum()) n_always_alive: int = int(always_alive_mask.sum()) diff --git a/spd/clustering/ci_dt/plot.py b/spd/clustering/ci_dt/plot.py index 7e4e66e3e..e5808b42c 100644 --- a/spd/clustering/ci_dt/plot.py +++ b/spd/clustering/ci_dt/plot.py @@ -451,7 +451,7 @@ def plot_ap_vs_prevalence(per_layer_stats: list[dict[str, Any]], models: list[La prevalence_list.append(prev) ap_list.append(ap) # Get tree depth for this target - estimator = model.model.estimators_[target_idx] + estimator = model.model.estimators_[target_idx] # pyright: ignore[reportIndexIssue] assert isinstance(estimator, DecisionTreeClassifier) depth_list.append(int(estimator.tree_.max_depth)) @@ -542,8 +542,12 @@ def plot_component_activity_breakdown( ).astype(bool) # Count each category - always_dead: Bool[np.ndarray, " n_components"] = ~acts_bool.any(axis=0) - always_alive: Bool[np.ndarray, " n_components"] = acts_bool.all(axis=0) + # NOTE: any(axis=0) and all(axis=0) are typed as returning numpy.bool_ | NDArray + # because they could return a scalar for 0-d arrays. We know these are always 1-d + # arrays at runtime, so we use type: ignore. Can't use assert isinstance() because + # pyright doesn't narrow union types with ndarray checks. + always_dead: Bool[np.ndarray, " n_components"] = ~acts_bool.any(axis=0) # pyright: ignore[reportAssignmentType] + always_alive: Bool[np.ndarray, " n_components"] = acts_bool.all(axis=0) # pyright: ignore[reportAssignmentType] varying: Bool[np.ndarray, " n_components"] = ~(always_dead | always_alive) n_always_dead_list.append(int(always_dead.sum())) diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 3e4de124a..526c66a8b 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -36,10 +36,10 @@ # ----------------------- configuration ----------------------- config = CIDTConfig( - # batch_size=50, # 50 ~~ 16GB VRAM max - # n_batches=8, - batch_size=2, - n_batches=2, + batch_size=32, # 50 ~~ 16GB VRAM max + n_batches=8, + # batch_size=2, + # n_batches=2, n_ctx=16, activation_threshold=0.01, max_depth=3, From 964f42e273597e7e7e19aa2555d3cf916b284626 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:06:49 +0000 Subject: [PATCH 71/77] print device --- spd/clustering/ci_dt/run.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index 526c66a8b..e129d1644 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -46,7 +46,7 @@ random_state=42, ) device: str = "cuda" if torch.cuda.is_available() else "cpu" - +print(f"Using {device=}") # %% # ----------------------- load model ----------------------- From b6620bacdbed3c0343656b39e9184f7da004228b Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 15:33:12 +0000 Subject: [PATCH 72/77] wip serialization of trees --- spd/clustering/ci_dt/config.py | 1 + spd/clustering/ci_dt/run.py | 38 ++++- spd/clustering/ci_dt/serialize.py | 243 ++++++++++++++++++++++++++++++ 3 files changed, 275 insertions(+), 7 deletions(-) create mode 100644 spd/clustering/ci_dt/serialize.py diff --git a/spd/clustering/ci_dt/config.py b/spd/clustering/ci_dt/config.py index 5980d654b..de0f95cee 100644 --- a/spd/clustering/ci_dt/config.py +++ b/spd/clustering/ci_dt/config.py @@ -7,6 +7,7 @@ class CIDTConfig: """Configuration for causal importance decision tree training.""" + wandb_run_path: str # WandB run path for the SPD model batch_size: int = 10 # Number of samples per batch for GPU inference n_batches: int = 25 # Number of batches to process (total samples = batch_size * n_batches) n_ctx: int = 64 # Context length (sequence length) for tokenization diff --git a/spd/clustering/ci_dt/run.py b/spd/clustering/ci_dt/run.py index e129d1644..94a80d008 100644 --- a/spd/clustering/ci_dt/run.py +++ b/spd/clustering/ci_dt/run.py @@ -4,6 +4,7 @@ import numpy as np import torch from jaxtyping import Bool +from matplotlib import pyplot as plt from torch import Tensor from spd.clustering.ci_dt.config import CIDTConfig @@ -23,6 +24,7 @@ plot_selected_trees, plot_tree_statistics, ) +from spd.clustering.ci_dt.serialize import TreeCollection from spd.configs import Config from spd.data import DatasetConfig, create_data_loader from spd.experiments.lm.configs import LMTaskConfig @@ -35,11 +37,14 @@ # %% # ----------------------- configuration ----------------------- +wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" + config = CIDTConfig( - batch_size=32, # 50 ~~ 16GB VRAM max - n_batches=8, - # batch_size=2, - # n_batches=2, + wandb_run_path=wandb_run_path, + # batch_size=32, # 50 ~~ 16GB VRAM max + # n_batches=8, + batch_size=2, + n_batches=2, n_ctx=16, activation_threshold=0.01, max_depth=3, @@ -50,8 +55,6 @@ # %% # ----------------------- load model ----------------------- -wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" - spd_run: SPDRunInfo = SPDRunInfo.from_path(wandb_run_path) model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) model.to(device) @@ -140,6 +143,23 @@ layers_true=layers_true, ) +# %% +# ----------------------- save trees ----------------------- + +n_samples: int = layers_true[0].shape[0] +tree_collection = TreeCollection.from_models( + models=models, + per_layer_stats=per_layer_stats, + config=config, + module_keys=module_keys, + device=device, + n_samples=n_samples, +) + +output_path = "trees.json" +tree_collection.save_json(output_path) +print(f"\nSaved {len(tree_collection.trees)} trees to {output_path}") + # %% # ----------------------- compute orderings ----------------------- # Generate sample ordering once for use in multiple plots @@ -209,5 +229,9 @@ # %% # ----------------------- plots: decision trees ----------------------- -plot_selected_trees(worst_list, "Worst", models) +print("Best") plot_selected_trees(best_list, "Best", models) +plt.show() +print("Worst") +plot_selected_trees(worst_list, "Worst", models) +plt.show() diff --git a/spd/clustering/ci_dt/serialize.py b/spd/clustering/ci_dt/serialize.py new file mode 100644 index 000000000..018310755 --- /dev/null +++ b/spd/clustering/ci_dt/serialize.py @@ -0,0 +1,243 @@ +"""Serialization for causal importance decision trees.""" + +import json +from dataclasses import asdict, dataclass +from datetime import datetime +from pathlib import Path +from typing import Any + +import numpy as np +from sklearn.tree import DecisionTreeClassifier + +from spd.clustering.ci_dt.config import CIDTConfig +from spd.clustering.ci_dt.core import LayerModel + + +@dataclass +class TreeNode: + """A single node in the decision tree (nested structure).""" + + is_leaf: bool + n_samples: int + value: list[float] # Prediction probabilities [P(class=0), P(class=1)] + + # Only for non-leaf nodes: + feature: int | None = None # Which component to check + left: "TreeNode | None" = None # Left child (feature is False/0) + right: "TreeNode | None" = None # Right child (feature is True/1) + + def serialize(self) -> dict[str, Any]: + """Recursively serialize to nested dict.""" + result: dict[str, Any] = { + "is_leaf": self.is_leaf, + "n_samples": self.n_samples, + "value": self.value, + } + + if not self.is_leaf: + assert self.feature is not None + assert self.left is not None + assert self.right is not None + result["feature"] = self.feature + result["left"] = self.left.serialize() + result["right"] = self.right.serialize() + + return result + + @classmethod + def from_sklearn(cls, tree: DecisionTreeClassifier, node_id: int = 0) -> "TreeNode": + """Recursively build nested structure from sklearn tree.""" + sklearn_tree = tree.tree_ + + # Extract node info + n_samples = int(sklearn_tree.n_node_samples[node_id]) + value = sklearn_tree.value[node_id][0].tolist() # Extract [n_samples, n_classes] + + # Check if leaf + left_child = int(sklearn_tree.children_left[node_id]) + right_child = int(sklearn_tree.children_right[node_id]) + is_leaf = left_child == right_child # Both -1 for leaves + + if is_leaf: + return cls(is_leaf=True, n_samples=n_samples, value=value) + + # Non-leaf: recursively build children + feature = int(sklearn_tree.feature[node_id]) + return cls( + is_leaf=False, + n_samples=n_samples, + value=value, + feature=feature, + left=cls.from_sklearn(tree, left_child), + right=cls.from_sklearn(tree, right_child), + ) + + +@dataclass +class SavedTree: + """A single decision tree with metadata.""" + + layer_index: int # Which layer this tree predicts + target_index: int # Which component within that layer + metrics: dict[str, float] # Performance metrics and tree stats + structure: TreeNode + + def serialize(self) -> dict[str, Any]: + """Convert to JSON-serializable dict.""" + return { + "layer_index": self.layer_index, + "target_index": self.target_index, + "metrics": self.metrics, + "structure": self.structure.serialize(), + } + + @classmethod + def from_sklearn( + cls, + layer_index: int, + target_index: int, + tree: DecisionTreeClassifier, + metrics_dict: dict[str, np.ndarray], + ) -> "SavedTree": + """Create from sklearn tree and metrics.""" + # Extract metrics for this specific target + metrics = { + "ap": float(metrics_dict["ap"][target_index]), + "acc": float(metrics_dict["acc"][target_index]), + "bacc": float(metrics_dict["bacc"][target_index]), + "f1": float(metrics_dict["f1"][target_index]), + "precision": float(metrics_dict["precision"][target_index]), + "recall": float(metrics_dict["tpr"][target_index]), # recall = TPR + "tpr": float(metrics_dict["tpr"][target_index]), + "tnr": float(metrics_dict["tnr"][target_index]), + "npv": float(metrics_dict["npv"][target_index]), + "prev": float(metrics_dict["prev"][target_index]), + "depth": int(tree.get_depth()), + "n_nodes": int(tree.tree_.node_count), + "n_leaves": int(tree.get_n_leaves()), + } + + return cls( + layer_index=layer_index, + target_index=target_index, + metrics=metrics, + structure=TreeNode.from_sklearn(tree), + ) + + +@dataclass +class TreeCollection: + """Collection of all trees with configuration and metadata.""" + + config: CIDTConfig + module_keys: list[str] # Layer names + training_info: dict[str, Any] # timestamp, device, n_samples, etc. + trees: list[SavedTree] + + def serialize(self) -> dict[str, Any]: + """Convert to JSON-serializable dict.""" + return { + "config": asdict(self.config), + "module_keys": self.module_keys, + "training_info": self.training_info, + "trees": [tree.serialize() for tree in self.trees], + } + + def save_json(self, path: str | Path) -> None: + """Save to JSON file.""" + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w") as f: + json.dump(self.serialize(), f, indent=2) + + @classmethod + def from_models( + cls, + models: list[LayerModel], + per_layer_stats: list[dict[str, Any]], + config: CIDTConfig, + module_keys: list[str], + device: str, + n_samples: int, + ) -> "TreeCollection": + """Create from trained models and metrics.""" + trees: list[SavedTree] = [] + + for lm, metrics_dict in zip(models, per_layer_stats, strict=True): + for target_idx in range(lm.target_dim): + # Get the individual tree for this output + estimator = lm.model.estimators_[target_idx] # pyright: ignore[reportIndexIssue] + assert isinstance(estimator, DecisionTreeClassifier) + + trees.append( + SavedTree.from_sklearn( + layer_index=lm.layer_index, + target_index=target_idx, + tree=estimator, + metrics_dict=metrics_dict, + ) + ) + + training_info = { + "timestamp": datetime.now().isoformat(), + "device": device, + "n_samples": n_samples, + "n_layers": len(models), + "total_trees": len(trees), + } + + return cls( + config=config, + module_keys=module_keys, + training_info=training_info, + trees=trees, + ) + + @classmethod + def load_json(cls, path: str | Path) -> "TreeCollection": + """Load from JSON file.""" + path = Path(path) + with path.open() as f: + data = json.load(f) + + # Reconstruct config + config = CIDTConfig(**data["config"]) + + # Reconstruct trees + trees = [] + for tree_data in data["trees"]: + structure = cls._deserialize_tree_node(tree_data["structure"]) + trees.append( + SavedTree( + layer_index=tree_data["layer_index"], + target_index=tree_data["target_index"], + metrics=tree_data["metrics"], + structure=structure, + ) + ) + + return cls( + config=config, + module_keys=data["module_keys"], + training_info=data["training_info"], + trees=trees, + ) + + @staticmethod + def _deserialize_tree_node(data: dict[str, Any]) -> TreeNode: + """Recursively reconstruct TreeNode from dict.""" + if data["is_leaf"]: + return TreeNode( + is_leaf=True, + n_samples=data["n_samples"], + value=data["value"], + ) + + return TreeNode( + is_leaf=False, + n_samples=data["n_samples"], + value=data["value"], + feature=data["feature"], + left=TreeCollection._deserialize_tree_node(data["left"]), + right=TreeCollection._deserialize_tree_node(data["right"]), + ) From cd2df154aa7d1a674eabdf5c5f6bc752dea7c776 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 16:37:47 +0000 Subject: [PATCH 73/77] wip --- spd/clustering/ci_dt/matshow_sort.py | 82 +++++ spd/clustering/ci_dt/minimal_run.py | 217 ++++++++++++++ spd/clustering/ci_dt/viewer.html | 431 +++++++++++++++++++++++++++ 3 files changed, 730 insertions(+) create mode 100644 spd/clustering/ci_dt/matshow_sort.py create mode 100644 spd/clustering/ci_dt/minimal_run.py create mode 100644 spd/clustering/ci_dt/viewer.html diff --git a/spd/clustering/ci_dt/matshow_sort.py b/spd/clustering/ci_dt/matshow_sort.py new file mode 100644 index 000000000..ba27ca9a4 --- /dev/null +++ b/spd/clustering/ci_dt/matshow_sort.py @@ -0,0 +1,82 @@ +from typing import Literal +import numpy as np +from jaxtyping import Float + +MatrixSortMetric = Literal["cosine", "dot"] + + +def sort_by_similarity( + arr: Float[np.ndarray, "m n"], + axis: int = 0, + metric: MatrixSortMetric = "cosine", +) -> Float[np.ndarray, "m n"]: + """Sort a 2D array by similarity of rows or columns using a greedy heuristic, + with a secondary tie-breaker by absolute sum magnitude. + + Starts from the first row (or column) and iteratively adds the most similar + remaining vector until all are ordered. Uses cosine or dot-product similarity. + When multiple candidates have equal similarity to the last selected vector, + the one with the largest absolute sum is chosen. + + # Parameters: + - `arr : Float[np.ndarray, "m n"]` + Input 2D array to sort. + - `axis : int` + 0 to sort rows, 1 to sort columns. + (defaults to 0) + - `metric : str` + 'cosine' or 'dot' similarity. + (defaults to 'cosine') + - `seed : int | None` + Random seed for reproducibility. + (defaults to None) + + # Returns: + - `Float[np.ndarray, "m n"]` + Array reordered by similarity, tie-broken by abs-sum. + + # Usage: + ```python + >>> a = np.random.rand(5, 4) + >>> sort_by_similarity(a, axis=0) + >>> sort_by_similarity(a, axis=1) + ``` + """ + if arr.ndim != 2: + raise ValueError(f"Input must be 2D, got shape {arr.shape}") + if axis not in (0, 1): + raise ValueError(f"axis must be 0 or 1, got {axis}") + if metric not in MatrixSortMetric.__args__: + raise ValueError(f"metric must be in {MatrixSortMetric.__args__ = }") + + data: Float[np.ndarray, "n d"] = arr if axis == 0 else arr.T + n: int = data.shape[0] + + # normalize for cosine similarity + if metric == "cosine": + norms: Float[np.ndarray, "n 1"] = np.linalg.norm(data, axis=1, keepdims=True) + norms[norms == 0] = 1 + data = data / norms + + # similarity matrix (dot product or cosine) + sim: Float[np.ndarray, "n n"] = data @ data.T + abs_sums: Float[np.ndarray, "n"] = np.sum(np.abs(data), axis=1) + + # greedy ordering + remaining: list[int] = list(range(n)) + order: list[int] = [remaining.pop(0)] # start from first row/col + + while remaining: + last: int = order[-1] + + # choose next by (similarity, abs_sum) + next_idx: int = max( + remaining, + key=lambda i: (sim[last, i], abs_sums[i]), + ) + + order.append(next_idx) + remaining.remove(next_idx) + + sorted_arr: Float[np.ndarray, "m n"] = arr[order, :] if axis == 0 else arr[:, order] + return sorted_arr \ No newline at end of file diff --git a/spd/clustering/ci_dt/minimal_run.py b/spd/clustering/ci_dt/minimal_run.py new file mode 100644 index 000000000..c6c02990f --- /dev/null +++ b/spd/clustering/ci_dt/minimal_run.py @@ -0,0 +1,217 @@ +#%% +"""Minimal single-script version of causal importance decision tree training.""" + +import json +from typing import Any + +from matplotlib import pyplot as plt +import numpy as np +import torch +from jaxtyping import Bool, Float +from sklearn.metrics import accuracy_score, average_precision_score, balanced_accuracy_score +from sklearn.multioutput import MultiOutputClassifier +from sklearn.tree import DecisionTreeClassifier +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm + +from spd.clustering.ci_dt.matshow_sort import sort_by_similarity +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, OutputWithCache, SPDRunInfo + +# %% ----------------------- Configuration ----------------------- +WANDB_RUN_PATH = "wandb:goodfire/spd/runs/lxs77xye" +BATCH_SIZE = 8 +N_BATCHES = 4 +N_CTX = 16 +ACTIVATION_THRESHOLD = 0.01 +MAX_DEPTH = 3 +RANDOM_STATE = 42 +device = "cuda" if torch.cuda.is_available() else "cpu" + +# %% ----------------------- Load Model ----------------------- +print(f"Loading model from {WANDB_RUN_PATH}...") +spd_run: SPDRunInfo = SPDRunInfo.from_path(WANDB_RUN_PATH) +model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) +model.to(device) +cfg: Config = spd_run.config + +# %% ----------------------- Load Dataset ----------------------- +assert isinstance(cfg.task_config, LMTaskConfig) +assert cfg.pretrained_model_name is not None + +dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=cfg.pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=N_CTX, + column_name=cfg.task_config.column_name, + is_tokenized=False, + streaming=False, + seed=0, +) +dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=BATCH_SIZE, + buffer_size=cfg.task_config.buffer_size, + global_seed=cfg.seed, + ddp_rank=0, + ddp_world_size=1, +) + +# %% ----------------------- Compute Activations ----------------------- +print(f"\nComputing activations for {N_BATCHES} batches...") +all_acts: list[dict[str, Tensor]] = [] + +for _ in tqdm(range(N_BATCHES), desc="Batches"): + batch: Tensor = next(iter(dataloader))["input_ids"] + with torch.no_grad(): + output: OutputWithCache = model(batch.to(device), cache_type="input") + acts: dict[str, Tensor] = model.calc_causal_importances( + pre_weight_acts=output.cache, + sampling="continuous", + detach_inputs=False, + ).upper_leaky + all_acts.append({k: v.cpu() for k, v in acts.items()}) + +# Concatenate batches +module_keys = list(all_acts[0].keys()) +acts_concat: dict[str, Tensor] = { + k: torch.cat([b[k] for b in all_acts], dim=0) for k in module_keys +} + +# %% ----------------------- Convert to Boolean Layers ----------------------- +print("\nConverting to boolean and filtering constant components...") +layers: list[Bool[np.ndarray, "n_samples n_components"]] = [] + +for k in module_keys: + # Flatten if 3D (batch, seq, components) -> (batch*seq, components) + acts_tensor = acts_concat[k] + if acts_tensor.ndim == 3: + acts_np: Float[np.ndarray, "n_samples n_components"] = ( + acts_tensor.reshape(-1, acts_tensor.shape[-1]).numpy() + ) + else: + acts_np = acts_tensor.numpy() + + # Threshold to boolean + acts_bool: Bool[np.ndarray, "n_samples n_components"] = ( + acts_np >= ACTIVATION_THRESHOLD + ).astype(bool) + + # plt.title(f"{k}") + # sort by column similarity + acts_sorted = sort_by_similarity(sort_by_similarity(acts_bool.astype(float), axis=0), axis=1) + plt.matshow(acts_sorted[:,:600], aspect="auto") + plt.show() + + # Filter constant components (always 0 or always 1) + varying_mask: Bool[np.ndarray, " n_components"] = acts_bool.var(axis=0) > 0 + acts_varying = acts_bool[:, varying_mask] + layers.append(acts_varying) + print(f" {k}: {acts_varying.shape[1]} varying components") + +# %% ----------------------- Train Decision Trees ----------------------- +print("\nTraining decision trees...") +# Build (X, Y) pairs: X_k = concat(layers[:k]), Y_k = layers[k] +models: list[tuple[int, MultiOutputClassifier]] = [] + +for k in tqdm(range(1, len(layers)), desc="Training"): + X = np.concatenate(layers[:k], axis=1) if k > 0 else np.zeros((layers[0].shape[0], 0), bool) + Y = layers[k] + + clf = MultiOutputClassifier( + DecisionTreeClassifier( + max_depth=MAX_DEPTH, + min_samples_leaf=1, + random_state=RANDOM_STATE, + ) + ) + clf.fit(X.astype(np.uint8), Y.astype(np.uint8)) + models.append((k, clf)) + +# %% ----------------------- Compute Metrics ----------------------- +print("\nComputing metrics...") + + +def extract_prob_class_1(proba_list: list[np.ndarray], clf: MultiOutputClassifier) -> np.ndarray: + """Extract P(y=1) for each output.""" + result: list[np.ndarray] = [] + for i, p in enumerate(proba_list): + estimator = clf.estimators_[i] # type: ignore + assert isinstance(estimator, DecisionTreeClassifier) + assert len(estimator.classes_) == 2 + result.append(p[:, 1]) # P(y=1) + return np.stack(result, axis=1) + + +def tree_to_dict(tree: DecisionTreeClassifier) -> dict[str, Any]: + """Convert sklearn DecisionTree to JSON-serializable dict.""" + tree_ = tree.tree_ + return { + "feature": tree_.feature.tolist(), + "threshold": tree_.threshold.tolist(), + "children_left": tree_.children_left.tolist(), + "children_right": tree_.children_right.tolist(), + "value": tree_.value.tolist(), + "n_node_samples": tree_.n_node_samples.tolist(), + } + + +# Collect all results for saving +results: list[dict[str, Any]] = [] + +print("\nPer-layer metrics:") +for layer_idx, clf in models: + # Prepare X, Y for this layer + X = np.concatenate(layers[:layer_idx], axis=1) + Y = layers[layer_idx] + + # Predict + proba_list = clf.predict_proba(X.astype(np.uint8)) # type: ignore + P = extract_prob_class_1(proba_list, clf) + Y_pred = P >= 0.5 + + # Compute metrics per component + ap_scores: list[float] = [] + acc_scores: list[float] = [] + bacc_scores: list[float] = [] + + for j in range(Y.shape[1]): + y_true = Y[:, j].astype(int) + y_prob = P[:, j] + y_pred = Y_pred[:, j].astype(int) + + ap_scores.append(average_precision_score(y_true, y_prob)) + acc_scores.append(accuracy_score(y_true, y_pred)) + bacc_scores.append(balanced_accuracy_score(y_true, y_pred)) + + # Print summary + print(f" Layer {layer_idx} ({module_keys[layer_idx]}):") + print(f" Mean AP: {np.mean(ap_scores):.3f}") + print(f" Mean Acc: {np.mean(acc_scores):.3f}") + print(f" Mean BAcc: {np.mean(bacc_scores):.3f}") + + # Store results with tree structures + trees_data = [tree_to_dict(est) for est in clf.estimators_] # type: ignore + + results.append({ + "layer_idx": layer_idx, + "module_key": module_keys[layer_idx], + "trees": trees_data, + "ap_scores": ap_scores, + "acc_scores": acc_scores, + "bacc_scores": bacc_scores, + "mean_ap": float(np.mean(ap_scores)), + "mean_acc": float(np.mean(acc_scores)), + "mean_bacc": float(np.mean(bacc_scores)), + }) + +# %% ----------------------- Save Trees ----------------------- +output_path = "trees.json" +with open(output_path, "w") as f: + json.dump(results, f, indent=2) +print(f"\nSaved {len(results)} layers with trees and metrics to {output_path}") +print("Done!") diff --git a/spd/clustering/ci_dt/viewer.html b/spd/clustering/ci_dt/viewer.html new file mode 100644 index 000000000..e8e7686fd --- /dev/null +++ b/spd/clustering/ci_dt/viewer.html @@ -0,0 +1,431 @@ + + + + + + Decision Tree Viewer + + + +

Decision Tree Viewer

+
+ + + + From 59b299a0588fddaa68dda295e3c77a13592b8a4e Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Mon, 27 Oct 2025 16:37:54 +0000 Subject: [PATCH 74/77] make format --- spd/clustering/ci_dt/matshow_sort.py | 7 +++--- spd/clustering/ci_dt/minimal_run.py | 37 ++++++++++++++-------------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/spd/clustering/ci_dt/matshow_sort.py b/spd/clustering/ci_dt/matshow_sort.py index ba27ca9a4..4998c7199 100644 --- a/spd/clustering/ci_dt/matshow_sort.py +++ b/spd/clustering/ci_dt/matshow_sort.py @@ -1,4 +1,5 @@ from typing import Literal + import numpy as np from jaxtyping import Float @@ -12,7 +13,7 @@ def sort_by_similarity( ) -> Float[np.ndarray, "m n"]: """Sort a 2D array by similarity of rows or columns using a greedy heuristic, with a secondary tie-breaker by absolute sum magnitude. - + Starts from the first row (or column) and iteratively adds the most similar remaining vector until all are ordered. Uses cosine or dot-product similarity. When multiple candidates have equal similarity to the last selected vector, @@ -60,7 +61,7 @@ def sort_by_similarity( # similarity matrix (dot product or cosine) sim: Float[np.ndarray, "n n"] = data @ data.T - abs_sums: Float[np.ndarray, "n"] = np.sum(np.abs(data), axis=1) + abs_sums: Float[np.ndarray, n] = np.sum(np.abs(data), axis=1) # greedy ordering remaining: list[int] = list(range(n)) @@ -79,4 +80,4 @@ def sort_by_similarity( remaining.remove(next_idx) sorted_arr: Float[np.ndarray, "m n"] = arr[order, :] if axis == 0 else arr[:, order] - return sorted_arr \ No newline at end of file + return sorted_arr diff --git a/spd/clustering/ci_dt/minimal_run.py b/spd/clustering/ci_dt/minimal_run.py index c6c02990f..9411cf7b7 100644 --- a/spd/clustering/ci_dt/minimal_run.py +++ b/spd/clustering/ci_dt/minimal_run.py @@ -1,18 +1,17 @@ -#%% +# %% """Minimal single-script version of causal importance decision tree training.""" import json from typing import Any -from matplotlib import pyplot as plt import numpy as np import torch from jaxtyping import Bool, Float +from matplotlib import pyplot as plt from sklearn.metrics import accuracy_score, average_precision_score, balanced_accuracy_score from sklearn.multioutput import MultiOutputClassifier from sklearn.tree import DecisionTreeClassifier from torch import Tensor -from torch.utils.data import DataLoader from tqdm import tqdm from spd.clustering.ci_dt.matshow_sort import sort_by_similarity @@ -90,9 +89,9 @@ # Flatten if 3D (batch, seq, components) -> (batch*seq, components) acts_tensor = acts_concat[k] if acts_tensor.ndim == 3: - acts_np: Float[np.ndarray, "n_samples n_components"] = ( - acts_tensor.reshape(-1, acts_tensor.shape[-1]).numpy() - ) + acts_np: Float[np.ndarray, "n_samples n_components"] = acts_tensor.reshape( + -1, acts_tensor.shape[-1] + ).numpy() else: acts_np = acts_tensor.numpy() @@ -104,7 +103,7 @@ # plt.title(f"{k}") # sort by column similarity acts_sorted = sort_by_similarity(sort_by_similarity(acts_bool.astype(float), axis=0), axis=1) - plt.matshow(acts_sorted[:,:600], aspect="auto") + plt.matshow(acts_sorted[:, :600], aspect="auto") plt.show() # Filter constant components (always 0 or always 1) @@ -197,17 +196,19 @@ def tree_to_dict(tree: DecisionTreeClassifier) -> dict[str, Any]: # Store results with tree structures trees_data = [tree_to_dict(est) for est in clf.estimators_] # type: ignore - results.append({ - "layer_idx": layer_idx, - "module_key": module_keys[layer_idx], - "trees": trees_data, - "ap_scores": ap_scores, - "acc_scores": acc_scores, - "bacc_scores": bacc_scores, - "mean_ap": float(np.mean(ap_scores)), - "mean_acc": float(np.mean(acc_scores)), - "mean_bacc": float(np.mean(bacc_scores)), - }) + results.append( + { + "layer_idx": layer_idx, + "module_key": module_keys[layer_idx], + "trees": trees_data, + "ap_scores": ap_scores, + "acc_scores": acc_scores, + "bacc_scores": bacc_scores, + "mean_ap": float(np.mean(ap_scores)), + "mean_acc": float(np.mean(acc_scores)), + "mean_bacc": float(np.mean(bacc_scores)), + } + ) # %% ----------------------- Save Trees ----------------------- output_path = "trees.json" From 06deb8c7909062fda6187eaa63f4e27c93402c34 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:01:27 +0000 Subject: [PATCH 75/77] uv lock --- uv.lock | 2511 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 2511 insertions(+) create mode 100644 uv.lock diff --git a/uv.lock b/uv.lock new file mode 100644 index 000000000..6aab291ef --- /dev/null +++ b/uv.lock @@ -0,0 +1,2511 @@ +version = 1 +revision = 3 +requires-python = "==3.13.*" +resolution-markers = [ + "sys_platform == 'linux'", + "sys_platform != 'linux'", +] + +[[package]] +name = "aiohappyeyeballs" +version = "2.6.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/26/30/f84a107a9c4331c14b2b586036f40965c128aa4fee4dda5d3d51cb14ad54/aiohappyeyeballs-2.6.1.tar.gz", hash = "sha256:c3f9d0113123803ccadfdf3f0faa505bc78e6a72d1cc4806cbd719826e943558", size = 22760, upload-time = "2025-03-12T01:42:48.764Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/15/5bf3b99495fb160b63f95972b81750f18f7f4e02ad051373b669d17d44f2/aiohappyeyeballs-2.6.1-py3-none-any.whl", hash = "sha256:f349ba8f4b75cb25c99c5c2d84e997e485204d2902a9597802b0371f09331fb8", size = 15265, upload-time = "2025-03-12T01:42:47.083Z" }, +] + +[[package]] +name = "aiohttp" +version = "3.13.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohappyeyeballs" }, + { name = "aiosignal" }, + { name = "attrs" }, + { name = "frozenlist" }, + { name = "multidict" }, + { name = "propcache" }, + { name = "yarl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1c/ce/3b83ebba6b3207a7135e5fcaba49706f8a4b6008153b4e30540c982fae26/aiohttp-3.13.2.tar.gz", hash = "sha256:40176a52c186aefef6eb3cad2cdd30cd06e3afbe88fe8ab2af9c0b90f228daca", size = 7837994, upload-time = "2025-10-28T20:59:39.937Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/78/7e90ca79e5aa39f9694dcfd74f4720782d3c6828113bb1f3197f7e7c4a56/aiohttp-3.13.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7519bdc7dfc1940d201651b52bf5e03f5503bda45ad6eacf64dda98be5b2b6be", size = 732139, upload-time = "2025-10-28T20:57:02.455Z" }, + { url = "https://files.pythonhosted.org/packages/db/ed/1f59215ab6853fbaa5c8495fa6cbc39edfc93553426152b75d82a5f32b76/aiohttp-3.13.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:088912a78b4d4f547a1f19c099d5a506df17eacec3c6f4375e2831ec1d995742", size = 490082, upload-time = "2025-10-28T20:57:04.784Z" }, + { url = "https://files.pythonhosted.org/packages/68/7b/fe0fe0f5e05e13629d893c760465173a15ad0039c0a5b0d0040995c8075e/aiohttp-3.13.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5276807b9de9092af38ed23ce120539ab0ac955547b38563a9ba4f5b07b95293", size = 489035, upload-time = "2025-10-28T20:57:06.894Z" }, + { url = "https://files.pythonhosted.org/packages/d2/04/db5279e38471b7ac801d7d36a57d1230feeee130bbe2a74f72731b23c2b1/aiohttp-3.13.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1237c1375eaef0db4dcd7c2559f42e8af7b87ea7d295b118c60c36a6e61cb811", size = 1720387, upload-time = "2025-10-28T20:57:08.685Z" }, + { url = "https://files.pythonhosted.org/packages/31/07/8ea4326bd7dae2bd59828f69d7fdc6e04523caa55e4a70f4a8725a7e4ed2/aiohttp-3.13.2-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:96581619c57419c3d7d78703d5b78c1e5e5fc0172d60f555bdebaced82ded19a", size = 1688314, upload-time = "2025-10-28T20:57:10.693Z" }, + { url = "https://files.pythonhosted.org/packages/48/ab/3d98007b5b87ffd519d065225438cc3b668b2f245572a8cb53da5dd2b1bc/aiohttp-3.13.2-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2713a95b47374169409d18103366de1050fe0ea73db358fc7a7acb2880422d4", size = 1756317, upload-time = "2025-10-28T20:57:12.563Z" }, + { url = "https://files.pythonhosted.org/packages/97/3d/801ca172b3d857fafb7b50c7c03f91b72b867a13abca982ed6b3081774ef/aiohttp-3.13.2-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:228a1cd556b3caca590e9511a89444925da87d35219a49ab5da0c36d2d943a6a", size = 1858539, upload-time = "2025-10-28T20:57:14.623Z" }, + { url = "https://files.pythonhosted.org/packages/f7/0d/4764669bdf47bd472899b3d3db91fffbe925c8e3038ec591a2fd2ad6a14d/aiohttp-3.13.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac6cde5fba8d7d8c6ac963dbb0256a9854e9fafff52fbcc58fdf819357892c3e", size = 1739597, upload-time = "2025-10-28T20:57:16.399Z" }, + { url = "https://files.pythonhosted.org/packages/c4/52/7bd3c6693da58ba16e657eb904a5b6decfc48ecd06e9ac098591653b1566/aiohttp-3.13.2-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f2bef8237544f4e42878c61cef4e2839fee6346dc60f5739f876a9c50be7fcdb", size = 1555006, upload-time = "2025-10-28T20:57:18.288Z" }, + { url = "https://files.pythonhosted.org/packages/48/30/9586667acec5993b6f41d2ebcf96e97a1255a85f62f3c653110a5de4d346/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:16f15a4eac3bc2d76c45f7ebdd48a65d41b242eb6c31c2245463b40b34584ded", size = 1683220, upload-time = "2025-10-28T20:57:20.241Z" }, + { url = "https://files.pythonhosted.org/packages/71/01/3afe4c96854cfd7b30d78333852e8e851dceaec1c40fd00fec90c6402dd2/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:bb7fb776645af5cc58ab804c58d7eba545a97e047254a52ce89c157b5af6cd0b", size = 1712570, upload-time = "2025-10-28T20:57:22.253Z" }, + { url = "https://files.pythonhosted.org/packages/11/2c/22799d8e720f4697a9e66fd9c02479e40a49de3de2f0bbe7f9f78a987808/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:e1b4951125ec10c70802f2cb09736c895861cd39fd9dcb35107b4dc8ae6220b8", size = 1733407, upload-time = "2025-10-28T20:57:24.37Z" }, + { url = "https://files.pythonhosted.org/packages/34/cb/90f15dd029f07cebbd91f8238a8b363978b530cd128488085b5703683594/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:550bf765101ae721ee1d37d8095f47b1f220650f85fe1af37a90ce75bab89d04", size = 1550093, upload-time = "2025-10-28T20:57:26.257Z" }, + { url = "https://files.pythonhosted.org/packages/69/46/12dce9be9d3303ecbf4d30ad45a7683dc63d90733c2d9fe512be6716cd40/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:fe91b87fc295973096251e2d25a811388e7d8adf3bd2b97ef6ae78bc4ac6c476", size = 1758084, upload-time = "2025-10-28T20:57:28.349Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c8/0932b558da0c302ffd639fc6362a313b98fdf235dc417bc2493da8394df7/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e0c8e31cfcc4592cb200160344b2fb6ae0f9e4effe06c644b5a125d4ae5ebe23", size = 1716987, upload-time = "2025-10-28T20:57:30.233Z" }, + { url = "https://files.pythonhosted.org/packages/5d/8b/f5bd1a75003daed099baec373aed678f2e9b34f2ad40d85baa1368556396/aiohttp-3.13.2-cp313-cp313-win32.whl", hash = "sha256:0740f31a60848d6edb296a0df827473eede90c689b8f9f2a4cdde74889eb2254", size = 425859, upload-time = "2025-10-28T20:57:32.105Z" }, + { url = "https://files.pythonhosted.org/packages/5d/28/a8a9fc6957b2cee8902414e41816b5ab5536ecf43c3b1843c10e82c559b2/aiohttp-3.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:a88d13e7ca367394908f8a276b89d04a3652044612b9a408a0bb22a5ed976a1a", size = 452192, upload-time = "2025-10-28T20:57:34.166Z" }, +] + +[[package]] +name = "aiosignal" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "frozenlist" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7", size = 25007, upload-time = "2025-07-03T22:54:43.528Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, +] + +[[package]] +name = "altair" +version = "5.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "jsonschema" }, + { name = "narwhals" }, + { name = "packaging" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/16/b1/f2969c7bdb8ad8bbdda031687defdce2c19afba2aa2c8e1d2a17f78376d8/altair-5.5.0.tar.gz", hash = "sha256:d960ebe6178c56de3855a68c47b516be38640b73fb3b5111c2a9ca90546dd73d", size = 705305, upload-time = "2024-11-23T23:39:58.542Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl", hash = "sha256:91a310b926508d560fe0148d02a194f38b824122641ef528113d029fcd129f8c", size = 731200, upload-time = "2024-11-23T23:39:56.4Z" }, +] + +[[package]] +name = "annotated-doc" +version = "0.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/a6/dc46877b911e40c00d395771ea710d5e77b6de7bacd5fdcd78d70cc5a48f/annotated_doc-0.0.3.tar.gz", hash = "sha256:e18370014c70187422c33e945053ff4c286f453a984eba84d0dbfa0c935adeda", size = 5535, upload-time = "2025-10-24T14:57:10.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/b7/cf592cb5de5cb3bade3357f8d2cf42bf103bbe39f459824b4939fd212911/annotated_doc-0.0.3-py3-none-any.whl", hash = "sha256:348ec6664a76f1fd3be81f43dffbee4c7e8ce931ba71ec67cc7f4ade7fbbb580", size = 5488, upload-time = "2025-10-24T14:57:09.462Z" }, +] + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "anyio" +version = "4.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "sniffio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, +] + +[[package]] +name = "appnope" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170, upload-time = "2024-02-06T09:43:11.258Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321, upload-time = "2024-02-06T09:43:09.663Z" }, +] + +[[package]] +name = "asttokens" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978, upload-time = "2024-11-30T04:30:14.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, +] + +[[package]] +name = "attrs" +version = "25.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6b/5c/685e6633917e101e5dcb62b9dd76946cbb57c26e133bae9e0cd36033c0a9/attrs-25.4.0.tar.gz", hash = "sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11", size = 934251, upload-time = "2025-10-06T13:54:44.725Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, +] + +[[package]] +name = "basedpyright" +version = "1.31.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodejs-wheel-binaries" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c6/ba/ed69e8df732a09c8ca469f592c8e08707fe29149735b834c276d94d4a3da/basedpyright-1.31.7.tar.gz", hash = "sha256:394f334c742a19bcc5905b2455c9f5858182866b7679a6f057a70b44b049bceb", size = 22710948, upload-time = "2025-10-11T05:12:48.3Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, +] + +[[package]] +name = "blinker" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/28/9b3f50ce0e048515135495f198351908d99540d69bfdc8c1d15b73dc55ce/blinker-1.9.0.tar.gz", hash = "sha256:b4ce2265a7abece45e7cc896e98dbebe6cead56bcf805a3d23136d145f5445bf", size = 22460, upload-time = "2024-11-08T17:25:47.436Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" }, +] + +[[package]] +name = "cachetools" +version = "6.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/7e/b975b5814bd36faf009faebe22c1072a1fa1168db34d285ef0ba071ad78c/cachetools-6.2.1.tar.gz", hash = "sha256:3f391e4bd8f8bf0931169baf7456cc822705f4e2a31f840d218f445b9a854201", size = 31325, upload-time = "2025-10-12T14:55:30.139Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/96/c5/1e741d26306c42e2bf6ab740b2202872727e0f606033c9dd713f8b93f5a8/cachetools-6.2.1-py3-none-any.whl", hash = "sha256:09868944b6dde876dfd44e1d47e18484541eaf12f26f29b7af91b26cc892d701", size = 11280, upload-time = "2025-10-12T14:55:28.382Z" }, +] + +[[package]] +name = "certifi" +version = "2025.10.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/5b/b6ce21586237c77ce67d01dc5507039d444b630dd76611bbca2d8e5dcd91/certifi-2025.10.5.tar.gz", hash = "sha256:47c09d31ccf2acf0be3f701ea53595ee7e0b8fa08801c6624be771df09ae7b43", size = 164519, upload-time = "2025-10-05T04:12:15.808Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/37/af0d2ef3967ac0d6113837b44a4f0bfe1328c2b9763bd5b1744520e5cfed/certifi-2025.10.5-py3-none-any.whl", hash = "sha256:0f212c2744a9bb6de0c56639a6f68afe01ecd92d91f14ae897c4fe7bbeeef0de", size = 163286, upload-time = "2025-10-05T04:12:14.03Z" }, +] + +[[package]] +name = "cffi" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pycparser", marker = "implementation_name != 'PyPy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/8d/a0a47a0c9e413a658623d014e91e74a50cdd2c423f7ccfd44086ef767f90/cffi-2.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb", size = 185230, upload-time = "2025-09-08T23:23:00.879Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d2/a6c0296814556c68ee32009d9c2ad4f85f2707cdecfd7727951ec228005d/cffi-2.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca", size = 181043, upload-time = "2025-09-08T23:23:02.231Z" }, + { url = "https://files.pythonhosted.org/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b", size = 212446, upload-time = "2025-09-08T23:23:03.472Z" }, + { url = "https://files.pythonhosted.org/packages/a9/f5/a2c23eb03b61a0b8747f211eb716446c826ad66818ddc7810cc2cc19b3f2/cffi-2.0.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d48a880098c96020b02d5a1f7d9251308510ce8858940e6fa99ece33f610838b", size = 220101, upload-time = "2025-09-08T23:23:04.792Z" }, + { url = "https://files.pythonhosted.org/packages/f2/7f/e6647792fc5850d634695bc0e6ab4111ae88e89981d35ac269956605feba/cffi-2.0.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f93fd8e5c8c0a4aa1f424d6173f14a892044054871c771f8566e4008eaa359d2", size = 207948, upload-time = "2025-09-08T23:23:06.127Z" }, + { url = "https://files.pythonhosted.org/packages/cb/1e/a5a1bd6f1fb30f22573f76533de12a00bf274abcdc55c8edab639078abb6/cffi-2.0.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:dd4f05f54a52fb558f1ba9f528228066954fee3ebe629fc1660d874d040ae5a3", size = 206422, upload-time = "2025-09-08T23:23:07.753Z" }, + { url = "https://files.pythonhosted.org/packages/98/df/0a1755e750013a2081e863e7cd37e0cdd02664372c754e5560099eb7aa44/cffi-2.0.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c8d3b5532fc71b7a77c09192b4a5a200ea992702734a2e9279a37f2478236f26", size = 219499, upload-time = "2025-09-08T23:23:09.648Z" }, + { url = "https://files.pythonhosted.org/packages/50/e1/a969e687fcf9ea58e6e2a928ad5e2dd88cc12f6f0ab477e9971f2309b57c/cffi-2.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d9b29c1f0ae438d5ee9acb31cadee00a58c46cc9c0b2f9038c6b0b3470877a8c", size = 222928, upload-time = "2025-09-08T23:23:10.928Z" }, + { url = "https://files.pythonhosted.org/packages/36/54/0362578dd2c9e557a28ac77698ed67323ed5b9775ca9d3fe73fe191bb5d8/cffi-2.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6d50360be4546678fc1b79ffe7a66265e28667840010348dd69a314145807a1b", size = 221302, upload-time = "2025-09-08T23:23:12.42Z" }, + { url = "https://files.pythonhosted.org/packages/eb/6d/bf9bda840d5f1dfdbf0feca87fbdb64a918a69bca42cfa0ba7b137c48cb8/cffi-2.0.0-cp313-cp313-win32.whl", hash = "sha256:74a03b9698e198d47562765773b4a8309919089150a0bb17d829ad7b44b60d27", size = 172909, upload-time = "2025-09-08T23:23:14.32Z" }, + { url = "https://files.pythonhosted.org/packages/37/18/6519e1ee6f5a1e579e04b9ddb6f1676c17368a7aba48299c3759bbc3c8b3/cffi-2.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:19f705ada2530c1167abacb171925dd886168931e0a7b78f5bffcae5c6b5be75", size = 183402, upload-time = "2025-09-08T23:23:15.535Z" }, + { url = "https://files.pythonhosted.org/packages/cb/0e/02ceeec9a7d6ee63bb596121c2c8e9b3a9e150936f4fbef6ca1943e6137c/cffi-2.0.0-cp313-cp313-win_arm64.whl", hash = "sha256:256f80b80ca3853f90c21b23ee78cd008713787b1b1e93eae9f3d6a7134abd91", size = 177780, upload-time = "2025-09-08T23:23:16.761Z" }, +] + +[[package]] +name = "cfgv" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/11/74/539e56497d9bd1d484fd863dd69cbbfa653cd2aa27abfe35653494d85e94/cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560", size = 7114, upload-time = "2023-08-12T20:38:17.776Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/55/51844dd50c4fc7a33b653bfaba4c2456f06955289ca770a5dbd5fd267374/cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9", size = 7249, upload-time = "2023-08-12T20:38:16.269Z" }, +] + +[[package]] +name = "charset-normalizer" +version = "3.4.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/97/45/4b3a1239bbacd321068ea6e7ac28875b03ab8bc0aa0966452db17cd36714/charset_normalizer-3.4.4-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e1f185f86a6f3403aa2420e815904c67b2f9ebc443f045edd0de921108345794", size = 208091, upload-time = "2025-10-14T04:41:13.346Z" }, + { url = "https://files.pythonhosted.org/packages/7d/62/73a6d7450829655a35bb88a88fca7d736f9882a27eacdca2c6d505b57e2e/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b39f987ae8ccdf0d2642338faf2abb1862340facc796048b604ef14919e55ed", size = 147936, upload-time = "2025-10-14T04:41:14.461Z" }, + { url = "https://files.pythonhosted.org/packages/89/c5/adb8c8b3d6625bef6d88b251bbb0d95f8205831b987631ab0c8bb5d937c2/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3162d5d8ce1bb98dd51af660f2121c55d0fa541b46dff7bb9b9f86ea1d87de72", size = 144180, upload-time = "2025-10-14T04:41:15.588Z" }, + { url = "https://files.pythonhosted.org/packages/91/ed/9706e4070682d1cc219050b6048bfd293ccf67b3d4f5a4f39207453d4b99/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:81d5eb2a312700f4ecaa977a8235b634ce853200e828fbadf3a9c50bab278328", size = 161346, upload-time = "2025-10-14T04:41:16.738Z" }, + { url = "https://files.pythonhosted.org/packages/d5/0d/031f0d95e4972901a2f6f09ef055751805ff541511dc1252ba3ca1f80cf5/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5bd2293095d766545ec1a8f612559f6b40abc0eb18bb2f5d1171872d34036ede", size = 158874, upload-time = "2025-10-14T04:41:17.923Z" }, + { url = "https://files.pythonhosted.org/packages/f5/83/6ab5883f57c9c801ce5e5677242328aa45592be8a00644310a008d04f922/charset_normalizer-3.4.4-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a8a8b89589086a25749f471e6a900d3f662d1d3b6e2e59dcecf787b1cc3a1894", size = 153076, upload-time = "2025-10-14T04:41:19.106Z" }, + { url = "https://files.pythonhosted.org/packages/75/1e/5ff781ddf5260e387d6419959ee89ef13878229732732ee73cdae01800f2/charset_normalizer-3.4.4-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:bc7637e2f80d8530ee4a78e878bce464f70087ce73cf7c1caf142416923b98f1", size = 150601, upload-time = "2025-10-14T04:41:20.245Z" }, + { url = "https://files.pythonhosted.org/packages/d7/57/71be810965493d3510a6ca79b90c19e48696fb1ff964da319334b12677f0/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f8bf04158c6b607d747e93949aa60618b61312fe647a6369f88ce2ff16043490", size = 150376, upload-time = "2025-10-14T04:41:21.398Z" }, + { url = "https://files.pythonhosted.org/packages/e5/d5/c3d057a78c181d007014feb7e9f2e65905a6c4ef182c0ddf0de2924edd65/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:554af85e960429cf30784dd47447d5125aaa3b99a6f0683589dbd27e2f45da44", size = 144825, upload-time = "2025-10-14T04:41:22.583Z" }, + { url = "https://files.pythonhosted.org/packages/e6/8c/d0406294828d4976f275ffbe66f00266c4b3136b7506941d87c00cab5272/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:74018750915ee7ad843a774364e13a3db91682f26142baddf775342c3f5b1133", size = 162583, upload-time = "2025-10-14T04:41:23.754Z" }, + { url = "https://files.pythonhosted.org/packages/d7/24/e2aa1f18c8f15c4c0e932d9287b8609dd30ad56dbe41d926bd846e22fb8d/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:c0463276121fdee9c49b98908b3a89c39be45d86d1dbaa22957e38f6321d4ce3", size = 150366, upload-time = "2025-10-14T04:41:25.27Z" }, + { url = "https://files.pythonhosted.org/packages/e4/5b/1e6160c7739aad1e2df054300cc618b06bf784a7a164b0f238360721ab86/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:362d61fd13843997c1c446760ef36f240cf81d3ebf74ac62652aebaf7838561e", size = 160300, upload-time = "2025-10-14T04:41:26.725Z" }, + { url = "https://files.pythonhosted.org/packages/7a/10/f882167cd207fbdd743e55534d5d9620e095089d176d55cb22d5322f2afd/charset_normalizer-3.4.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a26f18905b8dd5d685d6d07b0cdf98a79f3c7a918906af7cc143ea2e164c8bc", size = 154465, upload-time = "2025-10-14T04:41:28.322Z" }, + { url = "https://files.pythonhosted.org/packages/89/66/c7a9e1b7429be72123441bfdbaf2bc13faab3f90b933f664db506dea5915/charset_normalizer-3.4.4-cp313-cp313-win32.whl", hash = "sha256:9b35f4c90079ff2e2edc5b26c0c77925e5d2d255c42c74fdb70fb49b172726ac", size = 99404, upload-time = "2025-10-14T04:41:29.95Z" }, + { url = "https://files.pythonhosted.org/packages/c4/26/b9924fa27db384bdcd97ab83b4f0a8058d96ad9626ead570674d5e737d90/charset_normalizer-3.4.4-cp313-cp313-win_amd64.whl", hash = "sha256:b435cba5f4f750aa6c0a0d92c541fb79f69a387c91e61f1795227e4ed9cece14", size = 107092, upload-time = "2025-10-14T04:41:31.188Z" }, + { url = "https://files.pythonhosted.org/packages/af/8f/3ed4bfa0c0c72a7ca17f0380cd9e4dd842b09f664e780c13cff1dcf2ef1b/charset_normalizer-3.4.4-cp313-cp313-win_arm64.whl", hash = "sha256:542d2cee80be6f80247095cc36c418f7bddd14f4a6de45af91dfad36d817bba2", size = 100408, upload-time = "2025-10-14T04:41:32.624Z" }, + { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, +] + +[[package]] +name = "click" +version = "8.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/46/61/de6cd827efad202d7057d93e0fed9294b96952e188f7384832791c7b2254/click-8.3.0.tar.gz", hash = "sha256:e7b8232224eba16f4ebe410c25ced9f7875cb5f3263ffc93cc3e8da705e229c4", size = 276943, upload-time = "2025-09-18T17:32:23.696Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/d3/9dcc0f5797f070ec8edf30fbadfb200e71d9db6b84d211e3b2085a7589a0/click-8.3.0-py3-none-any.whl", hash = "sha256:9b9f285302c6e3064f4330c05f05b81945b2a39544279343e6e7c5f27a9baddc", size = 107295, upload-time = "2025-09-18T17:32:22.42Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "comm" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/13/7d740c5849255756bc17888787313b61fd38a0a8304fc4f073dfc46122aa/comm-0.2.3.tar.gz", hash = "sha256:2dc8048c10962d55d7ad693be1e7045d891b7ce8d999c97963a5e3e99c055971", size = 6319, upload-time = "2025-07-25T14:02:04.452Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294, upload-time = "2025-07-25T14:02:02.896Z" }, +] + +[[package]] +name = "contourpy" +version = "1.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/01/1253e6698a07380cd31a736d248a3f2a50a7c88779a1813da27503cadc2a/contourpy-1.3.3.tar.gz", hash = "sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880", size = 13466174, upload-time = "2025-07-26T12:03:12.549Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/68/35/0167aad910bbdb9599272bd96d01a9ec6852f36b9455cf2ca67bd4cc2d23/contourpy-1.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5", size = 293257, upload-time = "2025-07-26T12:01:39.367Z" }, + { url = "https://files.pythonhosted.org/packages/96/e4/7adcd9c8362745b2210728f209bfbcf7d91ba868a2c5f40d8b58f54c509b/contourpy-1.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1", size = 274034, upload-time = "2025-07-26T12:01:40.645Z" }, + { url = "https://files.pythonhosted.org/packages/73/23/90e31ceeed1de63058a02cb04b12f2de4b40e3bef5e082a7c18d9c8ae281/contourpy-1.3.3-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286", size = 334672, upload-time = "2025-07-26T12:01:41.942Z" }, + { url = "https://files.pythonhosted.org/packages/ed/93/b43d8acbe67392e659e1d984700e79eb67e2acb2bd7f62012b583a7f1b55/contourpy-1.3.3-cp313-cp313-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:655456777ff65c2c548b7c454af9c6f33f16c8884f11083244b5819cc214f1b5", size = 381234, upload-time = "2025-07-26T12:01:43.499Z" }, + { url = "https://files.pythonhosted.org/packages/46/3b/bec82a3ea06f66711520f75a40c8fc0b113b2a75edb36aa633eb11c4f50f/contourpy-1.3.3-cp313-cp313-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:644a6853d15b2512d67881586bd03f462c7ab755db95f16f14d7e238f2852c67", size = 385169, upload-time = "2025-07-26T12:01:45.219Z" }, + { url = "https://files.pythonhosted.org/packages/4b/32/e0f13a1c5b0f8572d0ec6ae2f6c677b7991fafd95da523159c19eff0696a/contourpy-1.3.3-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4debd64f124ca62069f313a9cb86656ff087786016d76927ae2cf37846b006c9", size = 362859, upload-time = "2025-07-26T12:01:46.519Z" }, + { url = "https://files.pythonhosted.org/packages/33/71/e2a7945b7de4e58af42d708a219f3b2f4cff7386e6b6ab0a0fa0033c49a9/contourpy-1.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a15459b0f4615b00bbd1e91f1b9e19b7e63aea7483d03d804186f278c0af2659", size = 1332062, upload-time = "2025-07-26T12:01:48.964Z" }, + { url = "https://files.pythonhosted.org/packages/12/fc/4e87ac754220ccc0e807284f88e943d6d43b43843614f0a8afa469801db0/contourpy-1.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ca0fdcd73925568ca027e0b17ab07aad764be4706d0a925b89227e447d9737b7", size = 1403932, upload-time = "2025-07-26T12:01:51.979Z" }, + { url = "https://files.pythonhosted.org/packages/a6/2e/adc197a37443f934594112222ac1aa7dc9a98faf9c3842884df9a9d8751d/contourpy-1.3.3-cp313-cp313-win32.whl", hash = "sha256:b20c7c9a3bf701366556e1b1984ed2d0cedf999903c51311417cf5f591d8c78d", size = 185024, upload-time = "2025-07-26T12:01:53.245Z" }, + { url = "https://files.pythonhosted.org/packages/18/0b/0098c214843213759692cc638fce7de5c289200a830e5035d1791d7a2338/contourpy-1.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:1cadd8b8969f060ba45ed7c1b714fe69185812ab43bd6b86a9123fe8f99c3263", size = 226578, upload-time = "2025-07-26T12:01:54.422Z" }, + { url = "https://files.pythonhosted.org/packages/8a/9a/2f6024a0c5995243cd63afdeb3651c984f0d2bc727fd98066d40e141ad73/contourpy-1.3.3-cp313-cp313-win_arm64.whl", hash = "sha256:fd914713266421b7536de2bfa8181aa8c699432b6763a0ea64195ebe28bff6a9", size = 193524, upload-time = "2025-07-26T12:01:55.73Z" }, + { url = "https://files.pythonhosted.org/packages/c0/b3/f8a1a86bd3298513f500e5b1f5fd92b69896449f6cab6a146a5d52715479/contourpy-1.3.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:88df9880d507169449d434c293467418b9f6cbe82edd19284aa0409e7fdb933d", size = 306730, upload-time = "2025-07-26T12:01:57.051Z" }, + { url = "https://files.pythonhosted.org/packages/3f/11/4780db94ae62fc0c2053909b65dc3246bd7cecfc4f8a20d957ad43aa4ad8/contourpy-1.3.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:d06bb1f751ba5d417047db62bca3c8fde202b8c11fb50742ab3ab962c81e8216", size = 287897, upload-time = "2025-07-26T12:01:58.663Z" }, + { url = "https://files.pythonhosted.org/packages/ae/15/e59f5f3ffdd6f3d4daa3e47114c53daabcb18574a26c21f03dc9e4e42ff0/contourpy-1.3.3-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e4e6b05a45525357e382909a4c1600444e2a45b4795163d3b22669285591c1ae", size = 326751, upload-time = "2025-07-26T12:02:00.343Z" }, + { url = "https://files.pythonhosted.org/packages/0f/81/03b45cfad088e4770b1dcf72ea78d3802d04200009fb364d18a493857210/contourpy-1.3.3-cp313-cp313t-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ab3074b48c4e2cf1a960e6bbeb7f04566bf36b1861d5c9d4d8ac04b82e38ba20", size = 375486, upload-time = "2025-07-26T12:02:02.128Z" }, + { url = "https://files.pythonhosted.org/packages/0c/ba/49923366492ffbdd4486e970d421b289a670ae8cf539c1ea9a09822b371a/contourpy-1.3.3-cp313-cp313t-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:6c3d53c796f8647d6deb1abe867daeb66dcc8a97e8455efa729516b997b8ed99", size = 388106, upload-time = "2025-07-26T12:02:03.615Z" }, + { url = "https://files.pythonhosted.org/packages/9f/52/5b00ea89525f8f143651f9f03a0df371d3cbd2fccd21ca9b768c7a6500c2/contourpy-1.3.3-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:50ed930df7289ff2a8d7afeb9603f8289e5704755c7e5c3bbd929c90c817164b", size = 352548, upload-time = "2025-07-26T12:02:05.165Z" }, + { url = "https://files.pythonhosted.org/packages/32/1d/a209ec1a3a3452d490f6b14dd92e72280c99ae3d1e73da74f8277d4ee08f/contourpy-1.3.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4feffb6537d64b84877da813a5c30f1422ea5739566abf0bd18065ac040e120a", size = 1322297, upload-time = "2025-07-26T12:02:07.379Z" }, + { url = "https://files.pythonhosted.org/packages/bc/9e/46f0e8ebdd884ca0e8877e46a3f4e633f6c9c8c4f3f6e72be3fe075994aa/contourpy-1.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2b7e9480ffe2b0cd2e787e4df64270e3a0440d9db8dc823312e2c940c167df7e", size = 1391023, upload-time = "2025-07-26T12:02:10.171Z" }, + { url = "https://files.pythonhosted.org/packages/b9/70/f308384a3ae9cd2209e0849f33c913f658d3326900d0ff5d378d6a1422d2/contourpy-1.3.3-cp313-cp313t-win32.whl", hash = "sha256:283edd842a01e3dcd435b1c5116798d661378d83d36d337b8dde1d16a5fc9ba3", size = 196157, upload-time = "2025-07-26T12:02:11.488Z" }, + { url = "https://files.pythonhosted.org/packages/b2/dd/880f890a6663b84d9e34a6f88cded89d78f0091e0045a284427cb6b18521/contourpy-1.3.3-cp313-cp313t-win_amd64.whl", hash = "sha256:87acf5963fc2b34825e5b6b048f40e3635dd547f590b04d2ab317c2619ef7ae8", size = 240570, upload-time = "2025-07-26T12:02:12.754Z" }, + { url = "https://files.pythonhosted.org/packages/80/99/2adc7d8ffead633234817ef8e9a87115c8a11927a94478f6bb3d3f4d4f7d/contourpy-1.3.3-cp313-cp313t-win_arm64.whl", hash = "sha256:3c30273eb2a55024ff31ba7d052dde990d7d8e5450f4bbb6e913558b3d6c2301", size = 199713, upload-time = "2025-07-26T12:02:14.4Z" }, +] + +[[package]] +name = "coverage" +version = "7.11.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/38/ee22495420457259d2f3390309505ea98f98a5eed40901cf62196abad006/coverage-7.11.0.tar.gz", hash = "sha256:167bd504ac1ca2af7ff3b81d245dfea0292c5032ebef9d66cc08a7d28c1b8050", size = 811905, upload-time = "2025-10-15T15:15:08.542Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/7f/85e4dfe65e400645464b25c036a26ac226cf3a69d4a50c3934c532491cdd/coverage-7.11.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:cc3f49e65ea6e0d5d9bd60368684fe52a704d46f9e7fc413918f18d046ec40e1", size = 216129, upload-time = "2025-10-15T15:13:25.371Z" }, + { url = "https://files.pythonhosted.org/packages/96/5d/dc5fa98fea3c175caf9d360649cb1aa3715e391ab00dc78c4c66fabd7356/coverage-7.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f39ae2f63f37472c17b4990f794035c9890418b1b8cca75c01193f3c8d3e01be", size = 216380, upload-time = "2025-10-15T15:13:26.976Z" }, + { url = "https://files.pythonhosted.org/packages/b2/f5/3da9cc9596708273385189289c0e4d8197d37a386bdf17619013554b3447/coverage-7.11.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7db53b5cdd2917b6eaadd0b1251cf4e7d96f4a8d24e174bdbdf2f65b5ea7994d", size = 247375, upload-time = "2025-10-15T15:13:28.923Z" }, + { url = "https://files.pythonhosted.org/packages/65/6c/f7f59c342359a235559d2bc76b0c73cfc4bac7d61bb0df210965cb1ecffd/coverage-7.11.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10ad04ac3a122048688387828b4537bc9cf60c0bf4869c1e9989c46e45690b82", size = 249978, upload-time = "2025-10-15T15:13:30.525Z" }, + { url = "https://files.pythonhosted.org/packages/e7/8c/042dede2e23525e863bf1ccd2b92689692a148d8b5fd37c37899ba882645/coverage-7.11.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4036cc9c7983a2b1f2556d574d2eb2154ac6ed55114761685657e38782b23f52", size = 251253, upload-time = "2025-10-15T15:13:32.174Z" }, + { url = "https://files.pythonhosted.org/packages/7b/a9/3c58df67bfa809a7bddd786356d9c5283e45d693edb5f3f55d0986dd905a/coverage-7.11.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:7ab934dd13b1c5e94b692b1e01bd87e4488cb746e3a50f798cb9464fd128374b", size = 247591, upload-time = "2025-10-15T15:13:34.147Z" }, + { url = "https://files.pythonhosted.org/packages/26/5b/c7f32efd862ee0477a18c41e4761305de6ddd2d49cdeda0c1116227570fd/coverage-7.11.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59a6e5a265f7cfc05f76e3bb53eca2e0dfe90f05e07e849930fecd6abb8f40b4", size = 249411, upload-time = "2025-10-15T15:13:38.425Z" }, + { url = "https://files.pythonhosted.org/packages/76/b5/78cb4f1e86c1611431c990423ec0768122905b03837e1b4c6a6f388a858b/coverage-7.11.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:df01d6c4c81e15a7c88337b795bb7595a8596e92310266b5072c7e301168efbd", size = 247303, upload-time = "2025-10-15T15:13:40.464Z" }, + { url = "https://files.pythonhosted.org/packages/87/c9/23c753a8641a330f45f221286e707c427e46d0ffd1719b080cedc984ec40/coverage-7.11.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:8c934bd088eed6174210942761e38ee81d28c46de0132ebb1801dbe36a390dcc", size = 247157, upload-time = "2025-10-15T15:13:42.087Z" }, + { url = "https://files.pythonhosted.org/packages/c5/42/6e0cc71dc8a464486e944a4fa0d85bdec031cc2969e98ed41532a98336b9/coverage-7.11.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5a03eaf7ec24078ad64a07f02e30060aaf22b91dedf31a6b24d0d98d2bba7f48", size = 248921, upload-time = "2025-10-15T15:13:43.715Z" }, + { url = "https://files.pythonhosted.org/packages/e8/1c/743c2ef665e6858cccb0f84377dfe3a4c25add51e8c7ef19249be92465b6/coverage-7.11.0-cp313-cp313-win32.whl", hash = "sha256:695340f698a5f56f795b2836abe6fb576e7c53d48cd155ad2f80fd24bc63a040", size = 218526, upload-time = "2025-10-15T15:13:45.336Z" }, + { url = "https://files.pythonhosted.org/packages/ff/d5/226daadfd1bf8ddbccefbd3aa3547d7b960fb48e1bdac124e2dd13a2b71a/coverage-7.11.0-cp313-cp313-win_amd64.whl", hash = "sha256:2727d47fce3ee2bac648528e41455d1b0c46395a087a229deac75e9f88ba5a05", size = 219317, upload-time = "2025-10-15T15:13:47.401Z" }, + { url = "https://files.pythonhosted.org/packages/97/54/47db81dcbe571a48a298f206183ba8a7ba79200a37cd0d9f4788fcd2af4a/coverage-7.11.0-cp313-cp313-win_arm64.whl", hash = "sha256:0efa742f431529699712b92ecdf22de8ff198df41e43aeaaadf69973eb93f17a", size = 217948, upload-time = "2025-10-15T15:13:49.096Z" }, + { url = "https://files.pythonhosted.org/packages/e5/8b/cb68425420154e7e2a82fd779a8cc01549b6fa83c2ad3679cd6c088ebd07/coverage-7.11.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:587c38849b853b157706407e9ebdca8fd12f45869edb56defbef2daa5fb0812b", size = 216837, upload-time = "2025-10-15T15:13:51.09Z" }, + { url = "https://files.pythonhosted.org/packages/33/55/9d61b5765a025685e14659c8d07037247de6383c0385757544ffe4606475/coverage-7.11.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b971bdefdd75096163dd4261c74be813c4508477e39ff7b92191dea19f24cd37", size = 217061, upload-time = "2025-10-15T15:13:52.747Z" }, + { url = "https://files.pythonhosted.org/packages/52/85/292459c9186d70dcec6538f06ea251bc968046922497377bf4a1dc9a71de/coverage-7.11.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:269bfe913b7d5be12ab13a95f3a76da23cf147be7fa043933320ba5625f0a8de", size = 258398, upload-time = "2025-10-15T15:13:54.45Z" }, + { url = "https://files.pythonhosted.org/packages/1f/e2/46edd73fb8bf51446c41148d81944c54ed224854812b6ca549be25113ee0/coverage-7.11.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:dadbcce51a10c07b7c72b0ce4a25e4b6dcb0c0372846afb8e5b6307a121eb99f", size = 260574, upload-time = "2025-10-15T15:13:56.145Z" }, + { url = "https://files.pythonhosted.org/packages/07/5e/1df469a19007ff82e2ca8fe509822820a31e251f80ee7344c34f6cd2ec43/coverage-7.11.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ed43fa22c6436f7957df036331f8fe4efa7af132054e1844918866cd228af6c", size = 262797, upload-time = "2025-10-15T15:13:58.635Z" }, + { url = "https://files.pythonhosted.org/packages/f9/50/de216b31a1434b94d9b34a964c09943c6be45069ec704bfc379d8d89a649/coverage-7.11.0-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9516add7256b6713ec08359b7b05aeff8850c98d357784c7205b2e60aa2513fa", size = 257361, upload-time = "2025-10-15T15:14:00.409Z" }, + { url = "https://files.pythonhosted.org/packages/82/1e/3f9f8344a48111e152e0fd495b6fff13cc743e771a6050abf1627a7ba918/coverage-7.11.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:eb92e47c92fcbcdc692f428da67db33337fa213756f7adb6a011f7b5a7a20740", size = 260349, upload-time = "2025-10-15T15:14:02.188Z" }, + { url = "https://files.pythonhosted.org/packages/65/9b/3f52741f9e7d82124272f3070bbe316006a7de1bad1093f88d59bfc6c548/coverage-7.11.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:d06f4fc7acf3cabd6d74941d53329e06bab00a8fe10e4df2714f0b134bfc64ef", size = 258114, upload-time = "2025-10-15T15:14:03.907Z" }, + { url = "https://files.pythonhosted.org/packages/0b/8b/918f0e15f0365d50d3986bbd3338ca01178717ac5678301f3f547b6619e6/coverage-7.11.0-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:6fbcee1a8f056af07ecd344482f711f563a9eb1c2cad192e87df00338ec3cdb0", size = 256723, upload-time = "2025-10-15T15:14:06.324Z" }, + { url = "https://files.pythonhosted.org/packages/44/9e/7776829f82d3cf630878a7965a7d70cc6ca94f22c7d20ec4944f7148cb46/coverage-7.11.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dbbf012be5f32533a490709ad597ad8a8ff80c582a95adc8d62af664e532f9ca", size = 259238, upload-time = "2025-10-15T15:14:08.002Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b8/49cf253e1e7a3bedb85199b201862dd7ca4859f75b6cf25ffa7298aa0760/coverage-7.11.0-cp313-cp313t-win32.whl", hash = "sha256:cee6291bb4fed184f1c2b663606a115c743df98a537c969c3c64b49989da96c2", size = 219180, upload-time = "2025-10-15T15:14:09.786Z" }, + { url = "https://files.pythonhosted.org/packages/ac/e1/1a541703826be7ae2125a0fb7f821af5729d56bb71e946e7b933cc7a89a4/coverage-7.11.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a386c1061bf98e7ea4758e4313c0ab5ecf57af341ef0f43a0bf26c2477b5c268", size = 220241, upload-time = "2025-10-15T15:14:11.471Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d1/5ee0e0a08621140fd418ec4020f595b4d52d7eb429ae6a0c6542b4ba6f14/coverage-7.11.0-cp313-cp313t-win_arm64.whl", hash = "sha256:f9ea02ef40bb83823b2b04964459d281688fe173e20643870bb5d2edf68bc836", size = 218510, upload-time = "2025-10-15T15:14:13.46Z" }, + { url = "https://files.pythonhosted.org/packages/5f/04/642c1d8a448ae5ea1369eac8495740a79eb4e581a9fb0cbdce56bbf56da1/coverage-7.11.0-py3-none-any.whl", hash = "sha256:4b7589765348d78fb4e5fb6ea35d07564e387da2fc5efff62e0222971f155f68", size = 207761, upload-time = "2025-10-15T15:15:06.439Z" }, +] + +[[package]] +name = "cycler" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/95/a3dbbb5028f35eafb79008e7522a75244477d2838f38cbb722248dabc2a8/cycler-0.12.1.tar.gz", hash = "sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c", size = 7615, upload-time = "2023-10-07T05:32:18.335Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30", size = 8321, upload-time = "2023-10-07T05:32:16.783Z" }, +] + +[[package]] +name = "datasets" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, + { name = "filelock" }, + { name = "fsspec", extra = ["http"] }, + { name = "httpx" }, + { name = "huggingface-hub" }, + { name = "multiprocess" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pyarrow" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "xxhash" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/47/325206ac160f7699ed9f1798afa8f8f8d5189b03bf3815654859ac1d5cba/datasets-4.3.0.tar.gz", hash = "sha256:bc9118ed9afd92346c5be7ed3aaa00177eb907c25467f9d072a0d22777efbd2b", size = 582801, upload-time = "2025-10-23T16:31:51.547Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/51/409a8184ed35453d9cbb3d6b20d524b1115c2c2d117b85d5e9b06cd70b45/datasets-4.3.0-py3-none-any.whl", hash = "sha256:0ea157e72138b3ca6c7d2415f19a164ecf7d4c4fa72da2a570da286882e96903", size = 506846, upload-time = "2025-10-23T16:31:49.965Z" }, +] + +[[package]] +name = "debugpy" +version = "1.8.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/ad/71e708ff4ca377c4230530d6a7aa7992592648c122a2cd2b321cf8b35a76/debugpy-1.8.17.tar.gz", hash = "sha256:fd723b47a8c08892b1a16b2c6239a8b96637c62a59b94bb5dab4bac592a58a8e", size = 1644129, upload-time = "2025-09-17T16:33:20.633Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/76/597e5cb97d026274ba297af8d89138dfd9e695767ba0e0895edb20963f40/debugpy-1.8.17-cp313-cp313-macosx_15_0_universal2.whl", hash = "sha256:857c1dd5d70042502aef1c6d1c2801211f3ea7e56f75e9c335f434afb403e464", size = 2538386, upload-time = "2025-09-17T16:33:54.594Z" }, + { url = "https://files.pythonhosted.org/packages/5f/60/ce5c34fcdfec493701f9d1532dba95b21b2f6394147234dce21160bd923f/debugpy-1.8.17-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:3bea3b0b12f3946e098cce9b43c3c46e317b567f79570c3f43f0b96d00788088", size = 4292100, upload-time = "2025-09-17T16:33:56.353Z" }, + { url = "https://files.pythonhosted.org/packages/e8/95/7873cf2146577ef71d2a20bf553f12df865922a6f87b9e8ee1df04f01785/debugpy-1.8.17-cp313-cp313-win32.whl", hash = "sha256:e34ee844c2f17b18556b5bbe59e1e2ff4e86a00282d2a46edab73fd7f18f4a83", size = 5277002, upload-time = "2025-09-17T16:33:58.231Z" }, + { url = "https://files.pythonhosted.org/packages/46/11/18c79a1cee5ff539a94ec4aa290c1c069a5580fd5cfd2fb2e282f8e905da/debugpy-1.8.17-cp313-cp313-win_amd64.whl", hash = "sha256:6c5cd6f009ad4fca8e33e5238210dc1e5f42db07d4b6ab21ac7ffa904a196420", size = 5319047, upload-time = "2025-09-17T16:34:00.586Z" }, + { url = "https://files.pythonhosted.org/packages/b0/d0/89247ec250369fc76db477720a26b2fce7ba079ff1380e4ab4529d2fe233/debugpy-1.8.17-py2.py3-none-any.whl", hash = "sha256:60c7dca6571efe660ccb7a9508d73ca14b8796c4ed484c2002abba714226cfef", size = 5283210, upload-time = "2025-09-17T16:34:25.835Z" }, +] + +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + +[[package]] +name = "dill" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/80/630b4b88364e9a8c8c5797f4602d0f76ef820909ee32f0bacb9f90654042/dill-0.4.0.tar.gz", hash = "sha256:0633f1d2df477324f53a895b02c901fb961bdbf65a17122586ea7019292cbcf0", size = 186976, upload-time = "2025-04-16T00:41:48.867Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/3d/9373ad9c56321fdab5b41197068e1d8c25883b3fea29dd361f9b55116869/dill-0.4.0-py3-none-any.whl", hash = "sha256:44f54bf6412c2c8464c14e8243eb163690a9800dbe2c367330883b19c7561049", size = 119668, upload-time = "2025-04-16T00:41:47.671Z" }, +] + +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + +[[package]] +name = "einops" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e5/81/df4fbe24dff8ba3934af99044188e20a98ed441ad17a274539b74e82e126/einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84", size = 54805, upload-time = "2025-02-09T03:17:00.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359, upload-time = "2025-02-09T03:17:01.998Z" }, +] + +[[package]] +name = "execnet" +version = "2.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bb/ff/b4c0dc78fbe20c3e59c0c7334de0c27eb4001a2b2017999af398bf730817/execnet-2.1.1.tar.gz", hash = "sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3", size = 166524, upload-time = "2024-04-08T09:04:19.245Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/09/2aea36ff60d16dd8879bdb2f5b3ee0ba8d08cbbdcdfe870e695ce3784385/execnet-2.1.1-py3-none-any.whl", hash = "sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc", size = 40612, upload-time = "2024-04-08T09:04:17.414Z" }, +] + +[[package]] +name = "executing" +version = "2.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cc/28/c14e053b6762b1044f34a13aab6859bbf40456d37d23aa286ac24cfd9a5d/executing-2.2.1.tar.gz", hash = "sha256:3632cc370565f6648cc328b32435bd120a1e4ebb20c77e3fdde9a13cd1e533c4", size = 1129488, upload-time = "2025-09-01T09:48:10.866Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/ea/53f2148663b321f21b5a606bd5f191517cf40b7072c0497d3c92c4a13b1e/executing-2.2.1-py2.py3-none-any.whl", hash = "sha256:760643d3452b4d777d295bb167ccc74c64a81df23fb5e08eff250c425a4b2017", size = 28317, upload-time = "2025-09-01T09:48:08.5Z" }, +] + +[[package]] +name = "fastapi" +version = "0.120.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/cc/28aff6e246ee85bd571b26e4a793b84d42700e3bdc3008c3d747eda7b06d/fastapi-0.120.1.tar.gz", hash = "sha256:b5c6217e9ddca6dfcf54c97986180d4a1955e10c693d74943fc5327700178bff", size = 337616, upload-time = "2025-10-27T17:53:42.954Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/bb/1a74dbe87e9a595bf63052c886dfef965dc5b91d149456a8301eb3d41ce2/fastapi-0.120.1-py3-none-any.whl", hash = "sha256:0e8a2c328e96c117272d8c794d3a97d205f753cc2e69dd7ee387b7488a75601f", size = 108254, upload-time = "2025-10-27T17:53:40.076Z" }, +] + +[[package]] +name = "filelock" +version = "3.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/46/0028a82567109b5ef6e4d2a1f04a583fb513e6cf9527fcdd09afd817deeb/filelock-3.20.0.tar.gz", hash = "sha256:711e943b4ec6be42e1d4e6690b48dc175c822967466bb31c0c293f34334c13f4", size = 18922, upload-time = "2025-10-08T18:03:50.056Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/91/7216b27286936c16f5b4d0c530087e4a54eead683e6b0b73dd0c64844af6/filelock-3.20.0-py3-none-any.whl", hash = "sha256:339b4732ffda5cd79b13f4e2711a31b0365ce445d95d243bb996273d072546a2", size = 16054, upload-time = "2025-10-08T18:03:48.35Z" }, +] + +[[package]] +name = "fire" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "termcolor" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/00/f8d10588d2019d6d6452653def1ee807353b21983db48550318424b5ff18/fire-0.7.1.tar.gz", hash = "sha256:3b208f05c736de98fb343310d090dcc4d8c78b2a89ea4f32b837c586270a9cbf", size = 88720, upload-time = "2025-08-16T20:20:24.175Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/4c/93d0f85318da65923e4b91c1c2ff03d8a458cbefebe3bc612a6693c7906d/fire-0.7.1-py3-none-any.whl", hash = "sha256:e43fd8a5033a9001e7e2973bab96070694b9f12f2e0ecf96d4683971b5ab1882", size = 115945, upload-time = "2025-08-16T20:20:22.87Z" }, +] + +[[package]] +name = "fonttools" +version = "4.60.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4b/42/97a13e47a1e51a5a7142475bbcf5107fe3a68fc34aef331c897d5fb98ad0/fonttools-4.60.1.tar.gz", hash = "sha256:ef00af0439ebfee806b25f24c8f92109157ff3fac5731dc7867957812e87b8d9", size = 3559823, upload-time = "2025-09-29T21:13:27.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/5b/cdd2c612277b7ac7ec8c0c9bc41812c43dc7b2d5f2b0897e15fdf5a1f915/fonttools-4.60.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6f68576bb4bbf6060c7ab047b1574a1ebe5c50a17de62830079967b211059ebb", size = 2825777, upload-time = "2025-09-29T21:12:01.22Z" }, + { url = "https://files.pythonhosted.org/packages/d6/8a/de9cc0540f542963ba5e8f3a1f6ad48fa211badc3177783b9d5cadf79b5d/fonttools-4.60.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:eedacb5c5d22b7097482fa834bda0dafa3d914a4e829ec83cdea2a01f8c813c4", size = 2348080, upload-time = "2025-09-29T21:12:03.785Z" }, + { url = "https://files.pythonhosted.org/packages/2d/8b/371ab3cec97ee3fe1126b3406b7abd60c8fec8975fd79a3c75cdea0c3d83/fonttools-4.60.1-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b33a7884fabd72bdf5f910d0cf46be50dce86a0362a65cfc746a4168c67eb96c", size = 4903082, upload-time = "2025-09-29T21:12:06.382Z" }, + { url = "https://files.pythonhosted.org/packages/04/05/06b1455e4bc653fcb2117ac3ef5fa3a8a14919b93c60742d04440605d058/fonttools-4.60.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2409d5fb7b55fd70f715e6d34e7a6e4f7511b8ad29a49d6df225ee76da76dd77", size = 4960125, upload-time = "2025-09-29T21:12:09.314Z" }, + { url = "https://files.pythonhosted.org/packages/8e/37/f3b840fcb2666f6cb97038793606bdd83488dca2d0b0fc542ccc20afa668/fonttools-4.60.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c8651e0d4b3bdeda6602b85fdc2abbefc1b41e573ecb37b6779c4ca50753a199", size = 4901454, upload-time = "2025-09-29T21:12:11.931Z" }, + { url = "https://files.pythonhosted.org/packages/fd/9e/eb76f77e82f8d4a46420aadff12cec6237751b0fb9ef1de373186dcffb5f/fonttools-4.60.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:145daa14bf24824b677b9357c5e44fd8895c2a8f53596e1b9ea3496081dc692c", size = 5044495, upload-time = "2025-09-29T21:12:15.241Z" }, + { url = "https://files.pythonhosted.org/packages/f8/b3/cede8f8235d42ff7ae891bae8d619d02c8ac9fd0cfc450c5927a6200c70d/fonttools-4.60.1-cp313-cp313-win32.whl", hash = "sha256:2299df884c11162617a66b7c316957d74a18e3758c0274762d2cc87df7bc0272", size = 2217028, upload-time = "2025-09-29T21:12:17.96Z" }, + { url = "https://files.pythonhosted.org/packages/75/4d/b022c1577807ce8b31ffe055306ec13a866f2337ecee96e75b24b9b753ea/fonttools-4.60.1-cp313-cp313-win_amd64.whl", hash = "sha256:a3db56f153bd4c5c2b619ab02c5db5192e222150ce5a1bc10f16164714bc39ac", size = 2266200, upload-time = "2025-09-29T21:12:20.14Z" }, + { url = "https://files.pythonhosted.org/packages/c7/93/0dd45cd283c32dea1545151d8c3637b4b8c53cdb3a625aeb2885b184d74d/fonttools-4.60.1-py3-none-any.whl", hash = "sha256:906306ac7afe2156fcf0042173d6ebbb05416af70f6b370967b47f8f00103bbb", size = 1143175, upload-time = "2025-09-29T21:13:24.134Z" }, +] + +[[package]] +name = "frozenlist" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/f5/c831fac6cc817d26fd54c7eaccd04ef7e0288806943f7cc5bbf69f3ac1f0/frozenlist-1.8.0.tar.gz", hash = "sha256:3ede829ed8d842f6cd48fc7081d7a41001a56f1f38603f9d49bf3020d59a31ad", size = 45875, upload-time = "2025-10-06T05:38:17.865Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2d/40/0832c31a37d60f60ed79e9dfb5a92e1e2af4f40a16a29abcc7992af9edff/frozenlist-1.8.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8d92f1a84bb12d9e56f818b3a746f3efba93c1b63c8387a73dde655e1e42282a", size = 85717, upload-time = "2025-10-06T05:36:27.341Z" }, + { url = "https://files.pythonhosted.org/packages/30/ba/b0b3de23f40bc55a7057bd38434e25c34fa48e17f20ee273bbde5e0650f3/frozenlist-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:96153e77a591c8adc2ee805756c61f59fef4cf4073a9275ee86fe8cba41241f7", size = 49651, upload-time = "2025-10-06T05:36:28.855Z" }, + { url = "https://files.pythonhosted.org/packages/0c/ab/6e5080ee374f875296c4243c381bbdef97a9ac39c6e3ce1d5f7d42cb78d6/frozenlist-1.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f21f00a91358803399890ab167098c131ec2ddd5f8f5fd5fe9c9f2c6fcd91e40", size = 49417, upload-time = "2025-10-06T05:36:29.877Z" }, + { url = "https://files.pythonhosted.org/packages/d5/4e/e4691508f9477ce67da2015d8c00acd751e6287739123113a9fca6f1604e/frozenlist-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:fb30f9626572a76dfe4293c7194a09fb1fe93ba94c7d4f720dfae3b646b45027", size = 234391, upload-time = "2025-10-06T05:36:31.301Z" }, + { url = "https://files.pythonhosted.org/packages/40/76/c202df58e3acdf12969a7895fd6f3bc016c642e6726aa63bd3025e0fc71c/frozenlist-1.8.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eaa352d7047a31d87dafcacbabe89df0aa506abb5b1b85a2fb91bc3faa02d822", size = 233048, upload-time = "2025-10-06T05:36:32.531Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c0/8746afb90f17b73ca5979c7a3958116e105ff796e718575175319b5bb4ce/frozenlist-1.8.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:03ae967b4e297f58f8c774c7eabcce57fe3c2434817d4385c50661845a058121", size = 226549, upload-time = "2025-10-06T05:36:33.706Z" }, + { url = "https://files.pythonhosted.org/packages/7e/eb/4c7eefc718ff72f9b6c4893291abaae5fbc0c82226a32dcd8ef4f7a5dbef/frozenlist-1.8.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f6292f1de555ffcc675941d65fffffb0a5bcd992905015f85d0592201793e0e5", size = 239833, upload-time = "2025-10-06T05:36:34.947Z" }, + { url = "https://files.pythonhosted.org/packages/c2/4e/e5c02187cf704224f8b21bee886f3d713ca379535f16893233b9d672ea71/frozenlist-1.8.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:29548f9b5b5e3460ce7378144c3010363d8035cea44bc0bf02d57f5a685e084e", size = 245363, upload-time = "2025-10-06T05:36:36.534Z" }, + { url = "https://files.pythonhosted.org/packages/1f/96/cb85ec608464472e82ad37a17f844889c36100eed57bea094518bf270692/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ec3cc8c5d4084591b4237c0a272cc4f50a5b03396a47d9caaf76f5d7b38a4f11", size = 229314, upload-time = "2025-10-06T05:36:38.582Z" }, + { url = "https://files.pythonhosted.org/packages/5d/6f/4ae69c550e4cee66b57887daeebe006fe985917c01d0fff9caab9883f6d0/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:517279f58009d0b1f2e7c1b130b377a349405da3f7621ed6bfae50b10adf20c1", size = 243365, upload-time = "2025-10-06T05:36:40.152Z" }, + { url = "https://files.pythonhosted.org/packages/7a/58/afd56de246cf11780a40a2c28dc7cbabbf06337cc8ddb1c780a2d97e88d8/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:db1e72ede2d0d7ccb213f218df6a078a9c09a7de257c2fe8fcef16d5925230b1", size = 237763, upload-time = "2025-10-06T05:36:41.355Z" }, + { url = "https://files.pythonhosted.org/packages/cb/36/cdfaf6ed42e2644740d4a10452d8e97fa1c062e2a8006e4b09f1b5fd7d63/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:b4dec9482a65c54a5044486847b8a66bf10c9cb4926d42927ec4e8fd5db7fed8", size = 240110, upload-time = "2025-10-06T05:36:42.716Z" }, + { url = "https://files.pythonhosted.org/packages/03/a8/9ea226fbefad669f11b52e864c55f0bd57d3c8d7eb07e9f2e9a0b39502e1/frozenlist-1.8.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:21900c48ae04d13d416f0e1e0c4d81f7931f73a9dfa0b7a8746fb2fe7dd970ed", size = 233717, upload-time = "2025-10-06T05:36:44.251Z" }, + { url = "https://files.pythonhosted.org/packages/1e/0b/1b5531611e83ba7d13ccc9988967ea1b51186af64c42b7a7af465dcc9568/frozenlist-1.8.0-cp313-cp313-win32.whl", hash = "sha256:8b7b94a067d1c504ee0b16def57ad5738701e4ba10cec90529f13fa03c833496", size = 39628, upload-time = "2025-10-06T05:36:45.423Z" }, + { url = "https://files.pythonhosted.org/packages/d8/cf/174c91dbc9cc49bc7b7aab74d8b734e974d1faa8f191c74af9b7e80848e6/frozenlist-1.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:878be833caa6a3821caf85eb39c5ba92d28e85df26d57afb06b35b2efd937231", size = 43882, upload-time = "2025-10-06T05:36:46.796Z" }, + { url = "https://files.pythonhosted.org/packages/c1/17/502cd212cbfa96eb1388614fe39a3fc9ab87dbbe042b66f97acb57474834/frozenlist-1.8.0-cp313-cp313-win_arm64.whl", hash = "sha256:44389d135b3ff43ba8cc89ff7f51f5a0bb6b63d829c8300f79a2fe4fe61bcc62", size = 39676, upload-time = "2025-10-06T05:36:47.8Z" }, + { url = "https://files.pythonhosted.org/packages/d2/5c/3bbfaa920dfab09e76946a5d2833a7cbdf7b9b4a91c714666ac4855b88b4/frozenlist-1.8.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:e25ac20a2ef37e91c1b39938b591457666a0fa835c7783c3a8f33ea42870db94", size = 89235, upload-time = "2025-10-06T05:36:48.78Z" }, + { url = "https://files.pythonhosted.org/packages/d2/d6/f03961ef72166cec1687e84e8925838442b615bd0b8854b54923ce5b7b8a/frozenlist-1.8.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:07cdca25a91a4386d2e76ad992916a85038a9b97561bf7a3fd12d5d9ce31870c", size = 50742, upload-time = "2025-10-06T05:36:49.837Z" }, + { url = "https://files.pythonhosted.org/packages/1e/bb/a6d12b7ba4c3337667d0e421f7181c82dda448ce4e7ad7ecd249a16fa806/frozenlist-1.8.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4e0c11f2cc6717e0a741f84a527c52616140741cd812a50422f83dc31749fb52", size = 51725, upload-time = "2025-10-06T05:36:50.851Z" }, + { url = "https://files.pythonhosted.org/packages/bc/71/d1fed0ffe2c2ccd70b43714c6cab0f4188f09f8a67a7914a6b46ee30f274/frozenlist-1.8.0-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b3210649ee28062ea6099cfda39e147fa1bc039583c8ee4481cb7811e2448c51", size = 284533, upload-time = "2025-10-06T05:36:51.898Z" }, + { url = "https://files.pythonhosted.org/packages/c9/1f/fb1685a7b009d89f9bf78a42d94461bc06581f6e718c39344754a5d9bada/frozenlist-1.8.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:581ef5194c48035a7de2aefc72ac6539823bb71508189e5de01d60c9dcd5fa65", size = 292506, upload-time = "2025-10-06T05:36:53.101Z" }, + { url = "https://files.pythonhosted.org/packages/e6/3b/b991fe1612703f7e0d05c0cf734c1b77aaf7c7d321df4572e8d36e7048c8/frozenlist-1.8.0-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3ef2d026f16a2b1866e1d86fc4e1291e1ed8a387b2c333809419a2f8b3a77b82", size = 274161, upload-time = "2025-10-06T05:36:54.309Z" }, + { url = "https://files.pythonhosted.org/packages/ca/ec/c5c618767bcdf66e88945ec0157d7f6c4a1322f1473392319b7a2501ded7/frozenlist-1.8.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:5500ef82073f599ac84d888e3a8c1f77ac831183244bfd7f11eaa0289fb30714", size = 294676, upload-time = "2025-10-06T05:36:55.566Z" }, + { url = "https://files.pythonhosted.org/packages/7c/ce/3934758637d8f8a88d11f0585d6495ef54b2044ed6ec84492a91fa3b27aa/frozenlist-1.8.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:50066c3997d0091c411a66e710f4e11752251e6d2d73d70d8d5d4c76442a199d", size = 300638, upload-time = "2025-10-06T05:36:56.758Z" }, + { url = "https://files.pythonhosted.org/packages/fc/4f/a7e4d0d467298f42de4b41cbc7ddaf19d3cfeabaf9ff97c20c6c7ee409f9/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:5c1c8e78426e59b3f8005e9b19f6ff46e5845895adbde20ece9218319eca6506", size = 283067, upload-time = "2025-10-06T05:36:57.965Z" }, + { url = "https://files.pythonhosted.org/packages/dc/48/c7b163063d55a83772b268e6d1affb960771b0e203b632cfe09522d67ea5/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:eefdba20de0d938cec6a89bd4d70f346a03108a19b9df4248d3cf0d88f1b0f51", size = 292101, upload-time = "2025-10-06T05:36:59.237Z" }, + { url = "https://files.pythonhosted.org/packages/9f/d0/2366d3c4ecdc2fd391e0afa6e11500bfba0ea772764d631bbf82f0136c9d/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:cf253e0e1c3ceb4aaff6df637ce033ff6535fb8c70a764a8f46aafd3d6ab798e", size = 289901, upload-time = "2025-10-06T05:37:00.811Z" }, + { url = "https://files.pythonhosted.org/packages/b8/94/daff920e82c1b70e3618a2ac39fbc01ae3e2ff6124e80739ce5d71c9b920/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:032efa2674356903cd0261c4317a561a6850f3ac864a63fc1583147fb05a79b0", size = 289395, upload-time = "2025-10-06T05:37:02.115Z" }, + { url = "https://files.pythonhosted.org/packages/e3/20/bba307ab4235a09fdcd3cc5508dbabd17c4634a1af4b96e0f69bfe551ebd/frozenlist-1.8.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6da155091429aeba16851ecb10a9104a108bcd32f6c1642867eadaee401c1c41", size = 283659, upload-time = "2025-10-06T05:37:03.711Z" }, + { url = "https://files.pythonhosted.org/packages/fd/00/04ca1c3a7a124b6de4f8a9a17cc2fcad138b4608e7a3fc5877804b8715d7/frozenlist-1.8.0-cp313-cp313t-win32.whl", hash = "sha256:0f96534f8bfebc1a394209427d0f8a63d343c9779cda6fc25e8e121b5fd8555b", size = 43492, upload-time = "2025-10-06T05:37:04.915Z" }, + { url = "https://files.pythonhosted.org/packages/59/5e/c69f733a86a94ab10f68e496dc6b7e8bc078ebb415281d5698313e3af3a1/frozenlist-1.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5d63a068f978fc69421fb0e6eb91a9603187527c86b7cd3f534a5b77a592b888", size = 48034, upload-time = "2025-10-06T05:37:06.343Z" }, + { url = "https://files.pythonhosted.org/packages/16/6c/be9d79775d8abe79b05fa6d23da99ad6e7763a1d080fbae7290b286093fd/frozenlist-1.8.0-cp313-cp313t-win_arm64.whl", hash = "sha256:bf0a7e10b077bf5fb9380ad3ae8ce20ef919a6ad93b4552896419ac7e1d8e042", size = 41749, upload-time = "2025-10-06T05:37:07.431Z" }, + { url = "https://files.pythonhosted.org/packages/9a/9a/e35b4a917281c0b8419d4207f4334c8e8c5dbf4f3f5f9ada73958d937dcc/frozenlist-1.8.0-py3-none-any.whl", hash = "sha256:0c18a16eab41e82c295618a77502e17b195883241c563b00f0aa5106fc4eaa0d", size = 13409, upload-time = "2025-10-06T05:38:16.721Z" }, +] + +[[package]] +name = "fsspec" +version = "2025.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/de/e0/bab50af11c2d75c9c4a2a26a5254573c0bd97cea152254401510950486fa/fsspec-2025.9.0.tar.gz", hash = "sha256:19fd429483d25d28b65ec68f9f4adc16c17ea2c7c7bf54ec61360d478fb19c19", size = 304847, upload-time = "2025-09-02T19:10:49.215Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/71/70db47e4f6ce3e5c37a607355f80da8860a33226be640226ac52cb05ef2e/fsspec-2025.9.0-py3-none-any.whl", hash = "sha256:530dc2a2af60a414a832059574df4a6e10cce927f6f4a78209390fe38955cfb7", size = 199289, upload-time = "2025-09-02T19:10:47.708Z" }, +] + +[package.optional-dependencies] +http = [ + { name = "aiohttp" }, +] + +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.45" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/c8/dd58967d119baab745caec2f9d853297cec1989ec1d63f677d3880632b88/gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c", size = 215076, upload-time = "2025-07-24T03:45:54.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, +] + +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "hf-xet" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/a5/85ef910a0aa034a2abcfadc360ab5ac6f6bc4e9112349bd40ca97551cff0/hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649", size = 2861870, upload-time = "2025-10-24T19:04:11.422Z" }, + { url = "https://files.pythonhosted.org/packages/ea/40/e2e0a7eb9a51fe8828ba2d47fe22a7e74914ea8a0db68a18c3aa7449c767/hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813", size = 2717584, upload-time = "2025-10-24T19:04:09.586Z" }, + { url = "https://files.pythonhosted.org/packages/a5/7d/daf7f8bc4594fdd59a8a596f9e3886133fdc68e675292218a5e4c1b7e834/hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc", size = 3315004, upload-time = "2025-10-24T19:04:00.314Z" }, + { url = "https://files.pythonhosted.org/packages/b1/ba/45ea2f605fbf6d81c8b21e4d970b168b18a53515923010c312c06cd83164/hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5", size = 3222636, upload-time = "2025-10-24T19:03:58.111Z" }, + { url = "https://files.pythonhosted.org/packages/4a/1d/04513e3cab8f29ab8c109d309ddd21a2705afab9d52f2ba1151e0c14f086/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f", size = 3408448, upload-time = "2025-10-24T19:04:20.951Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7c/60a2756d7feec7387db3a1176c632357632fbe7849fce576c5559d4520c7/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832", size = 3503401, upload-time = "2025-10-24T19:04:22.549Z" }, + { url = "https://files.pythonhosted.org/packages/4e/64/48fffbd67fb418ab07451e4ce641a70de1c40c10a13e25325e24858ebe5a/hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382", size = 2900866, upload-time = "2025-10-24T19:04:33.461Z" }, + { url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" }, + { url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" }, + { url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" }, + { url = "https://files.pythonhosted.org/packages/0b/dd/7ac658d54b9fb7999a0ccb07ad863b413cbaf5cf172f48ebcd9497ec7263/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737", size = 3413812, upload-time = "2025-10-24T19:04:24.585Z" }, + { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, + { url = "https://files.pythonhosted.org/packages/cb/44/870d44b30e1dcfb6a65932e3e1506c103a8a5aea9103c337e7a53180322c/hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69", size = 2905735, upload-time = "2025-10-24T19:04:35.928Z" }, +] + +[[package]] +name = "httpcore" +version = "1.0.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/94/82699a10bca87a5556c9c59b5963f2d039dbd239f25bc2a63907a05a14cb/httpcore-1.0.9.tar.gz", hash = "sha256:6e34463af53fd2ab5d807f399a9b45ea31c3dfa2276f15a2c3f00afff6e176e8", size = 85484, upload-time = "2025-04-24T22:06:22.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/f66802a942d491edb555dd61e3a9961140fd64c90bce1eafd741609d334d/httpcore-1.0.9-py3-none-any.whl", hash = "sha256:2d400746a40668fc9dec9810239072b40b4484b640a8c38fd654a024c7a1bf55", size = 78784, upload-time = "2025-04-24T22:06:20.566Z" }, +] + +[[package]] +name = "httpx" +version = "0.28.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "certifi" }, + { name = "httpcore" }, + { name = "idna" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, +] + +[[package]] +name = "huggingface-hub" +version = "0.36.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "tqdm" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/63/4910c5fa9128fdadf6a9c5ac138e8b1b6cee4ca44bf7915bbfbce4e355ee/huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25", size = 463358, upload-time = "2025-10-23T12:12:01.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" }, +] + +[[package]] +name = "identify" +version = "2.6.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/e7/685de97986c916a6d93b3876139e00eef26ad5bbbd61925d670ae8013449/identify-2.6.15.tar.gz", hash = "sha256:e4f4864b96c6557ef2a1e1c951771838f4edc9df3a72ec7118b338801b11c7bf", size = 99311, upload-time = "2025-10-02T17:43:40.631Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/1c/e5fd8f973d4f375adb21565739498e2e9a1e54c858a97b9a8ccfdc81da9b/identify-2.6.15-py2.py3-none-any.whl", hash = "sha256:1181ef7608e00704db228516541eb83a88a9f94433a8c80bb9b5bd54b1d81757", size = 99183, upload-time = "2025-10-02T17:43:39.137Z" }, +] + +[[package]] +name = "idna" +version = "3.11" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/6d/0703ccc57f3a7233505399edb88de3cbd678da106337b9fcde432b65ed60/idna-3.11.tar.gz", hash = "sha256:795dafcc9c04ed0c1fb032c2aa73654d8e8c5023a7df64a53f39190ada629902", size = 194582, upload-time = "2025-10-12T14:55:20.501Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, +] + +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + +[[package]] +name = "ipykernel" +version = "7.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "comm" }, + { name = "debugpy" }, + { name = "ipython" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "matplotlib-inline" }, + { name = "nest-asyncio" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/a4/4948be6eb88628505b83a1f2f40d90254cab66abf2043b3c40fa07dfce0f/ipykernel-7.1.0.tar.gz", hash = "sha256:58a3fc88533d5930c3546dc7eac66c6d288acde4f801e2001e65edc5dc9cf0db", size = 174579, upload-time = "2025-10-27T09:46:39.471Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl", hash = "sha256:763b5ec6c5b7776f6a8d7ce09b267693b4e5ce75cb50ae696aaefb3c85e1ea4c", size = 117968, upload-time = "2025-10-27T09:46:37.805Z" }, +] + +[[package]] +name = "ipython" +version = "9.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "decorator" }, + { name = "ipython-pygments-lexers" }, + { name = "jedi" }, + { name = "matplotlib-inline" }, + { name = "pexpect", marker = "sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "stack-data" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2a/34/29b18c62e39ee2f7a6a3bba7efd952729d8aadd45ca17efc34453b717665/ipython-9.6.0.tar.gz", hash = "sha256:5603d6d5d356378be5043e69441a072b50a5b33b4503428c77b04cb8ce7bc731", size = 4396932, upload-time = "2025-09-29T10:55:53.948Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/c5/d5e07995077e48220269c28a221e168c91123ad5ceee44d548f54a057fc0/ipython-9.6.0-py3-none-any.whl", hash = "sha256:5f77efafc886d2f023442479b8149e7d86547ad0a979e9da9f045d252f648196", size = 616170, upload-time = "2025-09-29T10:55:47.676Z" }, +] + +[[package]] +name = "ipython-pygments-lexers" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393, upload-time = "2025-01-17T11:24:34.505Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, +] + +[[package]] +name = "jaxtyping" +version = "0.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wadler-lindig" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e4/1e/827f9e17b26e21c7d4d934fd1a214284ad05663afedd37c21ed105db366b/jaxtyping-0.3.3.tar.gz", hash = "sha256:8003cfd16ba2ad9b47fdda1d982a575299a81ddfc7997ad0e917c87a0897ea86", size = 45484, upload-time = "2025-10-01T13:46:51.933Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/97/88264b1af140f66ba7ca6eb2f3a108be233ee278bb3f1d5c750243e7458a/jaxtyping-0.3.3-py3-none-any.whl", hash = "sha256:a1c2f0f4351a8deda84b0e3b5c5a50894a1cdae2b82d841279fce4393aff4a7c", size = 55926, upload-time = "2025-10-01T13:46:50.621Z" }, +] + +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, +] + +[[package]] +name = "jinja2" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/df/bf/f7da0350254c0ed7c72f3e33cef02e048281fec7ecec5f032d4aac52226b/jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d", size = 245115, upload-time = "2025-03-05T20:05:02.478Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, +] + +[[package]] +name = "jsonschema" +version = "4.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "jsonschema-specifications" }, + { name = "referencing" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/69/f7185de793a29082a9f3c7728268ffb31cb5095131a9c139a74078e27336/jsonschema-4.25.1.tar.gz", hash = "sha256:e4a9655ce0da0c0b67a085847e00a3a51449e1157f4f75e9fb5aa545e122eb85", size = 357342, upload-time = "2025-08-18T17:03:50.038Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/9c/8c95d856233c1f82500c2450b8c68576b4cf1c871db3afac5c34ff84e6fd/jsonschema-4.25.1-py3-none-any.whl", hash = "sha256:3fba0169e345c7175110351d456342c364814cfcf3b964ba4587f22915230a63", size = 90040, upload-time = "2025-08-18T17:03:48.373Z" }, +] + +[[package]] +name = "jsonschema-specifications" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "referencing" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/74/a633ee74eb36c44aa6d1095e7cc5569bebf04342ee146178e2d36600708b/jsonschema_specifications-2025.9.1.tar.gz", hash = "sha256:b540987f239e745613c7a9176f3edb72b832a4ac465cf02712288397832b5e8d", size = 32855, upload-time = "2025-09-08T01:34:59.186Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, +] + +[[package]] +name = "jupyter-client" +version = "8.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-core" }, + { name = "python-dateutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019, upload-time = "2024-09-17T10:44:17.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105, upload-time = "2024-09-17T10:44:15.218Z" }, +] + +[[package]] +name = "jupyter-core" +version = "5.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/02/49/9d1284d0dc65e2c757b74c6687b6d319b02f822ad039e5c512df9194d9dd/jupyter_core-5.9.1.tar.gz", hash = "sha256:4d09aaff303b9566c3ce657f580bd089ff5c91f5f89cf7d8846c3cdf465b5508", size = 89814, upload-time = "2025-10-16T19:19:18.444Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, +] + +[[package]] +name = "kiwisolver" +version = "1.4.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/3c/85844f1b0feb11ee581ac23fe5fce65cd049a200c1446708cc1b7f922875/kiwisolver-1.4.9.tar.gz", hash = "sha256:c3b22c26c6fd6811b0ae8363b95ca8ce4ea3c202d3d0975b2914310ceb1bcc4d", size = 97564, upload-time = "2025-08-10T21:27:49.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/31/c1/c2686cda909742ab66c7388e9a1a8521a59eb89f8bcfbee28fc980d07e24/kiwisolver-1.4.9-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a5d0432ccf1c7ab14f9949eec60c5d1f924f17c037e9f8b33352fa05799359b8", size = 123681, upload-time = "2025-08-10T21:26:26.725Z" }, + { url = "https://files.pythonhosted.org/packages/ca/f0/f44f50c9f5b1a1860261092e3bc91ecdc9acda848a8b8c6abfda4a24dd5c/kiwisolver-1.4.9-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efb3a45b35622bb6c16dbfab491a8f5a391fe0e9d45ef32f4df85658232ca0e2", size = 66464, upload-time = "2025-08-10T21:26:27.733Z" }, + { url = "https://files.pythonhosted.org/packages/2d/7a/9d90a151f558e29c3936b8a47ac770235f436f2120aca41a6d5f3d62ae8d/kiwisolver-1.4.9-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1a12cf6398e8a0a001a059747a1cbf24705e18fe413bc22de7b3d15c67cffe3f", size = 64961, upload-time = "2025-08-10T21:26:28.729Z" }, + { url = "https://files.pythonhosted.org/packages/e9/e9/f218a2cb3a9ffbe324ca29a9e399fa2d2866d7f348ec3a88df87fc248fc5/kiwisolver-1.4.9-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b67e6efbf68e077dd71d1a6b37e43e1a99d0bff1a3d51867d45ee8908b931098", size = 1474607, upload-time = "2025-08-10T21:26:29.798Z" }, + { url = "https://files.pythonhosted.org/packages/d9/28/aac26d4c882f14de59041636292bc838db8961373825df23b8eeb807e198/kiwisolver-1.4.9-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5656aa670507437af0207645273ccdfee4f14bacd7f7c67a4306d0dcaeaf6eed", size = 1276546, upload-time = "2025-08-10T21:26:31.401Z" }, + { url = "https://files.pythonhosted.org/packages/8b/ad/8bfc1c93d4cc565e5069162f610ba2f48ff39b7de4b5b8d93f69f30c4bed/kiwisolver-1.4.9-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:bfc08add558155345129c7803b3671cf195e6a56e7a12f3dde7c57d9b417f525", size = 1294482, upload-time = "2025-08-10T21:26:32.721Z" }, + { url = "https://files.pythonhosted.org/packages/da/f1/6aca55ff798901d8ce403206d00e033191f63d82dd708a186e0ed2067e9c/kiwisolver-1.4.9-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:40092754720b174e6ccf9e845d0d8c7d8e12c3d71e7fc35f55f3813e96376f78", size = 1343720, upload-time = "2025-08-10T21:26:34.032Z" }, + { url = "https://files.pythonhosted.org/packages/d1/91/eed031876c595c81d90d0f6fc681ece250e14bf6998c3d7c419466b523b7/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:497d05f29a1300d14e02e6441cf0f5ee81c1ff5a304b0d9fb77423974684e08b", size = 2224907, upload-time = "2025-08-10T21:26:35.824Z" }, + { url = "https://files.pythonhosted.org/packages/e9/ec/4d1925f2e49617b9cca9c34bfa11adefad49d00db038e692a559454dfb2e/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:bdd1a81a1860476eb41ac4bc1e07b3f07259e6d55bbf739b79c8aaedcf512799", size = 2321334, upload-time = "2025-08-10T21:26:37.534Z" }, + { url = "https://files.pythonhosted.org/packages/43/cb/450cd4499356f68802750c6ddc18647b8ea01ffa28f50d20598e0befe6e9/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:e6b93f13371d341afee3be9f7c5964e3fe61d5fa30f6a30eb49856935dfe4fc3", size = 2488313, upload-time = "2025-08-10T21:26:39.191Z" }, + { url = "https://files.pythonhosted.org/packages/71/67/fc76242bd99f885651128a5d4fa6083e5524694b7c88b489b1b55fdc491d/kiwisolver-1.4.9-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d75aa530ccfaa593da12834b86a0724f58bff12706659baa9227c2ccaa06264c", size = 2291970, upload-time = "2025-08-10T21:26:40.828Z" }, + { url = "https://files.pythonhosted.org/packages/75/bd/f1a5d894000941739f2ae1b65a32892349423ad49c2e6d0771d0bad3fae4/kiwisolver-1.4.9-cp313-cp313-win_amd64.whl", hash = "sha256:dd0a578400839256df88c16abddf9ba14813ec5f21362e1fe65022e00c883d4d", size = 73894, upload-time = "2025-08-10T21:26:42.33Z" }, + { url = "https://files.pythonhosted.org/packages/95/38/dce480814d25b99a391abbddadc78f7c117c6da34be68ca8b02d5848b424/kiwisolver-1.4.9-cp313-cp313-win_arm64.whl", hash = "sha256:d4188e73af84ca82468f09cadc5ac4db578109e52acb4518d8154698d3a87ca2", size = 64995, upload-time = "2025-08-10T21:26:43.889Z" }, + { url = "https://files.pythonhosted.org/packages/e2/37/7d218ce5d92dadc5ebdd9070d903e0c7cf7edfe03f179433ac4d13ce659c/kiwisolver-1.4.9-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:5a0f2724dfd4e3b3ac5a82436a8e6fd16baa7d507117e4279b660fe8ca38a3a1", size = 126510, upload-time = "2025-08-10T21:26:44.915Z" }, + { url = "https://files.pythonhosted.org/packages/23/b0/e85a2b48233daef4b648fb657ebbb6f8367696a2d9548a00b4ee0eb67803/kiwisolver-1.4.9-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:1b11d6a633e4ed84fc0ddafd4ebfd8ea49b3f25082c04ad12b8315c11d504dc1", size = 67903, upload-time = "2025-08-10T21:26:45.934Z" }, + { url = "https://files.pythonhosted.org/packages/44/98/f2425bc0113ad7de24da6bb4dae1343476e95e1d738be7c04d31a5d037fd/kiwisolver-1.4.9-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61874cdb0a36016354853593cffc38e56fc9ca5aa97d2c05d3dcf6922cd55a11", size = 66402, upload-time = "2025-08-10T21:26:47.101Z" }, + { url = "https://files.pythonhosted.org/packages/98/d8/594657886df9f34c4177cc353cc28ca7e6e5eb562d37ccc233bff43bbe2a/kiwisolver-1.4.9-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:60c439763a969a6af93b4881db0eed8fadf93ee98e18cbc35bc8da868d0c4f0c", size = 1582135, upload-time = "2025-08-10T21:26:48.665Z" }, + { url = "https://files.pythonhosted.org/packages/5c/c6/38a115b7170f8b306fc929e166340c24958347308ea3012c2b44e7e295db/kiwisolver-1.4.9-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92a2f997387a1b79a75e7803aa7ded2cfbe2823852ccf1ba3bcf613b62ae3197", size = 1389409, upload-time = "2025-08-10T21:26:50.335Z" }, + { url = "https://files.pythonhosted.org/packages/bf/3b/e04883dace81f24a568bcee6eb3001da4ba05114afa622ec9b6fafdc1f5e/kiwisolver-1.4.9-cp313-cp313t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a31d512c812daea6d8b3be3b2bfcbeb091dbb09177706569bcfc6240dcf8b41c", size = 1401763, upload-time = "2025-08-10T21:26:51.867Z" }, + { url = "https://files.pythonhosted.org/packages/9f/80/20ace48e33408947af49d7d15c341eaee69e4e0304aab4b7660e234d6288/kiwisolver-1.4.9-cp313-cp313t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:52a15b0f35dad39862d376df10c5230155243a2c1a436e39eb55623ccbd68185", size = 1453643, upload-time = "2025-08-10T21:26:53.592Z" }, + { url = "https://files.pythonhosted.org/packages/64/31/6ce4380a4cd1f515bdda976a1e90e547ccd47b67a1546d63884463c92ca9/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a30fd6fdef1430fd9e1ba7b3398b5ee4e2887783917a687d86ba69985fb08748", size = 2330818, upload-time = "2025-08-10T21:26:55.051Z" }, + { url = "https://files.pythonhosted.org/packages/fa/e9/3f3fcba3bcc7432c795b82646306e822f3fd74df0ee81f0fa067a1f95668/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:cc9617b46837c6468197b5945e196ee9ca43057bb7d9d1ae688101e4e1dddf64", size = 2419963, upload-time = "2025-08-10T21:26:56.421Z" }, + { url = "https://files.pythonhosted.org/packages/99/43/7320c50e4133575c66e9f7dadead35ab22d7c012a3b09bb35647792b2a6d/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:0ab74e19f6a2b027ea4f845a78827969af45ce790e6cb3e1ebab71bdf9f215ff", size = 2594639, upload-time = "2025-08-10T21:26:57.882Z" }, + { url = "https://files.pythonhosted.org/packages/65/d6/17ae4a270d4a987ef8a385b906d2bdfc9fce502d6dc0d3aea865b47f548c/kiwisolver-1.4.9-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dba5ee5d3981160c28d5490f0d1b7ed730c22470ff7f6cc26cfcfaacb9896a07", size = 2391741, upload-time = "2025-08-10T21:26:59.237Z" }, + { url = "https://files.pythonhosted.org/packages/2a/8f/8f6f491d595a9e5912971f3f863d81baddccc8a4d0c3749d6a0dd9ffc9df/kiwisolver-1.4.9-cp313-cp313t-win_arm64.whl", hash = "sha256:0749fd8f4218ad2e851e11cc4dc05c7cbc0cbc4267bdfdb31782e65aace4ee9c", size = 68646, upload-time = "2025-08-10T21:27:00.52Z" }, +] + +[[package]] +name = "markupsafe" +version = "3.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698", size = 80313, upload-time = "2025-09-27T18:37:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/2f/907b9c7bbba283e68f20259574b13d005c121a0fa4c175f9bed27c4597ff/markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795", size = 11622, upload-time = "2025-09-27T18:36:41.777Z" }, + { url = "https://files.pythonhosted.org/packages/9c/d9/5f7756922cdd676869eca1c4e3c0cd0df60ed30199ffd775e319089cb3ed/markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219", size = 12029, upload-time = "2025-09-27T18:36:43.257Z" }, + { url = "https://files.pythonhosted.org/packages/00/07/575a68c754943058c78f30db02ee03a64b3c638586fba6a6dd56830b30a3/markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6", size = 24374, upload-time = "2025-09-27T18:36:44.508Z" }, + { url = "https://files.pythonhosted.org/packages/a9/21/9b05698b46f218fc0e118e1f8168395c65c8a2c750ae2bab54fc4bd4e0e8/markupsafe-3.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ccfcd093f13f0f0b7fdd0f198b90053bf7b2f02a3927a30e63f3ccc9df56b676", size = 22980, upload-time = "2025-09-27T18:36:45.385Z" }, + { url = "https://files.pythonhosted.org/packages/7f/71/544260864f893f18b6827315b988c146b559391e6e7e8f7252839b1b846a/markupsafe-3.0.3-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:509fa21c6deb7a7a273d629cf5ec029bc209d1a51178615ddf718f5918992ab9", size = 21990, upload-time = "2025-09-27T18:36:46.916Z" }, + { url = "https://files.pythonhosted.org/packages/c2/28/b50fc2f74d1ad761af2f5dcce7492648b983d00a65b8c0e0cb457c82ebbe/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a4afe79fb3de0b7097d81da19090f4df4f8d3a2b3adaa8764138aac2e44f3af1", size = 23784, upload-time = "2025-09-27T18:36:47.884Z" }, + { url = "https://files.pythonhosted.org/packages/ed/76/104b2aa106a208da8b17a2fb72e033a5a9d7073c68f7e508b94916ed47a9/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:795e7751525cae078558e679d646ae45574b47ed6e7771863fcc079a6171a0fc", size = 21588, upload-time = "2025-09-27T18:36:48.82Z" }, + { url = "https://files.pythonhosted.org/packages/b5/99/16a5eb2d140087ebd97180d95249b00a03aa87e29cc224056274f2e45fd6/markupsafe-3.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8485f406a96febb5140bfeca44a73e3ce5116b2501ac54fe953e488fb1d03b12", size = 23041, upload-time = "2025-09-27T18:36:49.797Z" }, + { url = "https://files.pythonhosted.org/packages/19/bc/e7140ed90c5d61d77cea142eed9f9c303f4c4806f60a1044c13e3f1471d0/markupsafe-3.0.3-cp313-cp313-win32.whl", hash = "sha256:bdd37121970bfd8be76c5fb069c7751683bdf373db1ed6c010162b2a130248ed", size = 14543, upload-time = "2025-09-27T18:36:51.584Z" }, + { url = "https://files.pythonhosted.org/packages/05/73/c4abe620b841b6b791f2edc248f556900667a5a1cf023a6646967ae98335/markupsafe-3.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:9a1abfdc021a164803f4d485104931fb8f8c1efd55bc6b748d2f5774e78b62c5", size = 15113, upload-time = "2025-09-27T18:36:52.537Z" }, + { url = "https://files.pythonhosted.org/packages/f0/3a/fa34a0f7cfef23cf9500d68cb7c32dd64ffd58a12b09225fb03dd37d5b80/markupsafe-3.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:7e68f88e5b8799aa49c85cd116c932a1ac15caaa3f5db09087854d218359e485", size = 13911, upload-time = "2025-09-27T18:36:53.513Z" }, + { url = "https://files.pythonhosted.org/packages/e4/d7/e05cd7efe43a88a17a37b3ae96e79a19e846f3f456fe79c57ca61356ef01/markupsafe-3.0.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:218551f6df4868a8d527e3062d0fb968682fe92054e89978594c28e642c43a73", size = 11658, upload-time = "2025-09-27T18:36:54.819Z" }, + { url = "https://files.pythonhosted.org/packages/99/9e/e412117548182ce2148bdeacdda3bb494260c0b0184360fe0d56389b523b/markupsafe-3.0.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:3524b778fe5cfb3452a09d31e7b5adefeea8c5be1d43c4f810ba09f2ceb29d37", size = 12066, upload-time = "2025-09-27T18:36:55.714Z" }, + { url = "https://files.pythonhosted.org/packages/bc/e6/fa0ffcda717ef64a5108eaa7b4f5ed28d56122c9a6d70ab8b72f9f715c80/markupsafe-3.0.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4e885a3d1efa2eadc93c894a21770e4bc67899e3543680313b09f139e149ab19", size = 25639, upload-time = "2025-09-27T18:36:56.908Z" }, + { url = "https://files.pythonhosted.org/packages/96/ec/2102e881fe9d25fc16cb4b25d5f5cde50970967ffa5dddafdb771237062d/markupsafe-3.0.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8709b08f4a89aa7586de0aadc8da56180242ee0ada3999749b183aa23df95025", size = 23569, upload-time = "2025-09-27T18:36:57.913Z" }, + { url = "https://files.pythonhosted.org/packages/4b/30/6f2fce1f1f205fc9323255b216ca8a235b15860c34b6798f810f05828e32/markupsafe-3.0.3-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:b8512a91625c9b3da6f127803b166b629725e68af71f8184ae7e7d54686a56d6", size = 23284, upload-time = "2025-09-27T18:36:58.833Z" }, + { url = "https://files.pythonhosted.org/packages/58/47/4a0ccea4ab9f5dcb6f79c0236d954acb382202721e704223a8aafa38b5c8/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9b79b7a16f7fedff2495d684f2b59b0457c3b493778c9eed31111be64d58279f", size = 24801, upload-time = "2025-09-27T18:36:59.739Z" }, + { url = "https://files.pythonhosted.org/packages/6a/70/3780e9b72180b6fecb83a4814d84c3bf4b4ae4bf0b19c27196104149734c/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:12c63dfb4a98206f045aa9563db46507995f7ef6d83b2f68eda65c307c6829eb", size = 22769, upload-time = "2025-09-27T18:37:00.719Z" }, + { url = "https://files.pythonhosted.org/packages/98/c5/c03c7f4125180fc215220c035beac6b9cb684bc7a067c84fc69414d315f5/markupsafe-3.0.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:8f71bc33915be5186016f675cd83a1e08523649b0e33efdb898db577ef5bb009", size = 23642, upload-time = "2025-09-27T18:37:01.673Z" }, + { url = "https://files.pythonhosted.org/packages/80/d6/2d1b89f6ca4bff1036499b1e29a1d02d282259f3681540e16563f27ebc23/markupsafe-3.0.3-cp313-cp313t-win32.whl", hash = "sha256:69c0b73548bc525c8cb9a251cddf1931d1db4d2258e9599c28c07ef3580ef354", size = 14612, upload-time = "2025-09-27T18:37:02.639Z" }, + { url = "https://files.pythonhosted.org/packages/2b/98/e48a4bfba0a0ffcf9925fe2d69240bfaa19c6f7507b8cd09c70684a53c1e/markupsafe-3.0.3-cp313-cp313t-win_amd64.whl", hash = "sha256:1b4b79e8ebf6b55351f0d91fe80f893b4743f104bff22e90697db1590e47a218", size = 15200, upload-time = "2025-09-27T18:37:03.582Z" }, + { url = "https://files.pythonhosted.org/packages/0e/72/e3cc540f351f316e9ed0f092757459afbc595824ca724cbc5a5d4263713f/markupsafe-3.0.3-cp313-cp313t-win_arm64.whl", hash = "sha256:ad2cf8aa28b8c020ab2fc8287b0f823d0a7d8630784c31e9ee5edea20f406287", size = 13973, upload-time = "2025-09-27T18:37:04.929Z" }, +] + +[[package]] +name = "matplotlib" +version = "3.10.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "contourpy" }, + { name = "cycler" }, + { name = "fonttools" }, + { name = "kiwisolver" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "pyparsing" }, + { name = "python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/e2/d2d5295be2f44c678ebaf3544ba32d20c1f9ef08c49fe47f496180e1db15/matplotlib-3.10.7.tar.gz", hash = "sha256:a06ba7e2a2ef9131c79c49e63dad355d2d878413a0376c1727c8b9335ff731c7", size = 34804865, upload-time = "2025-10-09T00:28:00.669Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/9c/207547916a02c78f6bdd83448d9b21afbc42f6379ed887ecf610984f3b4e/matplotlib-3.10.7-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1d9d3713a237970569156cfb4de7533b7c4eacdd61789726f444f96a0d28f57f", size = 8273212, upload-time = "2025-10-09T00:26:56.752Z" }, + { url = "https://files.pythonhosted.org/packages/bc/d0/b3d3338d467d3fc937f0bb7f256711395cae6f78e22cef0656159950adf0/matplotlib-3.10.7-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:37a1fea41153dd6ee061d21ab69c9cf2cf543160b1b85d89cd3d2e2a7902ca4c", size = 8128713, upload-time = "2025-10-09T00:26:59.001Z" }, + { url = "https://files.pythonhosted.org/packages/22/ff/6425bf5c20d79aa5b959d1ce9e65f599632345391381c9a104133fe0b171/matplotlib-3.10.7-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b3c4ea4948d93c9c29dc01c0c23eef66f2101bf75158c291b88de6525c55c3d1", size = 8698527, upload-time = "2025-10-09T00:27:00.69Z" }, + { url = "https://files.pythonhosted.org/packages/d0/7f/ccdca06f4c2e6c7989270ed7829b8679466682f4cfc0f8c9986241c023b6/matplotlib-3.10.7-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:22df30ffaa89f6643206cf13877191c63a50e8f800b038bc39bee9d2d4957632", size = 9529690, upload-time = "2025-10-09T00:27:02.664Z" }, + { url = "https://files.pythonhosted.org/packages/b8/95/b80fc2c1f269f21ff3d193ca697358e24408c33ce2b106a7438a45407b63/matplotlib-3.10.7-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b69676845a0a66f9da30e87f48be36734d6748024b525ec4710be40194282c84", size = 9593732, upload-time = "2025-10-09T00:27:04.653Z" }, + { url = "https://files.pythonhosted.org/packages/e1/b6/23064a96308b9aeceeffa65e96bcde459a2ea4934d311dee20afde7407a0/matplotlib-3.10.7-cp313-cp313-win_amd64.whl", hash = "sha256:744991e0cc863dd669c8dc9136ca4e6e0082be2070b9d793cbd64bec872a6815", size = 8122727, upload-time = "2025-10-09T00:27:06.814Z" }, + { url = "https://files.pythonhosted.org/packages/b3/a6/2faaf48133b82cf3607759027f82b5c702aa99cdfcefb7f93d6ccf26a424/matplotlib-3.10.7-cp313-cp313-win_arm64.whl", hash = "sha256:fba2974df0bf8ce3c995fa84b79cde38326e0f7b5409e7a3a481c1141340bcf7", size = 7992958, upload-time = "2025-10-09T00:27:08.567Z" }, + { url = "https://files.pythonhosted.org/packages/4a/f0/b018fed0b599bd48d84c08794cb242227fe3341952da102ee9d9682db574/matplotlib-3.10.7-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:932c55d1fa7af4423422cb6a492a31cbcbdbe68fd1a9a3f545aa5e7a143b5355", size = 8316849, upload-time = "2025-10-09T00:27:10.254Z" }, + { url = "https://files.pythonhosted.org/packages/b0/b7/bb4f23856197659f275e11a2a164e36e65e9b48ea3e93c4ec25b4f163198/matplotlib-3.10.7-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5e38c2d581d62ee729a6e144c47a71b3f42fb4187508dbbf4fe71d5612c3433b", size = 8178225, upload-time = "2025-10-09T00:27:12.241Z" }, + { url = "https://files.pythonhosted.org/packages/62/56/0600609893ff277e6f3ab3c0cef4eafa6e61006c058e84286c467223d4d5/matplotlib-3.10.7-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:786656bb13c237bbcebcd402f65f44dd61ead60ee3deb045af429d889c8dbc67", size = 8711708, upload-time = "2025-10-09T00:27:13.879Z" }, + { url = "https://files.pythonhosted.org/packages/d8/1a/6bfecb0cafe94d6658f2f1af22c43b76cf7a1c2f0dc34ef84cbb6809617e/matplotlib-3.10.7-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09d7945a70ea43bf9248f4b6582734c2fe726723204a76eca233f24cffc7ef67", size = 9541409, upload-time = "2025-10-09T00:27:15.684Z" }, + { url = "https://files.pythonhosted.org/packages/08/50/95122a407d7f2e446fd865e2388a232a23f2b81934960ea802f3171518e4/matplotlib-3.10.7-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d0b181e9fa8daf1d9f2d4c547527b167cb8838fc587deabca7b5c01f97199e84", size = 9594054, upload-time = "2025-10-09T00:27:17.547Z" }, + { url = "https://files.pythonhosted.org/packages/13/76/75b194a43b81583478a81e78a07da8d9ca6ddf50dd0a2ccabf258059481d/matplotlib-3.10.7-cp313-cp313t-win_amd64.whl", hash = "sha256:31963603041634ce1a96053047b40961f7a29eb8f9a62e80cc2c0427aa1d22a2", size = 8200100, upload-time = "2025-10-09T00:27:20.039Z" }, + { url = "https://files.pythonhosted.org/packages/f5/9e/6aefebdc9f8235c12bdeeda44cc0383d89c1e41da2c400caf3ee2073a3ce/matplotlib-3.10.7-cp313-cp313t-win_arm64.whl", hash = "sha256:aebed7b50aa6ac698c90f60f854b47e48cd2252b30510e7a1feddaf5a3f72cbf", size = 8042131, upload-time = "2025-10-09T00:27:21.608Z" }, +] + +[[package]] +name = "matplotlib-inline" +version = "0.2.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/74/97e72a36efd4ae2bccb3463284300f8953f199b5ffbc04cbbb0ec78f74b1/matplotlib_inline-0.2.1.tar.gz", hash = "sha256:e1ee949c340d771fc39e241ea75683deb94762c8fa5f2927ec57c83c4dffa9fe", size = 8110, upload-time = "2025-10-23T09:00:22.126Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, +] + +[[package]] +name = "mpmath" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/47/dd32fa426cc72114383ac549964eecb20ecfd886d1e5ccf5340b55b02f57/mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f", size = 508106, upload-time = "2023-03-07T16:47:11.061Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/e3/7d92a15f894aa0c9c4b49b8ee9ac9850d6e63b03c9c32c0367a13ae62209/mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c", size = 536198, upload-time = "2023-03-07T16:47:09.197Z" }, +] + +[[package]] +name = "multidict" +version = "6.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/80/1e/5492c365f222f907de1039b91f922b93fa4f764c713ee858d235495d8f50/multidict-6.7.0.tar.gz", hash = "sha256:c6e99d9a65ca282e578dfea819cfa9c0a62b2499d8677392e09feaf305e9e6f5", size = 101834, upload-time = "2025-10-06T14:52:30.657Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/86/33272a544eeb36d66e4d9a920602d1a2f57d4ebea4ef3cdfe5a912574c95/multidict-6.7.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:bee7c0588aa0076ce77c0ea5d19a68d76ad81fcd9fe8501003b9a24f9d4000f6", size = 76135, upload-time = "2025-10-06T14:49:54.26Z" }, + { url = "https://files.pythonhosted.org/packages/91/1c/eb97db117a1ebe46d457a3d235a7b9d2e6dcab174f42d1b67663dd9e5371/multidict-6.7.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7ef6b61cad77091056ce0e7ce69814ef72afacb150b7ac6a3e9470def2198159", size = 45117, upload-time = "2025-10-06T14:49:55.82Z" }, + { url = "https://files.pythonhosted.org/packages/f1/d8/6c3442322e41fb1dd4de8bd67bfd11cd72352ac131f6368315617de752f1/multidict-6.7.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:9c0359b1ec12b1d6849c59f9d319610b7f20ef990a6d454ab151aa0e3b9f78ca", size = 43472, upload-time = "2025-10-06T14:49:57.048Z" }, + { url = "https://files.pythonhosted.org/packages/75/3f/e2639e80325af0b6c6febdf8e57cc07043ff15f57fa1ef808f4ccb5ac4cd/multidict-6.7.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cd240939f71c64bd658f186330603aac1a9a81bf6273f523fca63673cb7378a8", size = 249342, upload-time = "2025-10-06T14:49:58.368Z" }, + { url = "https://files.pythonhosted.org/packages/5d/cc/84e0585f805cbeaa9cbdaa95f9a3d6aed745b9d25700623ac89a6ecff400/multidict-6.7.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a60a4d75718a5efa473ebd5ab685786ba0c67b8381f781d1be14da49f1a2dc60", size = 257082, upload-time = "2025-10-06T14:49:59.89Z" }, + { url = "https://files.pythonhosted.org/packages/b0/9c/ac851c107c92289acbbf5cfb485694084690c1b17e555f44952c26ddc5bd/multidict-6.7.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:53a42d364f323275126aff81fb67c5ca1b7a04fda0546245730a55c8c5f24bc4", size = 240704, upload-time = "2025-10-06T14:50:01.485Z" }, + { url = "https://files.pythonhosted.org/packages/50/cc/5f93e99427248c09da95b62d64b25748a5f5c98c7c2ab09825a1d6af0e15/multidict-6.7.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3b29b980d0ddbecb736735ee5bef69bb2ddca56eff603c86f3f29a1128299b4f", size = 266355, upload-time = "2025-10-06T14:50:02.955Z" }, + { url = "https://files.pythonhosted.org/packages/ec/0c/2ec1d883ceb79c6f7f6d7ad90c919c898f5d1c6ea96d322751420211e072/multidict-6.7.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f8a93b1c0ed2d04b97a5e9336fd2d33371b9a6e29ab7dd6503d63407c20ffbaf", size = 267259, upload-time = "2025-10-06T14:50:04.446Z" }, + { url = "https://files.pythonhosted.org/packages/c6/2d/f0b184fa88d6630aa267680bdb8623fb69cb0d024b8c6f0d23f9a0f406d3/multidict-6.7.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ff96e8815eecacc6645da76c413eb3b3d34cfca256c70b16b286a687d013c32", size = 254903, upload-time = "2025-10-06T14:50:05.98Z" }, + { url = "https://files.pythonhosted.org/packages/06/c9/11ea263ad0df7dfabcad404feb3c0dd40b131bc7f232d5537f2fb1356951/multidict-6.7.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7516c579652f6a6be0e266aec0acd0db80829ca305c3d771ed898538804c2036", size = 252365, upload-time = "2025-10-06T14:50:07.511Z" }, + { url = "https://files.pythonhosted.org/packages/41/88/d714b86ee2c17d6e09850c70c9d310abac3d808ab49dfa16b43aba9d53fd/multidict-6.7.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:040f393368e63fb0f3330e70c26bfd336656bed925e5cbe17c9da839a6ab13ec", size = 250062, upload-time = "2025-10-06T14:50:09.074Z" }, + { url = "https://files.pythonhosted.org/packages/15/fe/ad407bb9e818c2b31383f6131ca19ea7e35ce93cf1310fce69f12e89de75/multidict-6.7.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b3bc26a951007b1057a1c543af845f1c7e3e71cc240ed1ace7bf4484aa99196e", size = 249683, upload-time = "2025-10-06T14:50:10.714Z" }, + { url = "https://files.pythonhosted.org/packages/8c/a4/a89abdb0229e533fb925e7c6e5c40201c2873efebc9abaf14046a4536ee6/multidict-6.7.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:7b022717c748dd1992a83e219587aabe45980d88969f01b316e78683e6285f64", size = 261254, upload-time = "2025-10-06T14:50:12.28Z" }, + { url = "https://files.pythonhosted.org/packages/8d/aa/0e2b27bd88b40a4fb8dc53dd74eecac70edaa4c1dd0707eb2164da3675b3/multidict-6.7.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:9600082733859f00d79dee64effc7aef1beb26adb297416a4ad2116fd61374bd", size = 257967, upload-time = "2025-10-06T14:50:14.16Z" }, + { url = "https://files.pythonhosted.org/packages/d0/8e/0c67b7120d5d5f6d874ed85a085f9dc770a7f9d8813e80f44a9fec820bb7/multidict-6.7.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:94218fcec4d72bc61df51c198d098ce2b378e0ccbac41ddbed5ef44092913288", size = 250085, upload-time = "2025-10-06T14:50:15.639Z" }, + { url = "https://files.pythonhosted.org/packages/ba/55/b73e1d624ea4b8fd4dd07a3bb70f6e4c7c6c5d9d640a41c6ffe5cdbd2a55/multidict-6.7.0-cp313-cp313-win32.whl", hash = "sha256:a37bd74c3fa9d00be2d7b8eca074dc56bd8077ddd2917a839bd989612671ed17", size = 41713, upload-time = "2025-10-06T14:50:17.066Z" }, + { url = "https://files.pythonhosted.org/packages/32/31/75c59e7d3b4205075b4c183fa4ca398a2daf2303ddf616b04ae6ef55cffe/multidict-6.7.0-cp313-cp313-win_amd64.whl", hash = "sha256:30d193c6cc6d559db42b6bcec8a5d395d34d60c9877a0b71ecd7c204fcf15390", size = 45915, upload-time = "2025-10-06T14:50:18.264Z" }, + { url = "https://files.pythonhosted.org/packages/31/2a/8987831e811f1184c22bc2e45844934385363ee61c0a2dcfa8f71b87e608/multidict-6.7.0-cp313-cp313-win_arm64.whl", hash = "sha256:ea3334cabe4d41b7ccd01e4d349828678794edbc2d3ae97fc162a3312095092e", size = 43077, upload-time = "2025-10-06T14:50:19.853Z" }, + { url = "https://files.pythonhosted.org/packages/e8/68/7b3a5170a382a340147337b300b9eb25a9ddb573bcdfff19c0fa3f31ffba/multidict-6.7.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:ad9ce259f50abd98a1ca0aa6e490b58c316a0fce0617f609723e40804add2c00", size = 83114, upload-time = "2025-10-06T14:50:21.223Z" }, + { url = "https://files.pythonhosted.org/packages/55/5c/3fa2d07c84df4e302060f555bbf539310980362236ad49f50eeb0a1c1eb9/multidict-6.7.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:07f5594ac6d084cbb5de2df218d78baf55ef150b91f0ff8a21cc7a2e3a5a58eb", size = 48442, upload-time = "2025-10-06T14:50:22.871Z" }, + { url = "https://files.pythonhosted.org/packages/fc/56/67212d33239797f9bd91962bb899d72bb0f4c35a8652dcdb8ed049bef878/multidict-6.7.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:0591b48acf279821a579282444814a2d8d0af624ae0bc600aa4d1b920b6e924b", size = 46885, upload-time = "2025-10-06T14:50:24.258Z" }, + { url = "https://files.pythonhosted.org/packages/46/d1/908f896224290350721597a61a69cd19b89ad8ee0ae1f38b3f5cd12ea2ac/multidict-6.7.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:749a72584761531d2b9467cfbdfd29487ee21124c304c4b6cb760d8777b27f9c", size = 242588, upload-time = "2025-10-06T14:50:25.716Z" }, + { url = "https://files.pythonhosted.org/packages/ab/67/8604288bbd68680eee0ab568fdcb56171d8b23a01bcd5cb0c8fedf6e5d99/multidict-6.7.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6b4c3d199f953acd5b446bf7c0de1fe25d94e09e79086f8dc2f48a11a129cdf1", size = 249966, upload-time = "2025-10-06T14:50:28.192Z" }, + { url = "https://files.pythonhosted.org/packages/20/33/9228d76339f1ba51e3efef7da3ebd91964d3006217aae13211653193c3ff/multidict-6.7.0-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:9fb0211dfc3b51efea2f349ec92c114d7754dd62c01f81c3e32b765b70c45c9b", size = 228618, upload-time = "2025-10-06T14:50:29.82Z" }, + { url = "https://files.pythonhosted.org/packages/f8/2d/25d9b566d10cab1c42b3b9e5b11ef79c9111eaf4463b8c257a3bd89e0ead/multidict-6.7.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a027ec240fe73a8d6281872690b988eed307cd7d91b23998ff35ff577ca688b5", size = 257539, upload-time = "2025-10-06T14:50:31.731Z" }, + { url = "https://files.pythonhosted.org/packages/b6/b1/8d1a965e6637fc33de3c0d8f414485c2b7e4af00f42cab3d84e7b955c222/multidict-6.7.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1d964afecdf3a8288789df2f5751dc0a8261138c3768d9af117ed384e538fad", size = 256345, upload-time = "2025-10-06T14:50:33.26Z" }, + { url = "https://files.pythonhosted.org/packages/ba/0c/06b5a8adbdeedada6f4fb8d8f193d44a347223b11939b42953eeb6530b6b/multidict-6.7.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:caf53b15b1b7df9fbd0709aa01409000a2b4dd03a5f6f5cc548183c7c8f8b63c", size = 247934, upload-time = "2025-10-06T14:50:34.808Z" }, + { url = "https://files.pythonhosted.org/packages/8f/31/b2491b5fe167ca044c6eb4b8f2c9f3b8a00b24c432c365358eadac5d7625/multidict-6.7.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:654030da3197d927f05a536a66186070e98765aa5142794c9904555d3a9d8fb5", size = 245243, upload-time = "2025-10-06T14:50:36.436Z" }, + { url = "https://files.pythonhosted.org/packages/61/1a/982913957cb90406c8c94f53001abd9eafc271cb3e70ff6371590bec478e/multidict-6.7.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:2090d3718829d1e484706a2f525e50c892237b2bf9b17a79b059cb98cddc2f10", size = 235878, upload-time = "2025-10-06T14:50:37.953Z" }, + { url = "https://files.pythonhosted.org/packages/be/c0/21435d804c1a1cf7a2608593f4d19bca5bcbd7a81a70b253fdd1c12af9c0/multidict-6.7.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:2d2cfeec3f6f45651b3d408c4acec0ebf3daa9bc8a112a084206f5db5d05b754", size = 243452, upload-time = "2025-10-06T14:50:39.574Z" }, + { url = "https://files.pythonhosted.org/packages/54/0a/4349d540d4a883863191be6eb9a928846d4ec0ea007d3dcd36323bb058ac/multidict-6.7.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:4ef089f985b8c194d341eb2c24ae6e7408c9a0e2e5658699c92f497437d88c3c", size = 252312, upload-time = "2025-10-06T14:50:41.612Z" }, + { url = "https://files.pythonhosted.org/packages/26/64/d5416038dbda1488daf16b676e4dbfd9674dde10a0cc8f4fc2b502d8125d/multidict-6.7.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e93a0617cd16998784bf4414c7e40f17a35d2350e5c6f0bd900d3a8e02bd3762", size = 246935, upload-time = "2025-10-06T14:50:43.972Z" }, + { url = "https://files.pythonhosted.org/packages/9f/8c/8290c50d14e49f35e0bd4abc25e1bc7711149ca9588ab7d04f886cdf03d9/multidict-6.7.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f0feece2ef8ebc42ed9e2e8c78fc4aa3cf455733b507c09ef7406364c94376c6", size = 243385, upload-time = "2025-10-06T14:50:45.648Z" }, + { url = "https://files.pythonhosted.org/packages/ef/a0/f83ae75e42d694b3fbad3e047670e511c138be747bc713cf1b10d5096416/multidict-6.7.0-cp313-cp313t-win32.whl", hash = "sha256:19a1d55338ec1be74ef62440ca9e04a2f001a04d0cc49a4983dc320ff0f3212d", size = 47777, upload-time = "2025-10-06T14:50:47.154Z" }, + { url = "https://files.pythonhosted.org/packages/dc/80/9b174a92814a3830b7357307a792300f42c9e94664b01dee8e457551fa66/multidict-6.7.0-cp313-cp313t-win_amd64.whl", hash = "sha256:3da4fb467498df97e986af166b12d01f05d2e04f978a9c1c680ea1988e0bc4b6", size = 53104, upload-time = "2025-10-06T14:50:48.851Z" }, + { url = "https://files.pythonhosted.org/packages/cc/28/04baeaf0428d95bb7a7bea0e691ba2f31394338ba424fb0679a9ed0f4c09/multidict-6.7.0-cp313-cp313t-win_arm64.whl", hash = "sha256:b4121773c49a0776461f4a904cdf6264c88e42218aaa8407e803ca8025872792", size = 45503, upload-time = "2025-10-06T14:50:50.16Z" }, + { url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" }, +] + +[[package]] +name = "multiprocess" +version = "0.70.16" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" }, + { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" }, + { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" }, + { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" }, + { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, +] + +[[package]] +name = "muutils" +version = "0.8.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/c8/556e999e5e5662ca2d74aa486962b2e7a955e58723af6cadca293be0bd37/muutils-0.8.12.tar.gz", hash = "sha256:ffc0d2c5b0e3bbf4c442dd810880aec7d9f95995e7677e14dc72f0a5ef12b993", size = 3348223, upload-time = "2025-10-28T17:52:25.769Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/65/d6e07cbff0caf10b2c77fad77e9138973c689dbe50c5ecb3b96764630276/muutils-0.8.12-py3-none-any.whl", hash = "sha256:19ecc6f2cab6e162d6f84f6f0d96377dc387a0e7105334c0b6d8eb90934eaeea", size = 129087, upload-time = "2025-10-28T17:52:23.013Z" }, +] + +[[package]] +name = "narwhals" +version = "2.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/e5/ef07d31c2e07d99eecac8e14ace5c20aeb00ecba4ed5bb00343136380524/narwhals-2.10.0.tar.gz", hash = "sha256:1c05bbef2048a4045263de7d98c3d06140583eb13d796dd733b2157f05d24485", size = 582423, upload-time = "2025-10-27T17:55:55.632Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/13/024ae0586d901f8a6f99e2d29b4ae217e8ef11d3fd944cdfc3bbde5f2a08/narwhals-2.10.0-py3-none-any.whl", hash = "sha256:baed44e8fc38e800e3a585e3fa9843a7079a6fad5fbffbecee4348d6ac52298c", size = 418077, upload-time = "2025-10-27T17:55:53.709Z" }, +] + +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, +] + +[[package]] +name = "networkx" +version = "3.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/4f/ccdb8ad3a38e583f214547fd2f7ff1fc160c43a75af88e6aec213404b96a/networkx-3.5.tar.gz", hash = "sha256:d4c6f9cf81f52d69230866796b82afbccdec3db7ae4fbd1b65ea750feed50037", size = 2471065, upload-time = "2025-05-29T11:35:07.804Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, +] + +[[package]] +name = "nodeenv" +version = "1.9.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/16/fc88b08840de0e0a72a2f9d8c6bae36be573e475a6326ae854bcc549fc45/nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f", size = 47437, upload-time = "2024-06-04T18:44:11.171Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/1d/1b658dbd2b9fa9c4c9f32accbfc0205d532c8c6194dc0f2a4c0428e7128a/nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9", size = 22314, upload-time = "2024-06-04T18:44:08.352Z" }, +] + +[[package]] +name = "nodejs-wheel-binaries" +version = "22.20.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/54/02f58c8119e2f1984e2572cc77a7b469dbaf4f8d171ad376e305749ef48e/nodejs_wheel_binaries-22.20.0.tar.gz", hash = "sha256:a62d47c9fd9c32191dff65bbe60261504f26992a0a19fe8b4d523256a84bd351", size = 8058, upload-time = "2025-09-26T09:48:00.906Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/24/6d/333e5458422f12318e3c3e6e7f194353aa68b0d633217c7e89833427ca01/nodejs_wheel_binaries-22.20.0-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:455add5ac4f01c9c830ab6771dbfad0fdf373f9b040d3aabe8cca9b6c56654fb", size = 53246314, upload-time = "2025-09-26T09:47:32.536Z" }, + { url = "https://files.pythonhosted.org/packages/56/30/dcd6879d286a35b3c4c8f9e5e0e1bcf4f9e25fe35310fc77ecf97f915a23/nodejs_wheel_binaries-22.20.0-py2.py3-none-macosx_11_0_x86_64.whl", hash = "sha256:5d8c12f97eea7028b34a84446eb5ca81829d0c428dfb4e647e09ac617f4e21fa", size = 53644391, upload-time = "2025-09-26T09:47:36.093Z" }, + { url = "https://files.pythonhosted.org/packages/58/be/c7b2e7aa3bb281d380a1c531f84d0ccfe225832dfc3bed1ca171753b9630/nodejs_wheel_binaries-22.20.0-py2.py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7a2b0989194148f66e9295d8f11bc463bde02cbe276517f4d20a310fb84780ae", size = 60282516, upload-time = "2025-09-26T09:47:39.88Z" }, + { url = "https://files.pythonhosted.org/packages/3e/c5/8befacf4190e03babbae54cb0809fb1a76e1600ec3967ab8ee9f8fc85b65/nodejs_wheel_binaries-22.20.0-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b5c500aa4dc046333ecb0a80f183e069e5c30ce637f1c1a37166b2c0b642dc21", size = 60347290, upload-time = "2025-09-26T09:47:43.712Z" }, + { url = "https://files.pythonhosted.org/packages/c0/bd/cfffd1e334277afa0714962c6ec432b5fe339340a6bca2e5fa8e678e7590/nodejs_wheel_binaries-22.20.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3279eb1b99521f0d20a850bbfc0159a658e0e85b843b3cf31b090d7da9f10dfc", size = 62178798, upload-time = "2025-09-26T09:47:47.752Z" }, + { url = "https://files.pythonhosted.org/packages/08/14/10b83a9c02faac985b3e9f5e65d63a34fc0f46b48d8a2c3e4caa3e1e7318/nodejs_wheel_binaries-22.20.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d29705797b33bade62d79d8f106c2453c8a26442a9b2a5576610c0f7e7c351ed", size = 62772957, upload-time = "2025-09-26T09:47:51.266Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a9/c6a480259aa0d6b270aac2c6ba73a97444b9267adde983a5b7e34f17e45a/nodejs_wheel_binaries-22.20.0-py2.py3-none-win_amd64.whl", hash = "sha256:4bd658962f24958503541963e5a6f2cc512a8cb301e48a69dc03c879f40a28ae", size = 40120431, upload-time = "2025-09-26T09:47:54.363Z" }, + { url = "https://files.pythonhosted.org/packages/42/b1/6a4eb2c6e9efa028074b0001b61008c9d202b6b46caee9e5d1b18c088216/nodejs_wheel_binaries-22.20.0-py2.py3-none-win_arm64.whl", hash = "sha256:1fccac931faa210d22b6962bcdbc99269d16221d831b9a118bbb80fe434a60b8", size = 38844133, upload-time = "2025-09-26T09:47:57.357Z" }, +] + +[[package]] +name = "numpy" +version = "2.3.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/f4/098d2270d52b41f1bd7db9fc288aaa0400cb48c2a3e2af6fa365d9720947/numpy-2.3.4.tar.gz", hash = "sha256:a7d018bfedb375a8d979ac758b120ba846a7fe764911a64465fd87b8729f4a6a", size = 20582187, upload-time = "2025-10-15T16:18:11.77Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/7e/b72610cc91edf138bc588df5150957a4937221ca6058b825b4725c27be62/numpy-2.3.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c090d4860032b857d94144d1a9976b8e36709e40386db289aaf6672de2a81966", size = 20950335, upload-time = "2025-10-15T16:16:10.304Z" }, + { url = "https://files.pythonhosted.org/packages/3e/46/bdd3370dcea2f95ef14af79dbf81e6927102ddf1cc54adc0024d61252fd9/numpy-2.3.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a13fc473b6db0be619e45f11f9e81260f7302f8d180c49a22b6e6120022596b3", size = 14179878, upload-time = "2025-10-15T16:16:12.595Z" }, + { url = "https://files.pythonhosted.org/packages/ac/01/5a67cb785bda60f45415d09c2bc245433f1c68dd82eef9c9002c508b5a65/numpy-2.3.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:3634093d0b428e6c32c3a69b78e554f0cd20ee420dcad5a9f3b2a63762ce4197", size = 5108673, upload-time = "2025-10-15T16:16:14.877Z" }, + { url = "https://files.pythonhosted.org/packages/c2/cd/8428e23a9fcebd33988f4cb61208fda832800ca03781f471f3727a820704/numpy-2.3.4-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:043885b4f7e6e232d7df4f51ffdef8c36320ee9d5f227b380ea636722c7ed12e", size = 6641438, upload-time = "2025-10-15T16:16:16.805Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d1/913fe563820f3c6b079f992458f7331278dcd7ba8427e8e745af37ddb44f/numpy-2.3.4-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4ee6a571d1e4f0ea6d5f22d6e5fbd6ed1dc2b18542848e1e7301bd190500c9d7", size = 14281290, upload-time = "2025-10-15T16:16:18.764Z" }, + { url = "https://files.pythonhosted.org/packages/9e/7e/7d306ff7cb143e6d975cfa7eb98a93e73495c4deabb7d1b5ecf09ea0fd69/numpy-2.3.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fc8a63918b04b8571789688b2780ab2b4a33ab44bfe8ccea36d3eba51228c953", size = 16636543, upload-time = "2025-10-15T16:16:21.072Z" }, + { url = "https://files.pythonhosted.org/packages/47/6a/8cfc486237e56ccfb0db234945552a557ca266f022d281a2f577b98e955c/numpy-2.3.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:40cc556d5abbc54aabe2b1ae287042d7bdb80c08edede19f0c0afb36ae586f37", size = 16056117, upload-time = "2025-10-15T16:16:23.369Z" }, + { url = "https://files.pythonhosted.org/packages/b1/0e/42cb5e69ea901e06ce24bfcc4b5664a56f950a70efdcf221f30d9615f3f3/numpy-2.3.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ecb63014bb7f4ce653f8be7f1df8cbc6093a5a2811211770f6606cc92b5a78fd", size = 18577788, upload-time = "2025-10-15T16:16:27.496Z" }, + { url = "https://files.pythonhosted.org/packages/86/92/41c3d5157d3177559ef0a35da50f0cda7fa071f4ba2306dd36818591a5bc/numpy-2.3.4-cp313-cp313-win32.whl", hash = "sha256:e8370eb6925bb8c1c4264fec52b0384b44f675f191df91cbe0140ec9f0955646", size = 6282620, upload-time = "2025-10-15T16:16:29.811Z" }, + { url = "https://files.pythonhosted.org/packages/09/97/fd421e8bc50766665ad35536c2bb4ef916533ba1fdd053a62d96cc7c8b95/numpy-2.3.4-cp313-cp313-win_amd64.whl", hash = "sha256:56209416e81a7893036eea03abcb91c130643eb14233b2515c90dcac963fe99d", size = 12784672, upload-time = "2025-10-15T16:16:31.589Z" }, + { url = "https://files.pythonhosted.org/packages/ad/df/5474fb2f74970ca8eb978093969b125a84cc3d30e47f82191f981f13a8a0/numpy-2.3.4-cp313-cp313-win_arm64.whl", hash = "sha256:a700a4031bc0fd6936e78a752eefb79092cecad2599ea9c8039c548bc097f9bc", size = 10196702, upload-time = "2025-10-15T16:16:33.902Z" }, + { url = "https://files.pythonhosted.org/packages/11/83/66ac031464ec1767ea3ed48ce40f615eb441072945e98693bec0bcd056cc/numpy-2.3.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:86966db35c4040fdca64f0816a1c1dd8dbd027d90fca5a57e00e1ca4cd41b879", size = 21049003, upload-time = "2025-10-15T16:16:36.101Z" }, + { url = "https://files.pythonhosted.org/packages/5f/99/5b14e0e686e61371659a1d5bebd04596b1d72227ce36eed121bb0aeab798/numpy-2.3.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:838f045478638b26c375ee96ea89464d38428c69170360b23a1a50fa4baa3562", size = 14302980, upload-time = "2025-10-15T16:16:39.124Z" }, + { url = "https://files.pythonhosted.org/packages/2c/44/e9486649cd087d9fc6920e3fc3ac2aba10838d10804b1e179fb7cbc4e634/numpy-2.3.4-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:d7315ed1dab0286adca467377c8381cd748f3dc92235f22a7dfc42745644a96a", size = 5231472, upload-time = "2025-10-15T16:16:41.168Z" }, + { url = "https://files.pythonhosted.org/packages/3e/51/902b24fa8887e5fe2063fd61b1895a476d0bbf46811ab0c7fdf4bd127345/numpy-2.3.4-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:84f01a4d18b2cc4ade1814a08e5f3c907b079c847051d720fad15ce37aa930b6", size = 6739342, upload-time = "2025-10-15T16:16:43.777Z" }, + { url = "https://files.pythonhosted.org/packages/34/f1/4de9586d05b1962acdcdb1dc4af6646361a643f8c864cef7c852bf509740/numpy-2.3.4-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:817e719a868f0dacde4abdfc5c1910b301877970195db9ab6a5e2c4bd5b121f7", size = 14354338, upload-time = "2025-10-15T16:16:46.081Z" }, + { url = "https://files.pythonhosted.org/packages/1f/06/1c16103b425de7969d5a76bdf5ada0804b476fed05d5f9e17b777f1cbefd/numpy-2.3.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:85e071da78d92a214212cacea81c6da557cab307f2c34b5f85b628e94803f9c0", size = 16702392, upload-time = "2025-10-15T16:16:48.455Z" }, + { url = "https://files.pythonhosted.org/packages/34/b2/65f4dc1b89b5322093572b6e55161bb42e3e0487067af73627f795cc9d47/numpy-2.3.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:2ec646892819370cf3558f518797f16597b4e4669894a2ba712caccc9da53f1f", size = 16134998, upload-time = "2025-10-15T16:16:51.114Z" }, + { url = "https://files.pythonhosted.org/packages/d4/11/94ec578896cdb973aaf56425d6c7f2aff4186a5c00fac15ff2ec46998b46/numpy-2.3.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:035796aaaddfe2f9664b9a9372f089cfc88bd795a67bd1bfe15e6e770934cf64", size = 18651574, upload-time = "2025-10-15T16:16:53.429Z" }, + { url = "https://files.pythonhosted.org/packages/62/b7/7efa763ab33dbccf56dade36938a77345ce8e8192d6b39e470ca25ff3cd0/numpy-2.3.4-cp313-cp313t-win32.whl", hash = "sha256:fea80f4f4cf83b54c3a051f2f727870ee51e22f0248d3114b8e755d160b38cfb", size = 6413135, upload-time = "2025-10-15T16:16:55.992Z" }, + { url = "https://files.pythonhosted.org/packages/43/70/aba4c38e8400abcc2f345e13d972fb36c26409b3e644366db7649015f291/numpy-2.3.4-cp313-cp313t-win_amd64.whl", hash = "sha256:15eea9f306b98e0be91eb344a94c0e630689ef302e10c2ce5f7e11905c704f9c", size = 12928582, upload-time = "2025-10-15T16:16:57.943Z" }, + { url = "https://files.pythonhosted.org/packages/67/63/871fad5f0073fc00fbbdd7232962ea1ac40eeaae2bba66c76214f7954236/numpy-2.3.4-cp313-cp313t-win_arm64.whl", hash = "sha256:b6c231c9c2fadbae4011ca5e7e83e12dc4a5072f1a1d85a0a7b3ed754d145a40", size = 10266691, upload-time = "2025-10-15T16:17:00.048Z" }, +] + +[[package]] +name = "nvidia-cublas-cu12" +version = "12.8.4.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/61/e24b560ab2e2eaeb3c839129175fb330dfcfc29e5203196e5541a4c44682/nvidia_cublas_cu12-12.8.4.1-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:8ac4e771d5a348c551b2a426eda6193c19aa630236b418086020df5ba9667142", size = 594346921, upload-time = "2025-03-07T01:44:31.254Z" }, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/02/2adcaa145158bf1a8295d83591d22e4103dbfd821bcaf6f3f53151ca4ffa/nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ea0cb07ebda26bb9b29ba82cda34849e73c166c18162d3913575b0c9db9a6182", size = 10248621, upload-time = "2025-03-07T01:40:21.213Z" }, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/6b/32f747947df2da6994e999492ab306a903659555dddc0fbdeb9d71f75e52/nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:a7756528852ef889772a84c6cd89d41dfa74667e24cca16bb31f8f061e3e9994", size = 88040029, upload-time = "2025-03-07T01:42:13.562Z" }, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0d/9b/a997b638fcd068ad6e4d53b8551a7d30fe8b404d6f1804abf1df69838932/nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adade8dcbd0edf427b7204d480d6066d33902cab2a4707dcfc48a2d0fd44ab90", size = 954765, upload-time = "2025-03-07T01:40:01.615Z" }, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "9.10.2.21" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, +] + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.3.3.83" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, +] + +[[package]] +name = "nvidia-cufile-cu12" +version = "1.13.1.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bb/fe/1bcba1dfbfb8d01be8d93f07bfc502c93fa23afa6fd5ab3fc7c1df71038a/nvidia_cufile_cu12-1.13.1.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1d069003be650e131b21c932ec3d8969c1715379251f8d23a1860554b1cb24fc", size = 1197834, upload-time = "2025-03-07T01:45:50.723Z" }, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.9.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/aa/6584b56dc84ebe9cf93226a5cde4d99080c8e90ab40f0c27bda7a0f29aa1/nvidia_curand_cu12-10.3.9.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:b32331d4f4df5d6eefa0554c565b626c7216f87a06a4f56fab27c3b68a830ec9", size = 63619976, upload-time = "2025-03-07T01:46:23.323Z" }, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.7.3.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cusparse-cu12" }, + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, +] + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.5.8.93" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, +] + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.27.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/5b/4e4fff7bad39adf89f735f2bc87248c81db71205b62bcc0d5ca5b606b3c3/nvidia_nccl_cu12-2.27.3-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:adf27ccf4238253e0b826bce3ff5fa532d65fc42322c8bfdfaf28024c0fbe039", size = 322364134, upload-time = "2025-06-03T21:58:04.013Z" }, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.8.93" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl", hash = "sha256:81ff63371a7ebd6e6451970684f916be2eab07321b73c9d244dc2b4da7f73b88", size = 39254836, upload-time = "2025-03-07T01:49:55.661Z" }, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.8.90" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f", size = 89954, upload-time = "2025-03-07T01:42:44.131Z" }, +] + +[[package]] +name = "packaging" +version = "25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727, upload-time = "2025-04-19T11:48:59.673Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469, upload-time = "2025-04-19T11:48:57.875Z" }, +] + +[[package]] +name = "pandas" +version = "2.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "python-dateutil" }, + { name = "pytz" }, + { name = "tzdata" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/01/d40b85317f86cf08d853a4f495195c73815fdf205eef3993821720274518/pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b", size = 4495223, upload-time = "2025-09-29T23:34:51.853Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/4b/18b035ee18f97c1040d94debd8f2e737000ad70ccc8f5513f4eefad75f4b/pandas-2.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:56851a737e3470de7fa88e6131f41281ed440d29a9268dcbf0002da5ac366713", size = 11544671, upload-time = "2025-09-29T23:21:05.024Z" }, + { url = "https://files.pythonhosted.org/packages/31/94/72fac03573102779920099bcac1c3b05975c2cb5f01eac609faf34bed1ca/pandas-2.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdcd9d1167f4885211e401b3036c0c8d9e274eee67ea8d0758a256d60704cfe8", size = 10680807, upload-time = "2025-09-29T23:21:15.979Z" }, + { url = "https://files.pythonhosted.org/packages/16/87/9472cf4a487d848476865321de18cc8c920b8cab98453ab79dbbc98db63a/pandas-2.3.3-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e32e7cc9af0f1cc15548288a51a3b681cc2a219faa838e995f7dc53dbab1062d", size = 11709872, upload-time = "2025-09-29T23:21:27.165Z" }, + { url = "https://files.pythonhosted.org/packages/15/07/284f757f63f8a8d69ed4472bfd85122bd086e637bf4ed09de572d575a693/pandas-2.3.3-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:318d77e0e42a628c04dc56bcef4b40de67918f7041c2b061af1da41dcff670ac", size = 12306371, upload-time = "2025-09-29T23:21:40.532Z" }, + { url = "https://files.pythonhosted.org/packages/33/81/a3afc88fca4aa925804a27d2676d22dcd2031c2ebe08aabd0ae55b9ff282/pandas-2.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4e0a175408804d566144e170d0476b15d78458795bb18f1304fb94160cabf40c", size = 12765333, upload-time = "2025-09-29T23:21:55.77Z" }, + { url = "https://files.pythonhosted.org/packages/8d/0f/b4d4ae743a83742f1153464cf1a8ecfafc3ac59722a0b5c8602310cb7158/pandas-2.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:93c2d9ab0fc11822b5eece72ec9587e172f63cff87c00b062f6e37448ced4493", size = 13418120, upload-time = "2025-09-29T23:22:10.109Z" }, + { url = "https://files.pythonhosted.org/packages/4f/c7/e54682c96a895d0c808453269e0b5928a07a127a15704fedb643e9b0a4c8/pandas-2.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:f8bfc0e12dc78f777f323f55c58649591b2cd0c43534e8355c51d3fede5f4dee", size = 10993991, upload-time = "2025-09-29T23:25:04.889Z" }, + { url = "https://files.pythonhosted.org/packages/f9/ca/3f8d4f49740799189e1395812f3bf23b5e8fc7c190827d55a610da72ce55/pandas-2.3.3-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:75ea25f9529fdec2d2e93a42c523962261e567d250b0013b16210e1d40d7c2e5", size = 12048227, upload-time = "2025-09-29T23:22:24.343Z" }, + { url = "https://files.pythonhosted.org/packages/0e/5a/f43efec3e8c0cc92c4663ccad372dbdff72b60bdb56b2749f04aa1d07d7e/pandas-2.3.3-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:74ecdf1d301e812db96a465a525952f4dde225fdb6d8e5a521d47e1f42041e21", size = 11411056, upload-time = "2025-09-29T23:22:37.762Z" }, + { url = "https://files.pythonhosted.org/packages/46/b1/85331edfc591208c9d1a63a06baa67b21d332e63b7a591a5ba42a10bb507/pandas-2.3.3-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6435cb949cb34ec11cc9860246ccb2fdc9ecd742c12d3304989017d53f039a78", size = 11645189, upload-time = "2025-09-29T23:22:51.688Z" }, + { url = "https://files.pythonhosted.org/packages/44/23/78d645adc35d94d1ac4f2a3c4112ab6f5b8999f4898b8cdf01252f8df4a9/pandas-2.3.3-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:900f47d8f20860de523a1ac881c4c36d65efcb2eb850e6948140fa781736e110", size = 12121912, upload-time = "2025-09-29T23:23:05.042Z" }, + { url = "https://files.pythonhosted.org/packages/53/da/d10013df5e6aaef6b425aa0c32e1fc1f3e431e4bcabd420517dceadce354/pandas-2.3.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a45c765238e2ed7d7c608fc5bc4a6f88b642f2f01e70c0c23d2224dd21829d86", size = 12712160, upload-time = "2025-09-29T23:23:28.57Z" }, + { url = "https://files.pythonhosted.org/packages/bd/17/e756653095a083d8a37cbd816cb87148debcfcd920129b25f99dd8d04271/pandas-2.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c4fc4c21971a1a9f4bdb4c73978c7f7256caa3e62b323f70d6cb80db583350bc", size = 13199233, upload-time = "2025-09-29T23:24:24.876Z" }, +] + +[[package]] +name = "parso" +version = "0.8.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d4/de/53e0bcf53d13e005bd8c92e7855142494f41171b34c2536b86187474184d/parso-0.8.5.tar.gz", hash = "sha256:034d7354a9a018bdce352f48b2a8a450f05e9d6ee85db84764e9b6bd96dafe5a", size = 401205, upload-time = "2025-08-23T15:15:28.028Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/32/f8e3c85d1d5250232a5d3477a2a28cc291968ff175caeadaf3cc19ce0e4a/parso-0.8.5-py2.py3-none-any.whl", hash = "sha256:646204b5ee239c396d040b90f9e272e9a8017c630092bf59980beb62fd033887", size = 106668, upload-time = "2025-08-23T15:15:25.663Z" }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, +] + +[[package]] +name = "pillow" +version = "11.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/0d/d0d6dea55cd152ce3d6767bb38a8fc10e33796ba4ba210cbab9354b6d238/pillow-11.3.0.tar.gz", hash = "sha256:3828ee7586cd0b2091b6209e5ad53e20d0649bbe87164a459d0676e035e8f523", size = 47113069, upload-time = "2025-07-01T09:16:30.666Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/93/0952f2ed8db3a5a4c7a11f91965d6184ebc8cd7cbb7941a260d5f018cd2d/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:1c627742b539bba4309df89171356fcb3cc5a9178355b2727d1b74a6cf155fbd", size = 2128328, upload-time = "2025-07-01T09:14:35.276Z" }, + { url = "https://files.pythonhosted.org/packages/4b/e8/100c3d114b1a0bf4042f27e0f87d2f25e857e838034e98ca98fe7b8c0a9c/pillow-11.3.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30b7c02f3899d10f13d7a48163c8969e4e653f8b43416d23d13d1bbfdc93b9f8", size = 2170652, upload-time = "2025-07-01T09:14:37.203Z" }, + { url = "https://files.pythonhosted.org/packages/aa/86/3f758a28a6e381758545f7cdb4942e1cb79abd271bea932998fc0db93cb6/pillow-11.3.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:7859a4cc7c9295f5838015d8cc0a9c215b77e43d07a25e460f35cf516df8626f", size = 2227443, upload-time = "2025-07-01T09:14:39.344Z" }, + { url = "https://files.pythonhosted.org/packages/01/f4/91d5b3ffa718df2f53b0dc109877993e511f4fd055d7e9508682e8aba092/pillow-11.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec1ee50470b0d050984394423d96325b744d55c701a439d2bd66089bff963d3c", size = 5278474, upload-time = "2025-07-01T09:14:41.843Z" }, + { url = "https://files.pythonhosted.org/packages/f9/0e/37d7d3eca6c879fbd9dba21268427dffda1ab00d4eb05b32923d4fbe3b12/pillow-11.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7db51d222548ccfd274e4572fdbf3e810a5e66b00608862f947b163e613b67dd", size = 4686038, upload-time = "2025-07-01T09:14:44.008Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b0/3426e5c7f6565e752d81221af9d3676fdbb4f352317ceafd42899aaf5d8a/pillow-11.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2d6fcc902a24ac74495df63faad1884282239265c6839a0a6416d33faedfae7e", size = 5864407, upload-time = "2025-07-03T13:10:15.628Z" }, + { url = "https://files.pythonhosted.org/packages/fc/c1/c6c423134229f2a221ee53f838d4be9d82bab86f7e2f8e75e47b6bf6cd77/pillow-11.3.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f0f5d8f4a08090c6d6d578351a2b91acf519a54986c055af27e7a93feae6d3f1", size = 7639094, upload-time = "2025-07-03T13:10:21.857Z" }, + { url = "https://files.pythonhosted.org/packages/ba/c9/09e6746630fe6372c67c648ff9deae52a2bc20897d51fa293571977ceb5d/pillow-11.3.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c37d8ba9411d6003bba9e518db0db0c58a680ab9fe5179f040b0463644bc9805", size = 5973503, upload-time = "2025-07-01T09:14:45.698Z" }, + { url = "https://files.pythonhosted.org/packages/d5/1c/a2a29649c0b1983d3ef57ee87a66487fdeb45132df66ab30dd37f7dbe162/pillow-11.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:13f87d581e71d9189ab21fe0efb5a23e9f28552d5be6979e84001d3b8505abe8", size = 6642574, upload-time = "2025-07-01T09:14:47.415Z" }, + { url = "https://files.pythonhosted.org/packages/36/de/d5cc31cc4b055b6c6fd990e3e7f0f8aaf36229a2698501bcb0cdf67c7146/pillow-11.3.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:023f6d2d11784a465f09fd09a34b150ea4672e85fb3d05931d89f373ab14abb2", size = 6084060, upload-time = "2025-07-01T09:14:49.636Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ea/502d938cbaeec836ac28a9b730193716f0114c41325db428e6b280513f09/pillow-11.3.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:45dfc51ac5975b938e9809451c51734124e73b04d0f0ac621649821a63852e7b", size = 6721407, upload-time = "2025-07-01T09:14:51.962Z" }, + { url = "https://files.pythonhosted.org/packages/45/9c/9c5e2a73f125f6cbc59cc7087c8f2d649a7ae453f83bd0362ff7c9e2aee2/pillow-11.3.0-cp313-cp313-win32.whl", hash = "sha256:a4d336baed65d50d37b88ca5b60c0fa9d81e3a87d4a7930d3880d1624d5b31f3", size = 6273841, upload-time = "2025-07-01T09:14:54.142Z" }, + { url = "https://files.pythonhosted.org/packages/23/85/397c73524e0cd212067e0c969aa245b01d50183439550d24d9f55781b776/pillow-11.3.0-cp313-cp313-win_amd64.whl", hash = "sha256:0bce5c4fd0921f99d2e858dc4d4d64193407e1b99478bc5cacecba2311abde51", size = 6978450, upload-time = "2025-07-01T09:14:56.436Z" }, + { url = "https://files.pythonhosted.org/packages/17/d2/622f4547f69cd173955194b78e4d19ca4935a1b0f03a302d655c9f6aae65/pillow-11.3.0-cp313-cp313-win_arm64.whl", hash = "sha256:1904e1264881f682f02b7f8167935cce37bc97db457f8e7849dc3a6a52b99580", size = 2423055, upload-time = "2025-07-01T09:14:58.072Z" }, + { url = "https://files.pythonhosted.org/packages/dd/80/a8a2ac21dda2e82480852978416cfacd439a4b490a501a288ecf4fe2532d/pillow-11.3.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:4c834a3921375c48ee6b9624061076bc0a32a60b5532b322cc0ea64e639dd50e", size = 5281110, upload-time = "2025-07-01T09:14:59.79Z" }, + { url = "https://files.pythonhosted.org/packages/44/d6/b79754ca790f315918732e18f82a8146d33bcd7f4494380457ea89eb883d/pillow-11.3.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5e05688ccef30ea69b9317a9ead994b93975104a677a36a8ed8106be9260aa6d", size = 4689547, upload-time = "2025-07-01T09:15:01.648Z" }, + { url = "https://files.pythonhosted.org/packages/49/20/716b8717d331150cb00f7fdd78169c01e8e0c219732a78b0e59b6bdb2fd6/pillow-11.3.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1019b04af07fc0163e2810167918cb5add8d74674b6267616021ab558dc98ced", size = 5901554, upload-time = "2025-07-03T13:10:27.018Z" }, + { url = "https://files.pythonhosted.org/packages/74/cf/a9f3a2514a65bb071075063a96f0a5cf949c2f2fce683c15ccc83b1c1cab/pillow-11.3.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f944255db153ebb2b19c51fe85dd99ef0ce494123f21b9db4877ffdfc5590c7c", size = 7669132, upload-time = "2025-07-03T13:10:33.01Z" }, + { url = "https://files.pythonhosted.org/packages/98/3c/da78805cbdbee9cb43efe8261dd7cc0b4b93f2ac79b676c03159e9db2187/pillow-11.3.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1f85acb69adf2aaee8b7da124efebbdb959a104db34d3a2cb0f3793dbae422a8", size = 6005001, upload-time = "2025-07-01T09:15:03.365Z" }, + { url = "https://files.pythonhosted.org/packages/6c/fa/ce044b91faecf30e635321351bba32bab5a7e034c60187fe9698191aef4f/pillow-11.3.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:05f6ecbeff5005399bb48d198f098a9b4b6bdf27b8487c7f38ca16eeb070cd59", size = 6668814, upload-time = "2025-07-01T09:15:05.655Z" }, + { url = "https://files.pythonhosted.org/packages/7b/51/90f9291406d09bf93686434f9183aba27b831c10c87746ff49f127ee80cb/pillow-11.3.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a7bc6e6fd0395bc052f16b1a8670859964dbd7003bd0af2ff08342eb6e442cfe", size = 6113124, upload-time = "2025-07-01T09:15:07.358Z" }, + { url = "https://files.pythonhosted.org/packages/cd/5a/6fec59b1dfb619234f7636d4157d11fb4e196caeee220232a8d2ec48488d/pillow-11.3.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:83e1b0161c9d148125083a35c1c5a89db5b7054834fd4387499e06552035236c", size = 6747186, upload-time = "2025-07-01T09:15:09.317Z" }, + { url = "https://files.pythonhosted.org/packages/49/6b/00187a044f98255225f172de653941e61da37104a9ea60e4f6887717e2b5/pillow-11.3.0-cp313-cp313t-win32.whl", hash = "sha256:2a3117c06b8fb646639dce83694f2f9eac405472713fcb1ae887469c0d4f6788", size = 6277546, upload-time = "2025-07-01T09:15:11.311Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5c/6caaba7e261c0d75bab23be79f1d06b5ad2a2ae49f028ccec801b0e853d6/pillow-11.3.0-cp313-cp313t-win_amd64.whl", hash = "sha256:857844335c95bea93fb39e0fa2726b4d9d758850b34075a7e3ff4f4fa3aa3b31", size = 6985102, upload-time = "2025-07-01T09:15:13.164Z" }, + { url = "https://files.pythonhosted.org/packages/f3/7e/b623008460c09a0cb38263c93b828c666493caee2eb34ff67f778b87e58c/pillow-11.3.0-cp313-cp313t-win_arm64.whl", hash = "sha256:8797edc41f3e8536ae4b10897ee2f637235c94f27404cac7297f7b607dd0716e", size = 2424803, upload-time = "2025-07-01T09:15:15.695Z" }, +] + +[[package]] +name = "platformdirs" +version = "4.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/61/33/9611380c2bdb1225fdef633e2a9610622310fed35ab11dac9620972ee088/platformdirs-4.5.0.tar.gz", hash = "sha256:70ddccdd7c99fc5942e9fc25636a8b34d04c24b335100223152c2803e4063312", size = 21632, upload-time = "2025-10-08T17:44:48.791Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, +] + +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + +[[package]] +name = "pre-commit" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ff/29/7cf5bbc236333876e4b41f56e06857a87937ce4bf91e117a6991a2dbb02a/pre_commit-4.3.0.tar.gz", hash = "sha256:499fe450cc9d42e9d58e606262795ecb64dd05438943c62b66f6a8673da30b16", size = 193792, upload-time = "2025-08-09T18:56:14.651Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5b/a5/987a405322d78a73b66e39e4a90e4ef156fd7141bf71df987e50717c321b/pre_commit-4.3.0-py2.py3-none-any.whl", hash = "sha256:2b0747ad7e6e967169136edffee14c16e148a778a54e4f967921aa1ebf2308d8", size = 220965, upload-time = "2025-08-09T18:56:13.192Z" }, +] + +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + +[[package]] +name = "propcache" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/da/e9fc233cf63743258bff22b3dfa7ea5baef7b5bc324af47a0ad89b8ffc6f/propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d", size = 46442, upload-time = "2025-10-08T19:49:02.291Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/df/6d9c1b6ac12b003837dde8a10231a7344512186e87b36e855bef32241942/propcache-0.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:43eedf29202c08550aac1d14e0ee619b0430aaef78f85864c1a892294fbc28cf", size = 77750, upload-time = "2025-10-08T19:47:07.648Z" }, + { url = "https://files.pythonhosted.org/packages/8b/e8/677a0025e8a2acf07d3418a2e7ba529c9c33caf09d3c1f25513023c1db56/propcache-0.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d62cdfcfd89ccb8de04e0eda998535c406bf5e060ffd56be6c586cbcc05b3311", size = 44780, upload-time = "2025-10-08T19:47:08.851Z" }, + { url = "https://files.pythonhosted.org/packages/89/a4/92380f7ca60f99ebae761936bc48a72a639e8a47b29050615eef757cb2a7/propcache-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cae65ad55793da34db5f54e4029b89d3b9b9490d8abe1b4c7ab5d4b8ec7ebf74", size = 46308, upload-time = "2025-10-08T19:47:09.982Z" }, + { url = "https://files.pythonhosted.org/packages/2d/48/c5ac64dee5262044348d1d78a5f85dd1a57464a60d30daee946699963eb3/propcache-0.4.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:333ddb9031d2704a301ee3e506dc46b1fe5f294ec198ed6435ad5b6a085facfe", size = 208182, upload-time = "2025-10-08T19:47:11.319Z" }, + { url = "https://files.pythonhosted.org/packages/c6/0c/cd762dd011a9287389a6a3eb43aa30207bde253610cca06824aeabfe9653/propcache-0.4.1-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:fd0858c20f078a32cf55f7e81473d96dcf3b93fd2ccdb3d40fdf54b8573df3af", size = 211215, upload-time = "2025-10-08T19:47:13.146Z" }, + { url = "https://files.pythonhosted.org/packages/30/3e/49861e90233ba36890ae0ca4c660e95df565b2cd15d4a68556ab5865974e/propcache-0.4.1-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:678ae89ebc632c5c204c794f8dab2837c5f159aeb59e6ed0539500400577298c", size = 218112, upload-time = "2025-10-08T19:47:14.913Z" }, + { url = "https://files.pythonhosted.org/packages/f1/8b/544bc867e24e1bd48f3118cecd3b05c694e160a168478fa28770f22fd094/propcache-0.4.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d472aeb4fbf9865e0c6d622d7f4d54a4e101a89715d8904282bb5f9a2f476c3f", size = 204442, upload-time = "2025-10-08T19:47:16.277Z" }, + { url = "https://files.pythonhosted.org/packages/50/a6/4282772fd016a76d3e5c0df58380a5ea64900afd836cec2c2f662d1b9bb3/propcache-0.4.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:4d3df5fa7e36b3225954fba85589da77a0fe6a53e3976de39caf04a0db4c36f1", size = 199398, upload-time = "2025-10-08T19:47:17.962Z" }, + { url = "https://files.pythonhosted.org/packages/3e/ec/d8a7cd406ee1ddb705db2139f8a10a8a427100347bd698e7014351c7af09/propcache-0.4.1-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:ee17f18d2498f2673e432faaa71698032b0127ebf23ae5974eeaf806c279df24", size = 196920, upload-time = "2025-10-08T19:47:19.355Z" }, + { url = "https://files.pythonhosted.org/packages/f6/6c/f38ab64af3764f431e359f8baf9e0a21013e24329e8b85d2da32e8ed07ca/propcache-0.4.1-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:580e97762b950f993ae618e167e7be9256b8353c2dcd8b99ec100eb50f5286aa", size = 203748, upload-time = "2025-10-08T19:47:21.338Z" }, + { url = "https://files.pythonhosted.org/packages/d6/e3/fa846bd70f6534d647886621388f0a265254d30e3ce47e5c8e6e27dbf153/propcache-0.4.1-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:501d20b891688eb8e7aa903021f0b72d5a55db40ffaab27edefd1027caaafa61", size = 205877, upload-time = "2025-10-08T19:47:23.059Z" }, + { url = "https://files.pythonhosted.org/packages/e2/39/8163fc6f3133fea7b5f2827e8eba2029a0277ab2c5beee6c1db7b10fc23d/propcache-0.4.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9a0bd56e5b100aef69bd8562b74b46254e7c8812918d3baa700c8a8009b0af66", size = 199437, upload-time = "2025-10-08T19:47:24.445Z" }, + { url = "https://files.pythonhosted.org/packages/93/89/caa9089970ca49c7c01662bd0eeedfe85494e863e8043565aeb6472ce8fe/propcache-0.4.1-cp313-cp313-win32.whl", hash = "sha256:bcc9aaa5d80322bc2fb24bb7accb4a30f81e90ab8d6ba187aec0744bc302ad81", size = 37586, upload-time = "2025-10-08T19:47:25.736Z" }, + { url = "https://files.pythonhosted.org/packages/f5/ab/f76ec3c3627c883215b5c8080debb4394ef5a7a29be811f786415fc1e6fd/propcache-0.4.1-cp313-cp313-win_amd64.whl", hash = "sha256:381914df18634f5494334d201e98245c0596067504b9372d8cf93f4bb23e025e", size = 40790, upload-time = "2025-10-08T19:47:26.847Z" }, + { url = "https://files.pythonhosted.org/packages/59/1b/e71ae98235f8e2ba5004d8cb19765a74877abf189bc53fc0c80d799e56c3/propcache-0.4.1-cp313-cp313-win_arm64.whl", hash = "sha256:8873eb4460fd55333ea49b7d189749ecf6e55bf85080f11b1c4530ed3034cba1", size = 37158, upload-time = "2025-10-08T19:47:27.961Z" }, + { url = "https://files.pythonhosted.org/packages/83/ce/a31bbdfc24ee0dcbba458c8175ed26089cf109a55bbe7b7640ed2470cfe9/propcache-0.4.1-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:92d1935ee1f8d7442da9c0c4fa7ac20d07e94064184811b685f5c4fada64553b", size = 81451, upload-time = "2025-10-08T19:47:29.445Z" }, + { url = "https://files.pythonhosted.org/packages/25/9c/442a45a470a68456e710d96cacd3573ef26a1d0a60067e6a7d5e655621ed/propcache-0.4.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:473c61b39e1460d386479b9b2f337da492042447c9b685f28be4f74d3529e566", size = 46374, upload-time = "2025-10-08T19:47:30.579Z" }, + { url = "https://files.pythonhosted.org/packages/f4/bf/b1d5e21dbc3b2e889ea4327044fb16312a736d97640fb8b6aa3f9c7b3b65/propcache-0.4.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c0ef0aaafc66fbd87842a3fe3902fd889825646bc21149eafe47be6072725835", size = 48396, upload-time = "2025-10-08T19:47:31.79Z" }, + { url = "https://files.pythonhosted.org/packages/f4/04/5b4c54a103d480e978d3c8a76073502b18db0c4bc17ab91b3cb5092ad949/propcache-0.4.1-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f95393b4d66bfae908c3ca8d169d5f79cd65636ae15b5e7a4f6e67af675adb0e", size = 275950, upload-time = "2025-10-08T19:47:33.481Z" }, + { url = "https://files.pythonhosted.org/packages/b4/c1/86f846827fb969c4b78b0af79bba1d1ea2156492e1b83dea8b8a6ae27395/propcache-0.4.1-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c07fda85708bc48578467e85099645167a955ba093be0a2dcba962195676e859", size = 273856, upload-time = "2025-10-08T19:47:34.906Z" }, + { url = "https://files.pythonhosted.org/packages/36/1d/fc272a63c8d3bbad6878c336c7a7dea15e8f2d23a544bda43205dfa83ada/propcache-0.4.1-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:af223b406d6d000830c6f65f1e6431783fc3f713ba3e6cc8c024d5ee96170a4b", size = 280420, upload-time = "2025-10-08T19:47:36.338Z" }, + { url = "https://files.pythonhosted.org/packages/07/0c/01f2219d39f7e53d52e5173bcb09c976609ba30209912a0680adfb8c593a/propcache-0.4.1-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a78372c932c90ee474559c5ddfffd718238e8673c340dc21fe45c5b8b54559a0", size = 263254, upload-time = "2025-10-08T19:47:37.692Z" }, + { url = "https://files.pythonhosted.org/packages/2d/18/cd28081658ce597898f0c4d174d4d0f3c5b6d4dc27ffafeef835c95eb359/propcache-0.4.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:564d9f0d4d9509e1a870c920a89b2fec951b44bf5ba7d537a9e7c1ccec2c18af", size = 261205, upload-time = "2025-10-08T19:47:39.659Z" }, + { url = "https://files.pythonhosted.org/packages/7a/71/1f9e22eb8b8316701c2a19fa1f388c8a3185082607da8e406a803c9b954e/propcache-0.4.1-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:17612831fda0138059cc5546f4d12a2aacfb9e47068c06af35c400ba58ba7393", size = 247873, upload-time = "2025-10-08T19:47:41.084Z" }, + { url = "https://files.pythonhosted.org/packages/4a/65/3d4b61f36af2b4eddba9def857959f1016a51066b4f1ce348e0cf7881f58/propcache-0.4.1-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:41a89040cb10bd345b3c1a873b2bf36413d48da1def52f268a055f7398514874", size = 262739, upload-time = "2025-10-08T19:47:42.51Z" }, + { url = "https://files.pythonhosted.org/packages/2a/42/26746ab087faa77c1c68079b228810436ccd9a5ce9ac85e2b7307195fd06/propcache-0.4.1-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e35b88984e7fa64aacecea39236cee32dd9bd8c55f57ba8a75cf2399553f9bd7", size = 263514, upload-time = "2025-10-08T19:47:43.927Z" }, + { url = "https://files.pythonhosted.org/packages/94/13/630690fe201f5502d2403dd3cfd451ed8858fe3c738ee88d095ad2ff407b/propcache-0.4.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6f8b465489f927b0df505cbe26ffbeed4d6d8a2bbc61ce90eb074ff129ef0ab1", size = 257781, upload-time = "2025-10-08T19:47:45.448Z" }, + { url = "https://files.pythonhosted.org/packages/92/f7/1d4ec5841505f423469efbfc381d64b7b467438cd5a4bbcbb063f3b73d27/propcache-0.4.1-cp313-cp313t-win32.whl", hash = "sha256:2ad890caa1d928c7c2965b48f3a3815c853180831d0e5503d35cf00c472f4717", size = 41396, upload-time = "2025-10-08T19:47:47.202Z" }, + { url = "https://files.pythonhosted.org/packages/48/f0/615c30622316496d2cbbc29f5985f7777d3ada70f23370608c1d3e081c1f/propcache-0.4.1-cp313-cp313t-win_amd64.whl", hash = "sha256:f7ee0e597f495cf415bcbd3da3caa3bd7e816b74d0d52b8145954c5e6fd3ff37", size = 44897, upload-time = "2025-10-08T19:47:48.336Z" }, + { url = "https://files.pythonhosted.org/packages/fd/ca/6002e46eccbe0e33dcd4069ef32f7f1c9e243736e07adca37ae8c4830ec3/propcache-0.4.1-cp313-cp313t-win_arm64.whl", hash = "sha256:929d7cbe1f01bb7baffb33dc14eb5691c95831450a26354cd210a8155170c93a", size = 39789, upload-time = "2025-10-08T19:47:49.876Z" }, + { url = "https://files.pythonhosted.org/packages/5b/5a/bc7b4a4ef808fa59a816c17b20c4bef6884daebbdf627ff2a161da67da19/propcache-0.4.1-py3-none-any.whl", hash = "sha256:af2a6052aeb6cf17d3e46ee169099044fd8224cbaf75c76a2ef596e8163e2237", size = 13305, upload-time = "2025-10-08T19:49:00.792Z" }, +] + +[[package]] +name = "protobuf" +version = "6.33.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/19/ff/64a6c8f420818bb873713988ca5492cba3a7946be57e027ac63495157d97/protobuf-6.33.0.tar.gz", hash = "sha256:140303d5c8d2037730c548f8c7b93b20bb1dc301be280c378b82b8894589c954", size = 443463, upload-time = "2025-10-15T20:39:52.159Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/ee/52b3fa8feb6db4a833dfea4943e175ce645144532e8a90f72571ad85df4e/protobuf-6.33.0-cp310-abi3-win32.whl", hash = "sha256:d6101ded078042a8f17959eccd9236fb7a9ca20d3b0098bbcb91533a5680d035", size = 425593, upload-time = "2025-10-15T20:39:40.29Z" }, + { url = "https://files.pythonhosted.org/packages/7b/c6/7a465f1825872c55e0341ff4a80198743f73b69ce5d43ab18043699d1d81/protobuf-6.33.0-cp310-abi3-win_amd64.whl", hash = "sha256:9a031d10f703f03768f2743a1c403af050b6ae1f3480e9c140f39c45f81b13ee", size = 436882, upload-time = "2025-10-15T20:39:42.841Z" }, + { url = "https://files.pythonhosted.org/packages/e1/a9/b6eee662a6951b9c3640e8e452ab3e09f117d99fc10baa32d1581a0d4099/protobuf-6.33.0-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:905b07a65f1a4b72412314082c7dbfae91a9e8b68a0cc1577515f8df58ecf455", size = 427521, upload-time = "2025-10-15T20:39:43.803Z" }, + { url = "https://files.pythonhosted.org/packages/10/35/16d31e0f92c6d2f0e77c2a3ba93185130ea13053dd16200a57434c882f2b/protobuf-6.33.0-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e0697ece353e6239b90ee43a9231318302ad8353c70e6e45499fa52396debf90", size = 324445, upload-time = "2025-10-15T20:39:44.932Z" }, + { url = "https://files.pythonhosted.org/packages/e6/eb/2a981a13e35cda8b75b5585aaffae2eb904f8f351bdd3870769692acbd8a/protobuf-6.33.0-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:e0a1715e4f27355afd9570f3ea369735afc853a6c3951a6afe1f80d8569ad298", size = 339159, upload-time = "2025-10-15T20:39:46.186Z" }, + { url = "https://files.pythonhosted.org/packages/21/51/0b1cbad62074439b867b4e04cc09b93f6699d78fd191bed2bbb44562e077/protobuf-6.33.0-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:35be49fd3f4fefa4e6e2aacc35e8b837d6703c37a2168a55ac21e9b1bc7559ef", size = 323172, upload-time = "2025-10-15T20:39:47.465Z" }, + { url = "https://files.pythonhosted.org/packages/07/d1/0a28c21707807c6aacd5dc9c3704b2aa1effbf37adebd8caeaf68b17a636/protobuf-6.33.0-py3-none-any.whl", hash = "sha256:25c9e1963c6734448ea2d308cfa610e692b801304ba0908d7bfa564ac5132995", size = 170477, upload-time = "2025-10-15T20:39:51.311Z" }, +] + +[[package]] +name = "psutil" +version = "7.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/ec/7b8e6b9b1d22708138630ef34c53ab2b61032c04f16adfdbb96791c8c70c/psutil-7.1.2.tar.gz", hash = "sha256:aa225cdde1335ff9684708ee8c72650f6598d5ed2114b9a7c5802030b1785018", size = 487424, upload-time = "2025-10-25T10:46:34.931Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/d9/b56cc9f883140ac10021a8c9b0f4e16eed1ba675c22513cdcbce3ba64014/psutil-7.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0cc5c6889b9871f231ed5455a9a02149e388fffcb30b607fb7a8896a6d95f22e", size = 238575, upload-time = "2025-10-25T10:46:38.728Z" }, + { url = "https://files.pythonhosted.org/packages/36/eb/28d22de383888deb252c818622196e709da98816e296ef95afda33f1c0a2/psutil-7.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8e9e77a977208d84aa363a4a12e0f72189d58bbf4e46b49aae29a2c6e93ef206", size = 239297, upload-time = "2025-10-25T10:46:41.347Z" }, + { url = "https://files.pythonhosted.org/packages/89/5d/220039e2f28cc129626e54d63892ab05c0d56a29818bfe7268dcb5008932/psutil-7.1.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7d9623a5e4164d2220ecceb071f4b333b3c78866141e8887c072129185f41278", size = 280420, upload-time = "2025-10-25T10:46:44.122Z" }, + { url = "https://files.pythonhosted.org/packages/ba/7a/286f0e1c167445b2ef4a6cbdfc8c59fdb45a5a493788950cf8467201dc73/psutil-7.1.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:364b1c10fe4ed59c89ec49e5f1a70da353b27986fa8233b4b999df4742a5ee2f", size = 283049, upload-time = "2025-10-25T10:46:47.095Z" }, + { url = "https://files.pythonhosted.org/packages/aa/cc/7eb93260794a42e39b976f3a4dde89725800b9f573b014fac142002a5c98/psutil-7.1.2-cp313-cp313t-win_amd64.whl", hash = "sha256:f101ef84de7e05d41310e3ccbdd65a6dd1d9eed85e8aaf0758405d022308e204", size = 248713, upload-time = "2025-10-25T10:46:49.573Z" }, + { url = "https://files.pythonhosted.org/packages/ab/1a/0681a92b53366e01f0a099f5237d0c8a2f79d322ac589cccde5e30c8a4e2/psutil-7.1.2-cp313-cp313t-win_arm64.whl", hash = "sha256:20c00824048a95de67f00afedc7b08b282aa08638585b0206a9fb51f28f1a165", size = 244644, upload-time = "2025-10-25T10:46:51.924Z" }, + { url = "https://files.pythonhosted.org/packages/ae/89/b9f8d47ddbc52d7301fc868e8224e5f44ed3c7f55e6d0f54ecaf5dd9ff5e/psutil-7.1.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c9ba5c19f2d46203ee8c152c7b01df6eec87d883cfd8ee1af2ef2727f6b0f814", size = 237244, upload-time = "2025-10-25T10:47:07.086Z" }, + { url = "https://files.pythonhosted.org/packages/c8/7a/8628c2f6b240680a67d73d8742bb9ff39b1820a693740e43096d5dcb01e5/psutil-7.1.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a486030d2fe81bec023f703d3d155f4823a10a47c36784c84f1cc7f8d39bedb", size = 238101, upload-time = "2025-10-25T10:47:09.523Z" }, + { url = "https://files.pythonhosted.org/packages/30/28/5e27f4d5a0e347f8e3cc16cd7d35533dbce086c95807f1f0e9cd77e26c10/psutil-7.1.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3efd8fc791492e7808a51cb2b94889db7578bfaea22df931424f874468e389e3", size = 258675, upload-time = "2025-10-25T10:47:11.082Z" }, + { url = "https://files.pythonhosted.org/packages/e5/5c/79cf60c9acf36d087f0db0f82066fca4a780e97e5b3a2e4c38209c03d170/psutil-7.1.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2aeb9b64f481b8eabfc633bd39e0016d4d8bbcd590d984af764d80bf0851b8a", size = 260203, upload-time = "2025-10-25T10:47:13.226Z" }, + { url = "https://files.pythonhosted.org/packages/f7/03/0a464404c51685dcb9329fdd660b1721e076ccd7b3d97dee066bcc9ffb15/psutil-7.1.2-cp37-abi3-win_amd64.whl", hash = "sha256:8e17852114c4e7996fe9da4745c2bdef001ebbf2f260dec406290e66628bdb91", size = 246714, upload-time = "2025-10-25T10:47:15.093Z" }, + { url = "https://files.pythonhosted.org/packages/6a/32/97ca2090f2f1b45b01b6aa7ae161cfe50671de097311975ca6eea3e7aabc/psutil-7.1.2-cp37-abi3-win_arm64.whl", hash = "sha256:3e988455e61c240cc879cb62a008c2699231bf3e3d061d7fce4234463fd2abb4", size = 243742, upload-time = "2025-10-25T10:47:17.302Z" }, +] + +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762, upload-time = "2020-12-28T15:15:30.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993, upload-time = "2020-12-28T15:15:28.35Z" }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752, upload-time = "2024-07-21T12:58:21.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, +] + +[[package]] +name = "pyarrow" +version = "22.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/53/04a7fdc63e6056116c9ddc8b43bc28c12cdd181b85cbeadb79278475f3ae/pyarrow-22.0.0.tar.gz", hash = "sha256:3d600dc583260d845c7d8a6db540339dd883081925da2bd1c5cb808f720b3cd9", size = 1151151, upload-time = "2025-10-24T12:30:00.762Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/d6/d0fac16a2963002fc22c8fa75180a838737203d558f0ed3b564c4a54eef5/pyarrow-22.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e6e95176209257803a8b3d0394f21604e796dadb643d2f7ca21b66c9c0b30c9a", size = 34204629, upload-time = "2025-10-24T10:06:20.274Z" }, + { url = "https://files.pythonhosted.org/packages/c6/9c/1d6357347fbae062ad3f17082f9ebc29cc733321e892c0d2085f42a2212b/pyarrow-22.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:001ea83a58024818826a9e3f89bf9310a114f7e26dfe404a4c32686f97bd7901", size = 35985783, upload-time = "2025-10-24T10:06:27.301Z" }, + { url = "https://files.pythonhosted.org/packages/ff/c0/782344c2ce58afbea010150df07e3a2f5fdad299cd631697ae7bd3bac6e3/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ce20fe000754f477c8a9125543f1936ea5b8867c5406757c224d745ed033e691", size = 45020999, upload-time = "2025-10-24T10:06:35.387Z" }, + { url = "https://files.pythonhosted.org/packages/1b/8b/5362443737a5307a7b67c1017c42cd104213189b4970bf607e05faf9c525/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e0a15757fccb38c410947df156f9749ae4a3c89b2393741a50521f39a8cf202a", size = 47724601, upload-time = "2025-10-24T10:06:43.551Z" }, + { url = "https://files.pythonhosted.org/packages/69/4d/76e567a4fc2e190ee6072967cb4672b7d9249ac59ae65af2d7e3047afa3b/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cedb9dd9358e4ea1d9bce3665ce0797f6adf97ff142c8e25b46ba9cdd508e9b6", size = 48001050, upload-time = "2025-10-24T10:06:52.284Z" }, + { url = "https://files.pythonhosted.org/packages/01/5e/5653f0535d2a1aef8223cee9d92944cb6bccfee5cf1cd3f462d7cb022790/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:252be4a05f9d9185bb8c18e83764ebcfea7185076c07a7a662253af3a8c07941", size = 50307877, upload-time = "2025-10-24T10:07:02.405Z" }, + { url = "https://files.pythonhosted.org/packages/2d/f8/1d0bd75bf9328a3b826e24a16e5517cd7f9fbf8d34a3184a4566ef5a7f29/pyarrow-22.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:a4893d31e5ef780b6edcaf63122df0f8d321088bb0dee4c8c06eccb1ca28d145", size = 27977099, upload-time = "2025-10-24T10:08:07.259Z" }, + { url = "https://files.pythonhosted.org/packages/90/81/db56870c997805bf2b0f6eeeb2d68458bf4654652dccdcf1bf7a42d80903/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:f7fe3dbe871294ba70d789be16b6e7e52b418311e166e0e3cba9522f0f437fb1", size = 34336685, upload-time = "2025-10-24T10:07:11.47Z" }, + { url = "https://files.pythonhosted.org/packages/1c/98/0727947f199aba8a120f47dfc229eeb05df15bcd7a6f1b669e9f882afc58/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:ba95112d15fd4f1105fb2402c4eab9068f0554435e9b7085924bcfaac2cc306f", size = 36032158, upload-time = "2025-10-24T10:07:18.626Z" }, + { url = "https://files.pythonhosted.org/packages/96/b4/9babdef9c01720a0785945c7cf550e4acd0ebcd7bdd2e6f0aa7981fa85e2/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c064e28361c05d72eed8e744c9605cbd6d2bb7481a511c74071fd9b24bc65d7d", size = 44892060, upload-time = "2025-10-24T10:07:26.002Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ca/2f8804edd6279f78a37062d813de3f16f29183874447ef6d1aadbb4efa0f/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6f9762274496c244d951c819348afbcf212714902742225f649cf02823a6a10f", size = 47504395, upload-time = "2025-10-24T10:07:34.09Z" }, + { url = "https://files.pythonhosted.org/packages/b9/f0/77aa5198fd3943682b2e4faaf179a674f0edea0d55d326d83cb2277d9363/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a9d9ffdc2ab696f6b15b4d1f7cec6658e1d788124418cb30030afbae31c64746", size = 48066216, upload-time = "2025-10-24T10:07:43.528Z" }, + { url = "https://files.pythonhosted.org/packages/79/87/a1937b6e78b2aff18b706d738c9e46ade5bfcf11b294e39c87706a0089ac/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ec1a15968a9d80da01e1d30349b2b0d7cc91e96588ee324ce1b5228175043e95", size = 50288552, upload-time = "2025-10-24T10:07:53.519Z" }, + { url = "https://files.pythonhosted.org/packages/60/ae/b5a5811e11f25788ccfdaa8f26b6791c9807119dffcf80514505527c384c/pyarrow-22.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:bba208d9c7decf9961998edf5c65e3ea4355d5818dd6cd0f6809bec1afb951cc", size = 28262504, upload-time = "2025-10-24T10:08:00.932Z" }, +] + +[[package]] +name = "pycparser" +version = "2.23" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/cf/d2d3b9f5699fb1e4615c8e32ff220203e43b248e1dfcc6736ad9057731ca/pycparser-2.23.tar.gz", hash = "sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2", size = 173734, upload-time = "2025-09-09T13:23:47.91Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/e3/59cd50310fc9b59512193629e1984c1f95e5c8ae6e5d8c69532ccc65a7fe/pycparser-2.23-py3-none-any.whl", hash = "sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934", size = 118140, upload-time = "2025-09-09T13:23:46.651Z" }, +] + +[[package]] +name = "pydantic" +version = "2.11.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ae/54/ecab642b3bed45f7d5f59b38443dcb36ef50f85af192e6ece103dbfe9587/pydantic-2.11.10.tar.gz", hash = "sha256:dc280f0982fbda6c38fada4e476dc0a4f3aeaf9c6ad4c28df68a666ec3c61423", size = 788494, upload-time = "2025-10-04T10:40:41.338Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bd/1f/73c53fcbfb0b5a78f91176df41945ca466e71e9d9d836e5c522abda39ee7/pydantic-2.11.10-py3-none-any.whl", hash = "sha256:802a655709d49bd004c31e865ef37da30b540786a46bfce02333e0e24b5fe29a", size = 444823, upload-time = "2025-10-04T10:40:39.055Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.33.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ad/88/5f2260bdfae97aabf98f1778d43f69574390ad787afb646292a638c923d4/pydantic_core-2.33.2.tar.gz", hash = "sha256:7cb8bc3605c29176e1b105350d2e6474142d7c1bd1d9327c4a9bdb46bf827acc", size = 435195, upload-time = "2025-04-23T18:33:52.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/8c/99040727b41f56616573a28771b1bfa08a3d3fe74d3d513f01251f79f172/pydantic_core-2.33.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1082dd3e2d7109ad8b7da48e1d4710c8d06c253cbc4a27c1cff4fbcaa97a9e3f", size = 2015688, upload-time = "2025-04-23T18:31:53.175Z" }, + { url = "https://files.pythonhosted.org/packages/3a/cc/5999d1eb705a6cefc31f0b4a90e9f7fc400539b1a1030529700cc1b51838/pydantic_core-2.33.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f517ca031dfc037a9c07e748cefd8d96235088b83b4f4ba8939105d20fa1dcd6", size = 1844808, upload-time = "2025-04-23T18:31:54.79Z" }, + { url = "https://files.pythonhosted.org/packages/6f/5e/a0a7b8885c98889a18b6e376f344da1ef323d270b44edf8174d6bce4d622/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a9f2c9dd19656823cb8250b0724ee9c60a82f3cdf68a080979d13092a3b0fef", size = 1885580, upload-time = "2025-04-23T18:31:57.393Z" }, + { url = "https://files.pythonhosted.org/packages/3b/2a/953581f343c7d11a304581156618c3f592435523dd9d79865903272c256a/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2b0a451c263b01acebe51895bfb0e1cc842a5c666efe06cdf13846c7418caa9a", size = 1973859, upload-time = "2025-04-23T18:31:59.065Z" }, + { url = "https://files.pythonhosted.org/packages/e6/55/f1a813904771c03a3f97f676c62cca0c0a4138654107c1b61f19c644868b/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ea40a64d23faa25e62a70ad163571c0b342b8bf66d5fa612ac0dec4f069d916", size = 2120810, upload-time = "2025-04-23T18:32:00.78Z" }, + { url = "https://files.pythonhosted.org/packages/aa/c3/053389835a996e18853ba107a63caae0b9deb4a276c6b472931ea9ae6e48/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0fb2d542b4d66f9470e8065c5469ec676978d625a8b7a363f07d9a501a9cb36a", size = 2676498, upload-time = "2025-04-23T18:32:02.418Z" }, + { url = "https://files.pythonhosted.org/packages/eb/3c/f4abd740877a35abade05e437245b192f9d0ffb48bbbbd708df33d3cda37/pydantic_core-2.33.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9fdac5d6ffa1b5a83bca06ffe7583f5576555e6c8b3a91fbd25ea7780f825f7d", size = 2000611, upload-time = "2025-04-23T18:32:04.152Z" }, + { url = "https://files.pythonhosted.org/packages/59/a7/63ef2fed1837d1121a894d0ce88439fe3e3b3e48c7543b2a4479eb99c2bd/pydantic_core-2.33.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04a1a413977ab517154eebb2d326da71638271477d6ad87a769102f7c2488c56", size = 2107924, upload-time = "2025-04-23T18:32:06.129Z" }, + { url = "https://files.pythonhosted.org/packages/04/8f/2551964ef045669801675f1cfc3b0d74147f4901c3ffa42be2ddb1f0efc4/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:c8e7af2f4e0194c22b5b37205bfb293d166a7344a5b0d0eaccebc376546d77d5", size = 2063196, upload-time = "2025-04-23T18:32:08.178Z" }, + { url = "https://files.pythonhosted.org/packages/26/bd/d9602777e77fc6dbb0c7db9ad356e9a985825547dce5ad1d30ee04903918/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:5c92edd15cd58b3c2d34873597a1e20f13094f59cf88068adb18947df5455b4e", size = 2236389, upload-time = "2025-04-23T18:32:10.242Z" }, + { url = "https://files.pythonhosted.org/packages/42/db/0e950daa7e2230423ab342ae918a794964b053bec24ba8af013fc7c94846/pydantic_core-2.33.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:65132b7b4a1c0beded5e057324b7e16e10910c106d43675d9bd87d4f38dde162", size = 2239223, upload-time = "2025-04-23T18:32:12.382Z" }, + { url = "https://files.pythonhosted.org/packages/58/4d/4f937099c545a8a17eb52cb67fe0447fd9a373b348ccfa9a87f141eeb00f/pydantic_core-2.33.2-cp313-cp313-win32.whl", hash = "sha256:52fb90784e0a242bb96ec53f42196a17278855b0f31ac7c3cc6f5c1ec4811849", size = 1900473, upload-time = "2025-04-23T18:32:14.034Z" }, + { url = "https://files.pythonhosted.org/packages/a0/75/4a0a9bac998d78d889def5e4ef2b065acba8cae8c93696906c3a91f310ca/pydantic_core-2.33.2-cp313-cp313-win_amd64.whl", hash = "sha256:c083a3bdd5a93dfe480f1125926afcdbf2917ae714bdb80b36d34318b2bec5d9", size = 1955269, upload-time = "2025-04-23T18:32:15.783Z" }, + { url = "https://files.pythonhosted.org/packages/f9/86/1beda0576969592f1497b4ce8e7bc8cbdf614c352426271b1b10d5f0aa64/pydantic_core-2.33.2-cp313-cp313-win_arm64.whl", hash = "sha256:e80b087132752f6b3d714f041ccf74403799d3b23a72722ea2e6ba2e892555b9", size = 1893921, upload-time = "2025-04-23T18:32:18.473Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7d/e09391c2eebeab681df2b74bfe6c43422fffede8dc74187b2b0bf6fd7571/pydantic_core-2.33.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:61c18fba8e5e9db3ab908620af374db0ac1baa69f0f32df4f61ae23f15e586ac", size = 1806162, upload-time = "2025-04-23T18:32:20.188Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3d/847b6b1fed9f8ed3bb95a9ad04fbd0b212e832d4f0f50ff4d9ee5a9f15cf/pydantic_core-2.33.2-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95237e53bb015f67b63c91af7518a62a8660376a6a0db19b89acc77a4d6199f5", size = 1981560, upload-time = "2025-04-23T18:32:22.354Z" }, + { url = "https://files.pythonhosted.org/packages/6f/9a/e73262f6c6656262b5fdd723ad90f518f579b7bc8622e43a942eec53c938/pydantic_core-2.33.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c2fc0a768ef76c15ab9238afa6da7f69895bb5d1ee83aeea2e3509af4472d0b9", size = 1935777, upload-time = "2025-04-23T18:32:25.088Z" }, +] + +[[package]] +name = "pydeck" +version = "0.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jinja2" }, + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/ca/40e14e196864a0f61a92abb14d09b3d3da98f94ccb03b49cf51688140dab/pydeck-0.9.1.tar.gz", hash = "sha256:f74475ae637951d63f2ee58326757f8d4f9cd9f2a457cf42950715003e2cb605", size = 3832240, upload-time = "2024-05-10T15:36:21.153Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/4c/b888e6cf58bd9db9c93f40d1c6be8283ff49d88919231afe93a6bcf61626/pydeck-0.9.1-py2.py3-none-any.whl", hash = "sha256:b3f75ba0d273fc917094fa61224f3f6076ca8752b93d46faf3bcfd9f9d59b038", size = 6900403, upload-time = "2024-05-10T15:36:17.36Z" }, +] + +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631, upload-time = "2025-06-21T13:39:12.283Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217, upload-time = "2025-06-21T13:39:07.939Z" }, +] + +[[package]] +name = "pyparsing" +version = "3.2.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/a5/181488fc2b9d093e3972d2a472855aae8a03f000592dbfce716a512b3359/pyparsing-3.2.5.tar.gz", hash = "sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6", size = 1099274, upload-time = "2025-09-21T04:11:06.277Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/5e/1aa9a93198c6b64513c9d7752de7422c06402de6600a8767da1524f9570b/pyparsing-3.2.5-py3-none-any.whl", hash = "sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e", size = 113890, upload-time = "2025-09-21T04:11:04.117Z" }, +] + +[[package]] +name = "pytest" +version = "8.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/5c/00a0e072241553e1a7496d638deababa67c5058571567b92a7eaa258397c/pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01", size = 1519618, upload-time = "2025-09-04T14:34:22.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, +] + +[[package]] +name = "pytest-cov" +version = "7.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage" }, + { name = "pluggy" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5e/f7/c933acc76f5208b3b00089573cf6a2bc26dc80a8aece8f52bb7d6b1855ca/pytest_cov-7.0.0.tar.gz", hash = "sha256:33c97eda2e049a0c5298e91f519302a1334c26ac65c1a483d6206fd458361af1", size = 54328, upload-time = "2025-09-09T10:57:02.113Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, +] + +[[package]] +name = "pytest-xdist" +version = "3.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "execnet" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/78/b4/439b179d1ff526791eb921115fca8e44e596a13efeda518b9d845a619450/pytest_xdist-3.8.0.tar.gz", hash = "sha256:7e578125ec9bc6050861aa93f2d59f1d8d085595d6551c2c90b6f4fad8d3a9f1", size = 88069, upload-time = "2025-07-01T13:30:59.346Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/31/d4e37e9e550c2b92a9cbc2e4d0b7420a27224968580b5a447f420847c975/pytest_xdist-3.8.0-py3-none-any.whl", hash = "sha256:202ca578cfeb7370784a8c33d6d05bc6e13b4f25b5053c30a152269fd10f0b88", size = 46396, upload-time = "2025-07-01T13:30:56.632Z" }, +] + +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, +] + +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + +[[package]] +name = "pytz" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, +] + +[[package]] +name = "pyyaml" +version = "6.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f", size = 130960, upload-time = "2025-09-25T21:33:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8", size = 181669, upload-time = "2025-09-25T21:32:23.673Z" }, + { url = "https://files.pythonhosted.org/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1", size = 173252, upload-time = "2025-09-25T21:32:25.149Z" }, + { url = "https://files.pythonhosted.org/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c", size = 767081, upload-time = "2025-09-25T21:32:26.575Z" }, + { url = "https://files.pythonhosted.org/packages/49/1e/a55ca81e949270d5d4432fbbd19dfea5321eda7c41a849d443dc92fd1ff7/pyyaml-6.0.3-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a33284e20b78bd4a18c8c2282d549d10bc8408a2a7ff57653c0cf0b9be0afce5", size = 841159, upload-time = "2025-09-25T21:32:27.727Z" }, + { url = "https://files.pythonhosted.org/packages/74/27/e5b8f34d02d9995b80abcef563ea1f8b56d20134d8f4e5e81733b1feceb2/pyyaml-6.0.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f29edc409a6392443abf94b9cf89ce99889a1dd5376d94316ae5145dfedd5d6", size = 801626, upload-time = "2025-09-25T21:32:28.878Z" }, + { url = "https://files.pythonhosted.org/packages/f9/11/ba845c23988798f40e52ba45f34849aa8a1f2d4af4b798588010792ebad6/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7057c9a337546edc7973c0d3ba84ddcdf0daa14533c2065749c9075001090e6", size = 753613, upload-time = "2025-09-25T21:32:30.178Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e0/7966e1a7bfc0a45bf0a7fb6b98ea03fc9b8d84fa7f2229e9659680b69ee3/pyyaml-6.0.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eda16858a3cab07b80edaf74336ece1f986ba330fdb8ee0d6c0d68fe82bc96be", size = 794115, upload-time = "2025-09-25T21:32:31.353Z" }, + { url = "https://files.pythonhosted.org/packages/de/94/980b50a6531b3019e45ddeada0626d45fa85cbe22300844a7983285bed3b/pyyaml-6.0.3-cp313-cp313-win32.whl", hash = "sha256:d0eae10f8159e8fdad514efdc92d74fd8d682c933a6dd088030f3834bc8e6b26", size = 137427, upload-time = "2025-09-25T21:32:32.58Z" }, + { url = "https://files.pythonhosted.org/packages/97/c9/39d5b874e8b28845e4ec2202b5da735d0199dbe5b8fb85f91398814a9a46/pyyaml-6.0.3-cp313-cp313-win_amd64.whl", hash = "sha256:79005a0d97d5ddabfeeea4cf676af11e647e41d81c9a7722a193022accdb6b7c", size = 154090, upload-time = "2025-09-25T21:32:33.659Z" }, + { url = "https://files.pythonhosted.org/packages/73/e8/2bdf3ca2090f68bb3d75b44da7bbc71843b19c9f2b9cb9b0f4ab7a5a4329/pyyaml-6.0.3-cp313-cp313-win_arm64.whl", hash = "sha256:5498cd1645aa724a7c71c8f378eb29ebe23da2fc0d7a08071d89469bf1d2defb", size = 140246, upload-time = "2025-09-25T21:32:34.663Z" }, +] + +[[package]] +name = "pyzmq" +version = "27.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/04/0b/3c9baedbdf613ecaa7aa07027780b8867f57b6293b6ee50de316c9f3222b/pyzmq-27.1.0.tar.gz", hash = "sha256:ac0765e3d44455adb6ddbf4417dcce460fc40a05978c08efdf2948072f6db540", size = 281750, upload-time = "2025-09-08T23:10:18.157Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/92/e7/038aab64a946d535901103da16b953c8c9cc9c961dadcbf3609ed6428d23/pyzmq-27.1.0-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:452631b640340c928fa343801b0d07eb0c3789a5ffa843f6e1a9cee0ba4eb4fc", size = 1306279, upload-time = "2025-09-08T23:08:03.807Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5e/c3c49fdd0f535ef45eefcc16934648e9e59dace4a37ee88fc53f6cd8e641/pyzmq-27.1.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1c179799b118e554b66da67d88ed66cd37a169f1f23b5d9f0a231b4e8d44a113", size = 895645, upload-time = "2025-09-08T23:08:05.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/e5/b0b2504cb4e903a74dcf1ebae157f9e20ebb6ea76095f6cfffea28c42ecd/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3837439b7f99e60312f0c926a6ad437b067356dc2bc2ec96eb395fd0fe804233", size = 652574, upload-time = "2025-09-08T23:08:06.828Z" }, + { url = "https://files.pythonhosted.org/packages/f8/9b/c108cdb55560eaf253f0cbdb61b29971e9fb34d9c3499b0e96e4e60ed8a5/pyzmq-27.1.0-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43ad9a73e3da1fab5b0e7e13402f0b2fb934ae1c876c51d0afff0e7c052eca31", size = 840995, upload-time = "2025-09-08T23:08:08.396Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bb/b79798ca177b9eb0825b4c9998c6af8cd2a7f15a6a1a4272c1d1a21d382f/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0de3028d69d4cdc475bfe47a6128eb38d8bc0e8f4d69646adfbcd840facbac28", size = 1642070, upload-time = "2025-09-08T23:08:09.989Z" }, + { url = "https://files.pythonhosted.org/packages/9c/80/2df2e7977c4ede24c79ae39dcef3899bfc5f34d1ca7a5b24f182c9b7a9ca/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:cf44a7763aea9298c0aa7dbf859f87ed7012de8bda0f3977b6fb1d96745df856", size = 2021121, upload-time = "2025-09-08T23:08:11.907Z" }, + { url = "https://files.pythonhosted.org/packages/46/bd/2d45ad24f5f5ae7e8d01525eb76786fa7557136555cac7d929880519e33a/pyzmq-27.1.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f30f395a9e6fbca195400ce833c731e7b64c3919aa481af4d88c3759e0cb7496", size = 1878550, upload-time = "2025-09-08T23:08:13.513Z" }, + { url = "https://files.pythonhosted.org/packages/e6/2f/104c0a3c778d7c2ab8190e9db4f62f0b6957b53c9d87db77c284b69f33ea/pyzmq-27.1.0-cp312-abi3-win32.whl", hash = "sha256:250e5436a4ba13885494412b3da5d518cd0d3a278a1ae640e113c073a5f88edd", size = 559184, upload-time = "2025-09-08T23:08:15.163Z" }, + { url = "https://files.pythonhosted.org/packages/fc/7f/a21b20d577e4100c6a41795842028235998a643b1ad406a6d4163ea8f53e/pyzmq-27.1.0-cp312-abi3-win_amd64.whl", hash = "sha256:9ce490cf1d2ca2ad84733aa1d69ce6855372cb5ce9223802450c9b2a7cba0ccf", size = 619480, upload-time = "2025-09-08T23:08:17.192Z" }, + { url = "https://files.pythonhosted.org/packages/78/c2/c012beae5f76b72f007a9e91ee9401cb88c51d0f83c6257a03e785c81cc2/pyzmq-27.1.0-cp312-abi3-win_arm64.whl", hash = "sha256:75a2f36223f0d535a0c919e23615fc85a1e23b71f40c7eb43d7b1dedb4d8f15f", size = 552993, upload-time = "2025-09-08T23:08:18.926Z" }, + { url = "https://files.pythonhosted.org/packages/60/cb/84a13459c51da6cec1b7b1dc1a47e6db6da50b77ad7fd9c145842750a011/pyzmq-27.1.0-cp313-cp313-android_24_arm64_v8a.whl", hash = "sha256:93ad4b0855a664229559e45c8d23797ceac03183c7b6f5b4428152a6b06684a5", size = 1122436, upload-time = "2025-09-08T23:08:20.801Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b6/94414759a69a26c3dd674570a81813c46a078767d931a6c70ad29fc585cb/pyzmq-27.1.0-cp313-cp313-android_24_x86_64.whl", hash = "sha256:fbb4f2400bfda24f12f009cba62ad5734148569ff4949b1b6ec3b519444342e6", size = 1156301, upload-time = "2025-09-08T23:08:22.47Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ad/15906493fd40c316377fd8a8f6b1f93104f97a752667763c9b9c1b71d42d/pyzmq-27.1.0-cp313-cp313t-macosx_10_15_universal2.whl", hash = "sha256:e343d067f7b151cfe4eb3bb796a7752c9d369eed007b91231e817071d2c2fec7", size = 1341197, upload-time = "2025-09-08T23:08:24.286Z" }, + { url = "https://files.pythonhosted.org/packages/14/1d/d343f3ce13db53a54cb8946594e567410b2125394dafcc0268d8dda027e0/pyzmq-27.1.0-cp313-cp313t-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:08363b2011dec81c354d694bdecaef4770e0ae96b9afea70b3f47b973655cc05", size = 897275, upload-time = "2025-09-08T23:08:26.063Z" }, + { url = "https://files.pythonhosted.org/packages/69/2d/d83dd6d7ca929a2fc67d2c3005415cdf322af7751d773524809f9e585129/pyzmq-27.1.0-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d54530c8c8b5b8ddb3318f481297441af102517602b569146185fa10b63f4fa9", size = 660469, upload-time = "2025-09-08T23:08:27.623Z" }, + { url = "https://files.pythonhosted.org/packages/3e/cd/9822a7af117f4bc0f1952dbe9ef8358eb50a24928efd5edf54210b850259/pyzmq-27.1.0-cp313-cp313t-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f3afa12c392f0a44a2414056d730eebc33ec0926aae92b5ad5cf26ebb6cc128", size = 847961, upload-time = "2025-09-08T23:08:29.672Z" }, + { url = "https://files.pythonhosted.org/packages/9a/12/f003e824a19ed73be15542f172fd0ec4ad0b60cf37436652c93b9df7c585/pyzmq-27.1.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c65047adafe573ff023b3187bb93faa583151627bc9c51fc4fb2c561ed689d39", size = 1650282, upload-time = "2025-09-08T23:08:31.349Z" }, + { url = "https://files.pythonhosted.org/packages/d5/4a/e82d788ed58e9a23995cee70dbc20c9aded3d13a92d30d57ec2291f1e8a3/pyzmq-27.1.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:90e6e9441c946a8b0a667356f7078d96411391a3b8f80980315455574177ec97", size = 2024468, upload-time = "2025-09-08T23:08:33.543Z" }, + { url = "https://files.pythonhosted.org/packages/d9/94/2da0a60841f757481e402b34bf4c8bf57fa54a5466b965de791b1e6f747d/pyzmq-27.1.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:add071b2d25f84e8189aaf0882d39a285b42fa3853016ebab234a5e78c7a43db", size = 1885394, upload-time = "2025-09-08T23:08:35.51Z" }, + { url = "https://files.pythonhosted.org/packages/4f/6f/55c10e2e49ad52d080dc24e37adb215e5b0d64990b57598abc2e3f01725b/pyzmq-27.1.0-cp313-cp313t-win32.whl", hash = "sha256:7ccc0700cfdf7bd487bea8d850ec38f204478681ea02a582a8da8171b7f90a1c", size = 574964, upload-time = "2025-09-08T23:08:37.178Z" }, + { url = "https://files.pythonhosted.org/packages/87/4d/2534970ba63dd7c522d8ca80fb92777f362c0f321900667c615e2067cb29/pyzmq-27.1.0-cp313-cp313t-win_amd64.whl", hash = "sha256:8085a9fba668216b9b4323be338ee5437a235fe275b9d1610e422ccc279733e2", size = 641029, upload-time = "2025-09-08T23:08:40.595Z" }, + { url = "https://files.pythonhosted.org/packages/f6/fa/f8aea7a28b0641f31d40dea42d7ef003fded31e184ef47db696bc74cd610/pyzmq-27.1.0-cp313-cp313t-win_arm64.whl", hash = "sha256:6bb54ca21bcfe361e445256c15eedf083f153811c37be87e0514934d6913061e", size = 561541, upload-time = "2025-09-08T23:08:42.668Z" }, +] + +[[package]] +name = "referencing" +version = "0.37.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "rpds-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8", size = 78036, upload-time = "2025-10-13T15:30:48.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" }, +] + +[[package]] +name = "regex" +version = "2025.10.23" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/c8/1d2160d36b11fbe0a61acb7c3c81ab032d9ec8ad888ac9e0a61b85ab99dd/regex-2025.10.23.tar.gz", hash = "sha256:8cbaf8ceb88f96ae2356d01b9adf5e6306fa42fa6f7eab6b97794e37c959ac26", size = 401266, upload-time = "2025-10-21T15:58:20.23Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/28/c6/195a6217a43719d5a6a12cc192a22d12c40290cecfa577f00f4fb822f07d/regex-2025.10.23-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:b7690f95404a1293923a296981fd943cca12c31a41af9c21ba3edd06398fc193", size = 488956, upload-time = "2025-10-21T15:55:42.887Z" }, + { url = "https://files.pythonhosted.org/packages/4c/93/181070cd1aa2fa541ff2d3afcf763ceecd4937b34c615fa92765020a6c90/regex-2025.10.23-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1a32d77aeaea58a13230100dd8797ac1a84c457f3af2fdf0d81ea689d5a9105b", size = 290997, upload-time = "2025-10-21T15:55:44.53Z" }, + { url = "https://files.pythonhosted.org/packages/b6/c5/9d37fbe3a40ed8dda78c23e1263002497540c0d1522ed75482ef6c2000f0/regex-2025.10.23-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b24b29402f264f70a3c81f45974323b41764ff7159655360543b7cabb73e7d2f", size = 288686, upload-time = "2025-10-21T15:55:46.186Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e7/db610ff9f10c2921f9b6ac0c8d8be4681b28ddd40fc0549429366967e61f/regex-2025.10.23-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:563824a08c7c03d96856d84b46fdb3bbb7cfbdf79da7ef68725cda2ce169c72a", size = 798466, upload-time = "2025-10-21T15:55:48.24Z" }, + { url = "https://files.pythonhosted.org/packages/90/10/aab883e1fa7fe2feb15ac663026e70ca0ae1411efa0c7a4a0342d9545015/regex-2025.10.23-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a0ec8bdd88d2e2659c3518087ee34b37e20bd169419ffead4240a7004e8ed03b", size = 863996, upload-time = "2025-10-21T15:55:50.478Z" }, + { url = "https://files.pythonhosted.org/packages/a2/b0/8f686dd97a51f3b37d0238cd00a6d0f9ccabe701f05b56de1918571d0d61/regex-2025.10.23-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b577601bfe1d33913fcd9276d7607bbac827c4798d9e14d04bf37d417a6c41cb", size = 912145, upload-time = "2025-10-21T15:55:52.215Z" }, + { url = "https://files.pythonhosted.org/packages/a3/ca/639f8cd5b08797bca38fc5e7e07f76641a428cf8c7fca05894caf045aa32/regex-2025.10.23-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7c9f2c68ac6cb3de94eea08a437a75eaa2bd33f9e97c84836ca0b610a5804368", size = 803370, upload-time = "2025-10-21T15:55:53.944Z" }, + { url = "https://files.pythonhosted.org/packages/0d/1e/a40725bb76959eddf8abc42a967bed6f4851b39f5ac4f20e9794d7832aa5/regex-2025.10.23-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89f8b9ea3830c79468e26b0e21c3585f69f105157c2154a36f6b7839f8afb351", size = 787767, upload-time = "2025-10-21T15:55:56.004Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d8/8ee9858062936b0f99656dce390aa667c6e7fb0c357b1b9bf76fb5e2e708/regex-2025.10.23-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:98fd84c4e4ea185b3bb5bf065261ab45867d8875032f358a435647285c722673", size = 858335, upload-time = "2025-10-21T15:55:58.185Z" }, + { url = "https://files.pythonhosted.org/packages/d8/0a/ed5faaa63fa8e3064ab670e08061fbf09e3a10235b19630cf0cbb9e48c0a/regex-2025.10.23-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:1e11d3e5887b8b096f96b4154dfb902f29c723a9556639586cd140e77e28b313", size = 850402, upload-time = "2025-10-21T15:56:00.023Z" }, + { url = "https://files.pythonhosted.org/packages/79/14/d05f617342f4b2b4a23561da500ca2beab062bfcc408d60680e77ecaf04d/regex-2025.10.23-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f13450328a6634348d47a88367e06b64c9d84980ef6a748f717b13f8ce64e87", size = 789739, upload-time = "2025-10-21T15:56:01.967Z" }, + { url = "https://files.pythonhosted.org/packages/f9/7b/e8ce8eef42a15f2c3461f8b3e6e924bbc86e9605cb534a393aadc8d3aff8/regex-2025.10.23-cp313-cp313-win32.whl", hash = "sha256:37be9296598a30c6a20236248cb8b2c07ffd54d095b75d3a2a2ee5babdc51df1", size = 266054, upload-time = "2025-10-21T15:56:05.291Z" }, + { url = "https://files.pythonhosted.org/packages/71/2d/55184ed6be6473187868d2f2e6a0708195fc58270e62a22cbf26028f2570/regex-2025.10.23-cp313-cp313-win_amd64.whl", hash = "sha256:ea7a3c283ce0f06fe789365841e9174ba05f8db16e2fd6ae00a02df9572c04c0", size = 276917, upload-time = "2025-10-21T15:56:07.303Z" }, + { url = "https://files.pythonhosted.org/packages/9c/d4/927eced0e2bd45c45839e556f987f8c8f8683268dd3c00ad327deb3b0172/regex-2025.10.23-cp313-cp313-win_arm64.whl", hash = "sha256:d9a4953575f300a7bab71afa4cd4ac061c7697c89590a2902b536783eeb49a4f", size = 270105, upload-time = "2025-10-21T15:56:09.857Z" }, + { url = "https://files.pythonhosted.org/packages/3e/b3/95b310605285573341fc062d1d30b19a54f857530e86c805f942c4ff7941/regex-2025.10.23-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:7d6606524fa77b3912c9ef52a42ef63c6cfbfc1077e9dc6296cd5da0da286044", size = 491850, upload-time = "2025-10-21T15:56:11.685Z" }, + { url = "https://files.pythonhosted.org/packages/a4/8f/207c2cec01e34e56db1eff606eef46644a60cf1739ecd474627db90ad90b/regex-2025.10.23-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:c037aadf4d64bdc38af7db3dbd34877a057ce6524eefcb2914d6d41c56f968cc", size = 292537, upload-time = "2025-10-21T15:56:13.963Z" }, + { url = "https://files.pythonhosted.org/packages/98/3b/025240af4ada1dc0b5f10d73f3e5122d04ce7f8908ab8881e5d82b9d61b6/regex-2025.10.23-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:99018c331fb2529084a0c9b4c713dfa49fafb47c7712422e49467c13a636c656", size = 290904, upload-time = "2025-10-21T15:56:16.016Z" }, + { url = "https://files.pythonhosted.org/packages/81/8e/104ac14e2d3450c43db18ec03e1b96b445a94ae510b60138f00ce2cb7ca1/regex-2025.10.23-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fd8aba965604d70306eb90a35528f776e59112a7114a5162824d43b76fa27f58", size = 807311, upload-time = "2025-10-21T15:56:17.818Z" }, + { url = "https://files.pythonhosted.org/packages/19/63/78aef90141b7ce0be8a18e1782f764f6997ad09de0e05251f0d2503a914a/regex-2025.10.23-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:238e67264b4013e74136c49f883734f68656adf8257bfa13b515626b31b20f8e", size = 873241, upload-time = "2025-10-21T15:56:19.941Z" }, + { url = "https://files.pythonhosted.org/packages/b3/a8/80eb1201bb49ae4dba68a1b284b4211ed9daa8e74dc600018a10a90399fb/regex-2025.10.23-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b2eb48bd9848d66fd04826382f5e8491ae633de3233a3d64d58ceb4ecfa2113a", size = 914794, upload-time = "2025-10-21T15:56:22.488Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d5/1984b6ee93281f360a119a5ca1af6a8ca7d8417861671388bf750becc29b/regex-2025.10.23-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d36591ce06d047d0c0fe2fc5f14bfbd5b4525d08a7b6a279379085e13f0e3d0e", size = 812581, upload-time = "2025-10-21T15:56:24.319Z" }, + { url = "https://files.pythonhosted.org/packages/c4/39/11ebdc6d9927172a64ae237d16763145db6bd45ebb4055c17b88edab72a7/regex-2025.10.23-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b5d4ece8628d6e364302006366cea3ee887db397faebacc5dacf8ef19e064cf8", size = 795346, upload-time = "2025-10-21T15:56:26.232Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b4/89a591bcc08b5e436af43315284bd233ba77daf0cf20e098d7af12f006c1/regex-2025.10.23-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:39a7e8083959cb1c4ff74e483eecb5a65d3b3e1d821b256e54baf61782c906c6", size = 868214, upload-time = "2025-10-21T15:56:28.597Z" }, + { url = "https://files.pythonhosted.org/packages/3d/ff/58ba98409c1dbc8316cdb20dafbc63ed267380a07780cafecaf5012dabc9/regex-2025.10.23-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:842d449a8fefe546f311656cf8c0d6729b08c09a185f1cad94c756210286d6a8", size = 854540, upload-time = "2025-10-21T15:56:30.875Z" }, + { url = "https://files.pythonhosted.org/packages/9a/f2/4a9e9338d67626e2071b643f828a482712ad15889d7268e11e9a63d6f7e9/regex-2025.10.23-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d614986dc68506be8f00474f4f6960e03e4ca9883f7df47744800e7d7c08a494", size = 799346, upload-time = "2025-10-21T15:56:32.725Z" }, + { url = "https://files.pythonhosted.org/packages/63/be/543d35c46bebf6f7bf2be538cca74d6585f25714700c36f37f01b92df551/regex-2025.10.23-cp313-cp313t-win32.whl", hash = "sha256:a5b7a26b51a9df473ec16a1934d117443a775ceb7b39b78670b2e21893c330c9", size = 268657, upload-time = "2025-10-21T15:56:34.577Z" }, + { url = "https://files.pythonhosted.org/packages/14/9f/4dd6b7b612037158bb2c9bcaa710e6fb3c40ad54af441b9c53b3a137a9f1/regex-2025.10.23-cp313-cp313t-win_amd64.whl", hash = "sha256:ce81c5544a5453f61cb6f548ed358cfb111e3b23f3cd42d250a4077a6be2a7b6", size = 280075, upload-time = "2025-10-21T15:56:36.767Z" }, + { url = "https://files.pythonhosted.org/packages/81/7a/5bd0672aa65d38c8da6747c17c8b441bdb53d816c569e3261013af8e83cf/regex-2025.10.23-cp313-cp313t-win_arm64.whl", hash = "sha256:e9bf7f6699f490e4e43c44757aa179dab24d1960999c84ab5c3d5377714ed473", size = 271219, upload-time = "2025-10-21T15:56:39.033Z" }, +] + +[[package]] +name = "requests" +version = "2.32.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "charset-normalizer" }, + { name = "idna" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, +] + +[[package]] +name = "rpds-py" +version = "0.28.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/48/dc/95f074d43452b3ef5d06276696ece4b3b5d696e7c9ad7173c54b1390cd70/rpds_py-0.28.0.tar.gz", hash = "sha256:abd4df20485a0983e2ca334a216249b6186d6e3c1627e106651943dbdb791aea", size = 27419, upload-time = "2025-10-22T22:24:29.327Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/03/ce566d92611dfac0085c2f4b048cd53ed7c274a5c05974b882a908d540a2/rpds_py-0.28.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e9e184408a0297086f880556b6168fa927d677716f83d3472ea333b42171ee3b", size = 366235, upload-time = "2025-10-22T22:22:28.397Z" }, + { url = "https://files.pythonhosted.org/packages/00/34/1c61da1b25592b86fd285bd7bd8422f4c9d748a7373b46126f9ae792a004/rpds_py-0.28.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:edd267266a9b0448f33dc465a97cfc5d467594b600fe28e7fa2f36450e03053a", size = 348241, upload-time = "2025-10-22T22:22:30.171Z" }, + { url = "https://files.pythonhosted.org/packages/fc/00/ed1e28616848c61c493a067779633ebf4b569eccaacf9ccbdc0e7cba2b9d/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85beb8b3f45e4e32f6802fb6cd6b17f615ef6c6a52f265371fb916fae02814aa", size = 378079, upload-time = "2025-10-22T22:22:31.644Z" }, + { url = "https://files.pythonhosted.org/packages/11/b2/ccb30333a16a470091b6e50289adb4d3ec656fd9951ba8c5e3aaa0746a67/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d2412be8d00a1b895f8ad827cc2116455196e20ed994bb704bf138fe91a42724", size = 393151, upload-time = "2025-10-22T22:22:33.453Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d0/73e2217c3ee486d555cb84920597480627d8c0240ff3062005c6cc47773e/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cf128350d384b777da0e68796afdcebc2e9f63f0e9f242217754e647f6d32491", size = 517520, upload-time = "2025-10-22T22:22:34.949Z" }, + { url = "https://files.pythonhosted.org/packages/c4/91/23efe81c700427d0841a4ae7ea23e305654381831e6029499fe80be8a071/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a2036d09b363aa36695d1cc1a97b36865597f4478470b0697b5ee9403f4fe399", size = 408699, upload-time = "2025-10-22T22:22:36.584Z" }, + { url = "https://files.pythonhosted.org/packages/ca/ee/a324d3198da151820a326c1f988caaa4f37fc27955148a76fff7a2d787a9/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8e1e9be4fa6305a16be628959188e4fd5cd6f1b0e724d63c6d8b2a8adf74ea6", size = 385720, upload-time = "2025-10-22T22:22:38.014Z" }, + { url = "https://files.pythonhosted.org/packages/19/ad/e68120dc05af8b7cab4a789fccd8cdcf0fe7e6581461038cc5c164cd97d2/rpds_py-0.28.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:0a403460c9dd91a7f23fc3188de6d8977f1d9603a351d5db6cf20aaea95b538d", size = 401096, upload-time = "2025-10-22T22:22:39.869Z" }, + { url = "https://files.pythonhosted.org/packages/99/90/c1e070620042459d60df6356b666bb1f62198a89d68881816a7ed121595a/rpds_py-0.28.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d7366b6553cdc805abcc512b849a519167db8f5e5c3472010cd1228b224265cb", size = 411465, upload-time = "2025-10-22T22:22:41.395Z" }, + { url = "https://files.pythonhosted.org/packages/68/61/7c195b30d57f1b8d5970f600efee72a4fad79ec829057972e13a0370fd24/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5b43c6a3726efd50f18d8120ec0551241c38785b68952d240c45ea553912ac41", size = 558832, upload-time = "2025-10-22T22:22:42.871Z" }, + { url = "https://files.pythonhosted.org/packages/b0/3d/06f3a718864773f69941d4deccdf18e5e47dd298b4628062f004c10f3b34/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:0cb7203c7bc69d7c1585ebb33a2e6074492d2fc21ad28a7b9d40457ac2a51ab7", size = 583230, upload-time = "2025-10-22T22:22:44.877Z" }, + { url = "https://files.pythonhosted.org/packages/66/df/62fc783781a121e77fee9a21ead0a926f1b652280a33f5956a5e7833ed30/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7a52a5169c664dfb495882adc75c304ae1d50df552fbd68e100fdc719dee4ff9", size = 553268, upload-time = "2025-10-22T22:22:46.441Z" }, + { url = "https://files.pythonhosted.org/packages/84/85/d34366e335140a4837902d3dea89b51f087bd6a63c993ebdff59e93ee61d/rpds_py-0.28.0-cp313-cp313-win32.whl", hash = "sha256:2e42456917b6687215b3e606ab46aa6bca040c77af7df9a08a6dcfe8a4d10ca5", size = 217100, upload-time = "2025-10-22T22:22:48.342Z" }, + { url = "https://files.pythonhosted.org/packages/3c/1c/f25a3f3752ad7601476e3eff395fe075e0f7813fbb9862bd67c82440e880/rpds_py-0.28.0-cp313-cp313-win_amd64.whl", hash = "sha256:e0a0311caedc8069d68fc2bf4c9019b58a2d5ce3cd7cb656c845f1615b577e1e", size = 227759, upload-time = "2025-10-22T22:22:50.219Z" }, + { url = "https://files.pythonhosted.org/packages/e0/d6/5f39b42b99615b5bc2f36ab90423ea404830bdfee1c706820943e9a645eb/rpds_py-0.28.0-cp313-cp313-win_arm64.whl", hash = "sha256:04c1b207ab8b581108801528d59ad80aa83bb170b35b0ddffb29c20e411acdc1", size = 217326, upload-time = "2025-10-22T22:22:51.647Z" }, + { url = "https://files.pythonhosted.org/packages/5c/8b/0c69b72d1cee20a63db534be0df271effe715ef6c744fdf1ff23bb2b0b1c/rpds_py-0.28.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:f296ea3054e11fc58ad42e850e8b75c62d9a93a9f981ad04b2e5ae7d2186ff9c", size = 355736, upload-time = "2025-10-22T22:22:53.211Z" }, + { url = "https://files.pythonhosted.org/packages/f7/6d/0c2ee773cfb55c31a8514d2cece856dd299170a49babd50dcffb15ddc749/rpds_py-0.28.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5a7306c19b19005ad98468fcefeb7100b19c79fc23a5f24a12e06d91181193fa", size = 342677, upload-time = "2025-10-22T22:22:54.723Z" }, + { url = "https://files.pythonhosted.org/packages/e2/1c/22513ab25a27ea205144414724743e305e8153e6abe81833b5e678650f5a/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5d9b86aa501fed9862a443c5c3116f6ead8bc9296185f369277c42542bd646b", size = 371847, upload-time = "2025-10-22T22:22:56.295Z" }, + { url = "https://files.pythonhosted.org/packages/60/07/68e6ccdb4b05115ffe61d31afc94adef1833d3a72f76c9632d4d90d67954/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e5bbc701eff140ba0e872691d573b3d5d30059ea26e5785acba9132d10c8c31d", size = 381800, upload-time = "2025-10-22T22:22:57.808Z" }, + { url = "https://files.pythonhosted.org/packages/73/bf/6d6d15df80781d7f9f368e7c1a00caf764436518c4877fb28b029c4624af/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a5690671cd672a45aa8616d7374fdf334a1b9c04a0cac3c854b1136e92374fe", size = 518827, upload-time = "2025-10-22T22:22:59.826Z" }, + { url = "https://files.pythonhosted.org/packages/7b/d3/2decbb2976cc452cbf12a2b0aaac5f1b9dc5dd9d1f7e2509a3ee00421249/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9f1d92ecea4fa12f978a367c32a5375a1982834649cdb96539dcdc12e609ab1a", size = 399471, upload-time = "2025-10-22T22:23:01.968Z" }, + { url = "https://files.pythonhosted.org/packages/b1/2c/f30892f9e54bd02e5faca3f6a26d6933c51055e67d54818af90abed9748e/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d252db6b1a78d0a3928b6190156042d54c93660ce4d98290d7b16b5296fb7cc", size = 377578, upload-time = "2025-10-22T22:23:03.52Z" }, + { url = "https://files.pythonhosted.org/packages/f0/5d/3bce97e5534157318f29ac06bf2d279dae2674ec12f7cb9c12739cee64d8/rpds_py-0.28.0-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:d61b355c3275acb825f8777d6c4505f42b5007e357af500939d4a35b19177259", size = 390482, upload-time = "2025-10-22T22:23:05.391Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f0/886bd515ed457b5bd93b166175edb80a0b21a210c10e993392127f1e3931/rpds_py-0.28.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:acbe5e8b1026c0c580d0321c8aae4b0a1e1676861d48d6e8c6586625055b606a", size = 402447, upload-time = "2025-10-22T22:23:06.93Z" }, + { url = "https://files.pythonhosted.org/packages/42/b5/71e8777ac55e6af1f4f1c05b47542a1eaa6c33c1cf0d300dca6a1c6e159a/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:8aa23b6f0fc59b85b4c7d89ba2965af274346f738e8d9fc2455763602e62fd5f", size = 552385, upload-time = "2025-10-22T22:23:08.557Z" }, + { url = "https://files.pythonhosted.org/packages/5d/cb/6ca2d70cbda5a8e36605e7788c4aa3bea7c17d71d213465a5a675079b98d/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7b14b0c680286958817c22d76fcbca4800ddacef6f678f3a7c79a1fe7067fe37", size = 575642, upload-time = "2025-10-22T22:23:10.348Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d4/407ad9960ca7856d7b25c96dcbe019270b5ffdd83a561787bc682c797086/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bcf1d210dfee61a6c86551d67ee1031899c0fdbae88b2d44a569995d43797712", size = 544507, upload-time = "2025-10-22T22:23:12.434Z" }, + { url = "https://files.pythonhosted.org/packages/51/31/2f46fe0efcac23fbf5797c6b6b7e1c76f7d60773e525cb65fcbc582ee0f2/rpds_py-0.28.0-cp313-cp313t-win32.whl", hash = "sha256:3aa4dc0fdab4a7029ac63959a3ccf4ed605fee048ba67ce89ca3168da34a1342", size = 205376, upload-time = "2025-10-22T22:23:13.979Z" }, + { url = "https://files.pythonhosted.org/packages/92/e4/15947bda33cbedfc134490a41841ab8870a72a867a03d4969d886f6594a2/rpds_py-0.28.0-cp313-cp313t-win_amd64.whl", hash = "sha256:7b7d9d83c942855e4fdcfa75d4f96f6b9e272d42fffcb72cd4bb2577db2e2907", size = 215907, upload-time = "2025-10-22T22:23:15.5Z" }, +] + +[[package]] +name = "ruff" +version = "0.14.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/34/8218a19b2055b80601e8fd201ec723c74c7fe1ca06d525a43ed07b6d8e85/ruff-0.14.2.tar.gz", hash = "sha256:98da787668f239313d9c902ca7c523fe11b8ec3f39345553a51b25abc4629c96", size = 5539663, upload-time = "2025-10-23T19:37:00.956Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/dd/23eb2db5ad9acae7c845700493b72d3ae214dce0b226f27df89216110f2b/ruff-0.14.2-py3-none-linux_armv6l.whl", hash = "sha256:7cbe4e593505bdec5884c2d0a4d791a90301bc23e49a6b1eb642dd85ef9c64f1", size = 12533390, upload-time = "2025-10-23T19:36:18.044Z" }, + { url = "https://files.pythonhosted.org/packages/5a/8c/5f9acff43ddcf3f85130d0146d0477e28ccecc495f9f684f8f7119b74c0d/ruff-0.14.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8d54b561729cee92f8d89c316ad7a3f9705533f5903b042399b6ae0ddfc62e11", size = 12887187, upload-time = "2025-10-23T19:36:22.664Z" }, + { url = "https://files.pythonhosted.org/packages/99/fa/047646491479074029665022e9f3dc6f0515797f40a4b6014ea8474c539d/ruff-0.14.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5c8753dfa44ebb2cde10ce5b4d2ef55a41fb9d9b16732a2c5df64620dbda44a3", size = 11925177, upload-time = "2025-10-23T19:36:24.778Z" }, + { url = "https://files.pythonhosted.org/packages/15/8b/c44cf7fe6e59ab24a9d939493a11030b503bdc2a16622cede8b7b1df0114/ruff-0.14.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d0bbeffb8d9f4fccf7b5198d566d0bad99a9cb622f1fc3467af96cb8773c9e3", size = 12358285, upload-time = "2025-10-23T19:36:26.979Z" }, + { url = "https://files.pythonhosted.org/packages/45/01/47701b26254267ef40369aea3acb62a7b23e921c27372d127e0f3af48092/ruff-0.14.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7047f0c5a713a401e43a88d36843d9c83a19c584e63d664474675620aaa634a8", size = 12303832, upload-time = "2025-10-23T19:36:29.192Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5c/ae7244ca4fbdf2bee9d6405dcd5bc6ae51ee1df66eb7a9884b77b8af856d/ruff-0.14.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bf8d2f9aa1602599217d82e8e0af7fd33e5878c4d98f37906b7c93f46f9a839", size = 13036995, upload-time = "2025-10-23T19:36:31.861Z" }, + { url = "https://files.pythonhosted.org/packages/27/4c/0860a79ce6fd4c709ac01173f76f929d53f59748d0dcdd662519835dae43/ruff-0.14.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1c505b389e19c57a317cf4b42db824e2fca96ffb3d86766c1c9f8b96d32048a7", size = 14512649, upload-time = "2025-10-23T19:36:33.915Z" }, + { url = "https://files.pythonhosted.org/packages/7f/7f/d365de998069720a3abfc250ddd876fc4b81a403a766c74ff9bde15b5378/ruff-0.14.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a307fc45ebd887b3f26b36d9326bb70bf69b01561950cdcc6c0bdf7bb8e0f7cc", size = 14088182, upload-time = "2025-10-23T19:36:36.983Z" }, + { url = "https://files.pythonhosted.org/packages/6c/ea/d8e3e6b209162000a7be1faa41b0a0c16a133010311edc3329753cc6596a/ruff-0.14.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:61ae91a32c853172f832c2f40bd05fd69f491db7289fb85a9b941ebdd549781a", size = 13599516, upload-time = "2025-10-23T19:36:39.208Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ea/c7810322086db68989fb20a8d5221dd3b79e49e396b01badca07b433ab45/ruff-0.14.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1967e40286f63ee23c615e8e7e98098dedc7301568bd88991f6e544d8ae096", size = 13272690, upload-time = "2025-10-23T19:36:41.453Z" }, + { url = "https://files.pythonhosted.org/packages/a9/39/10b05acf8c45786ef501d454e00937e1b97964f846bf28883d1f9619928a/ruff-0.14.2-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:2877f02119cdebf52a632d743a2e302dea422bfae152ebe2f193d3285a3a65df", size = 13496497, upload-time = "2025-10-23T19:36:43.61Z" }, + { url = "https://files.pythonhosted.org/packages/59/a1/1f25f8301e13751c30895092485fada29076e5e14264bdacc37202e85d24/ruff-0.14.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e681c5bc777de5af898decdcb6ba3321d0d466f4cb43c3e7cc2c3b4e7b843a05", size = 12266116, upload-time = "2025-10-23T19:36:45.625Z" }, + { url = "https://files.pythonhosted.org/packages/5c/fa/0029bfc9ce16ae78164e6923ef392e5f173b793b26cc39aa1d8b366cf9dc/ruff-0.14.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e21be42d72e224736f0c992cdb9959a2fa53c7e943b97ef5d081e13170e3ffc5", size = 12281345, upload-time = "2025-10-23T19:36:47.618Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ab/ece7baa3c0f29b7683be868c024f0838770c16607bea6852e46b202f1ff6/ruff-0.14.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b8264016f6f209fac16262882dbebf3f8be1629777cf0f37e7aff071b3e9b92e", size = 12629296, upload-time = "2025-10-23T19:36:49.789Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7f/638f54b43f3d4e48c6a68062794e5b367ddac778051806b9e235dfb7aa81/ruff-0.14.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5ca36b4cb4db3067a3b24444463ceea5565ea78b95fe9a07ca7cb7fd16948770", size = 13371610, upload-time = "2025-10-23T19:36:51.882Z" }, + { url = "https://files.pythonhosted.org/packages/8d/35/3654a973ebe5b32e1fd4a08ed2d46755af7267da7ac710d97420d7b8657d/ruff-0.14.2-py3-none-win32.whl", hash = "sha256:41775927d287685e08f48d8eb3f765625ab0b7042cc9377e20e64f4eb0056ee9", size = 12415318, upload-time = "2025-10-23T19:36:53.961Z" }, + { url = "https://files.pythonhosted.org/packages/71/30/3758bcf9e0b6a4193a6f51abf84254aba00887dfa8c20aba18aa366c5f57/ruff-0.14.2-py3-none-win_amd64.whl", hash = "sha256:0df3424aa5c3c08b34ed8ce099df1021e3adaca6e90229273496b839e5a7e1af", size = 13565279, upload-time = "2025-10-23T19:36:56.578Z" }, + { url = "https://files.pythonhosted.org/packages/2e/5d/aa883766f8ef9ffbe6aa24f7192fb71632f31a30e77eb39aa2b0dc4290ac/ruff-0.14.2-py3-none-win_arm64.whl", hash = "sha256:ea9d635e83ba21569fbacda7e78afbfeb94911c9434aff06192d9bc23fd5495a", size = 12554956, upload-time = "2025-10-23T19:36:58.714Z" }, +] + +[[package]] +name = "safetensors" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/cc/738f3011628920e027a11754d9cae9abec1aed00f7ae860abbf843755233/safetensors-0.6.2.tar.gz", hash = "sha256:43ff2aa0e6fa2dc3ea5524ac7ad93a9839256b8703761e76e2d0b2a3fa4f15d9", size = 197968, upload-time = "2025-08-08T13:13:58.654Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/b1/3f5fd73c039fc87dba3ff8b5d528bfc5a32b597fea8e7a6a4800343a17c7/safetensors-0.6.2-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:9c85ede8ec58f120bad982ec47746981e210492a6db876882aa021446af8ffba", size = 454797, upload-time = "2025-08-08T13:13:52.066Z" }, + { url = "https://files.pythonhosted.org/packages/8c/c9/bb114c158540ee17907ec470d01980957fdaf87b4aa07914c24eba87b9c6/safetensors-0.6.2-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d6675cf4b39c98dbd7d940598028f3742e0375a6b4d4277e76beb0c35f4b843b", size = 432206, upload-time = "2025-08-08T13:13:50.931Z" }, + { url = "https://files.pythonhosted.org/packages/d3/8e/f70c34e47df3110e8e0bb268d90db8d4be8958a54ab0336c9be4fe86dac8/safetensors-0.6.2-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d2d2b3ce1e2509c68932ca03ab8f20570920cd9754b05063d4368ee52833ecd", size = 473261, upload-time = "2025-08-08T13:13:41.259Z" }, + { url = "https://files.pythonhosted.org/packages/2a/f5/be9c6a7c7ef773e1996dc214e73485286df1836dbd063e8085ee1976f9cb/safetensors-0.6.2-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:93de35a18f46b0f5a6a1f9e26d91b442094f2df02e9fd7acf224cfec4238821a", size = 485117, upload-time = "2025-08-08T13:13:43.506Z" }, + { url = "https://files.pythonhosted.org/packages/c9/55/23f2d0a2c96ed8665bf17a30ab4ce5270413f4d74b6d87dd663258b9af31/safetensors-0.6.2-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89a89b505f335640f9120fac65ddeb83e40f1fd081cb8ed88b505bdccec8d0a1", size = 616154, upload-time = "2025-08-08T13:13:45.096Z" }, + { url = "https://files.pythonhosted.org/packages/98/c6/affb0bd9ce02aa46e7acddbe087912a04d953d7a4d74b708c91b5806ef3f/safetensors-0.6.2-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fc4d0d0b937e04bdf2ae6f70cd3ad51328635fe0e6214aa1fc811f3b576b3bda", size = 520713, upload-time = "2025-08-08T13:13:46.25Z" }, + { url = "https://files.pythonhosted.org/packages/fe/5d/5a514d7b88e310c8b146e2404e0dc161282e78634d9358975fd56dfd14be/safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8045db2c872db8f4cbe3faa0495932d89c38c899c603f21e9b6486951a5ecb8f", size = 485835, upload-time = "2025-08-08T13:13:49.373Z" }, + { url = "https://files.pythonhosted.org/packages/7a/7b/4fc3b2ba62c352b2071bea9cfbad330fadda70579f617506ae1a2f129cab/safetensors-0.6.2-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:81e67e8bab9878bb568cffbc5f5e655adb38d2418351dc0859ccac158f753e19", size = 521503, upload-time = "2025-08-08T13:13:47.651Z" }, + { url = "https://files.pythonhosted.org/packages/5a/50/0057e11fe1f3cead9254315a6c106a16dd4b1a19cd247f7cc6414f6b7866/safetensors-0.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0e4d029ab0a0e0e4fdf142b194514695b1d7d3735503ba700cf36d0fc7136ce", size = 652256, upload-time = "2025-08-08T13:13:53.167Z" }, + { url = "https://files.pythonhosted.org/packages/e9/29/473f789e4ac242593ac1656fbece6e1ecd860bb289e635e963667807afe3/safetensors-0.6.2-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:fa48268185c52bfe8771e46325a1e21d317207bcabcb72e65c6e28e9ffeb29c7", size = 747281, upload-time = "2025-08-08T13:13:54.656Z" }, + { url = "https://files.pythonhosted.org/packages/68/52/f7324aad7f2df99e05525c84d352dc217e0fa637a4f603e9f2eedfbe2c67/safetensors-0.6.2-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:d83c20c12c2d2f465997c51b7ecb00e407e5f94d7dec3ea0cc11d86f60d3fde5", size = 692286, upload-time = "2025-08-08T13:13:55.884Z" }, + { url = "https://files.pythonhosted.org/packages/ad/fe/cad1d9762868c7c5dc70c8620074df28ebb1a8e4c17d4c0cb031889c457e/safetensors-0.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d944cea65fad0ead848b6ec2c37cc0b197194bec228f8020054742190e9312ac", size = 655957, upload-time = "2025-08-08T13:13:57.029Z" }, + { url = "https://files.pythonhosted.org/packages/59/a7/e2158e17bbe57d104f0abbd95dff60dda916cf277c9f9663b4bf9bad8b6e/safetensors-0.6.2-cp38-abi3-win32.whl", hash = "sha256:cab75ca7c064d3911411461151cb69380c9225798a20e712b102edda2542ddb1", size = 308926, upload-time = "2025-08-08T13:14:01.095Z" }, + { url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" }, +] + +[[package]] +name = "scipy" +version = "1.16.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/ca/d8ace4f98322d01abcd52d381134344bf7b431eba7ed8b42bdea5a3c2ac9/scipy-1.16.3.tar.gz", hash = "sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb", size = 30597883, upload-time = "2025-10-28T17:38:54.068Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/f1/57e8327ab1508272029e27eeef34f2302ffc156b69e7e233e906c2a5c379/scipy-1.16.3-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c", size = 36617856, upload-time = "2025-10-28T17:33:31.375Z" }, + { url = "https://files.pythonhosted.org/packages/44/13/7e63cfba8a7452eb756306aa2fd9b37a29a323b672b964b4fdeded9a3f21/scipy-1.16.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d", size = 28874306, upload-time = "2025-10-28T17:33:36.516Z" }, + { url = "https://files.pythonhosted.org/packages/15/65/3a9400efd0228a176e6ec3454b1fa998fbbb5a8defa1672c3f65706987db/scipy-1.16.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9", size = 20865371, upload-time = "2025-10-28T17:33:42.094Z" }, + { url = "https://files.pythonhosted.org/packages/33/d7/eda09adf009a9fb81827194d4dd02d2e4bc752cef16737cc4ef065234031/scipy-1.16.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4", size = 23524877, upload-time = "2025-10-28T17:33:48.483Z" }, + { url = "https://files.pythonhosted.org/packages/7d/6b/3f911e1ebc364cb81320223a3422aab7d26c9c7973109a9cd0f27c64c6c0/scipy-1.16.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959", size = 33342103, upload-time = "2025-10-28T17:33:56.495Z" }, + { url = "https://files.pythonhosted.org/packages/21/f6/4bfb5695d8941e5c570a04d9fcd0d36bce7511b7d78e6e75c8f9791f82d0/scipy-1.16.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88", size = 35697297, upload-time = "2025-10-28T17:34:04.722Z" }, + { url = "https://files.pythonhosted.org/packages/04/e1/6496dadbc80d8d896ff72511ecfe2316b50313bfc3ebf07a3f580f08bd8c/scipy-1.16.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234", size = 36021756, upload-time = "2025-10-28T17:34:13.482Z" }, + { url = "https://files.pythonhosted.org/packages/fe/bd/a8c7799e0136b987bda3e1b23d155bcb31aec68a4a472554df5f0937eef7/scipy-1.16.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d", size = 38696566, upload-time = "2025-10-28T17:34:22.384Z" }, + { url = "https://files.pythonhosted.org/packages/cd/01/1204382461fcbfeb05b6161b594f4007e78b6eba9b375382f79153172b4d/scipy-1.16.3-cp313-cp313-win_amd64.whl", hash = "sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304", size = 38529877, upload-time = "2025-10-28T17:35:51.076Z" }, + { url = "https://files.pythonhosted.org/packages/7f/14/9d9fbcaa1260a94f4bb5b64ba9213ceb5d03cd88841fe9fd1ffd47a45b73/scipy-1.16.3-cp313-cp313-win_arm64.whl", hash = "sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2", size = 25455366, upload-time = "2025-10-28T17:35:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/e2/a3/9ec205bd49f42d45d77f1730dbad9ccf146244c1647605cf834b3a8c4f36/scipy-1.16.3-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b", size = 37027931, upload-time = "2025-10-28T17:34:31.451Z" }, + { url = "https://files.pythonhosted.org/packages/25/06/ca9fd1f3a4589cbd825b1447e5db3a8ebb969c1eaf22c8579bd286f51b6d/scipy-1.16.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079", size = 29400081, upload-time = "2025-10-28T17:34:39.087Z" }, + { url = "https://files.pythonhosted.org/packages/6a/56/933e68210d92657d93fb0e381683bc0e53a965048d7358ff5fbf9e6a1b17/scipy-1.16.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a", size = 21391244, upload-time = "2025-10-28T17:34:45.234Z" }, + { url = "https://files.pythonhosted.org/packages/a8/7e/779845db03dc1418e215726329674b40576879b91814568757ff0014ad65/scipy-1.16.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119", size = 23929753, upload-time = "2025-10-28T17:34:51.793Z" }, + { url = "https://files.pythonhosted.org/packages/4c/4b/f756cf8161d5365dcdef9e5f460ab226c068211030a175d2fc7f3f41ca64/scipy-1.16.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c", size = 33496912, upload-time = "2025-10-28T17:34:59.8Z" }, + { url = "https://files.pythonhosted.org/packages/09/b5/222b1e49a58668f23839ca1542a6322bb095ab8d6590d4f71723869a6c2c/scipy-1.16.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e", size = 35802371, upload-time = "2025-10-28T17:35:08.173Z" }, + { url = "https://files.pythonhosted.org/packages/c1/8d/5964ef68bb31829bde27611f8c9deeac13764589fe74a75390242b64ca44/scipy-1.16.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135", size = 36190477, upload-time = "2025-10-28T17:35:16.7Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f2/b31d75cb9b5fa4dd39a0a931ee9b33e7f6f36f23be5ef560bf72e0f92f32/scipy-1.16.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6", size = 38796678, upload-time = "2025-10-28T17:35:26.354Z" }, + { url = "https://files.pythonhosted.org/packages/b4/1e/b3723d8ff64ab548c38d87055483714fefe6ee20e0189b62352b5e015bb1/scipy-1.16.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc", size = 38640178, upload-time = "2025-10-28T17:35:35.304Z" }, + { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, +] + +[[package]] +name = "sentry-sdk" +version = "2.42.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/31/04/ec8c1dd9250847303d98516e917978cb1c7083024770d86d657d2ccb5a70/sentry_sdk-2.42.1.tar.gz", hash = "sha256:8598cc6edcfe74cb8074ba6a7c15338cdee93d63d3eb9b9943b4b568354ad5b6", size = 354839, upload-time = "2025-10-20T12:38:40.45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/cb/c21b96ff379923310b4fb2c06e8d560d801e24aeb300faa72a04776868fc/sentry_sdk-2.42.1-py2.py3-none-any.whl", hash = "sha256:f8716b50c927d3beb41bc88439dc6bcd872237b596df5b14613e2ade104aee02", size = 380952, upload-time = "2025-10-20T12:38:38.88Z" }, +] + +[[package]] +name = "setuptools" +version = "80.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, +] + +[[package]] +name = "simple-stories-train" +version = "0.0.1" +source = { git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev#efa6175df794b322b5cd5b241e22ce47f117df32" } +dependencies = [ + { name = "datasets" }, + { name = "fire" }, + { name = "ipykernel" }, + { name = "jaxtyping" }, + { name = "pydantic" }, + { name = "pytest" }, + { name = "python-dotenv" }, + { name = "tiktoken" }, + { name = "torch" }, + { name = "torchvision" }, + { name = "tqdm" }, + { name = "transformers" }, + { name = "wandb" }, +] + +[[package]] +name = "six" +version = "1.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/e7/b2c673351809dca68a0e064b6af791aa332cf192da575fd474ed7d6f16a2/six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81", size = 34031, upload-time = "2024-12-04T17:35:28.174Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, +] + +[[package]] +name = "smmap" +version = "5.0.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, +] + +[[package]] +name = "sniffio" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/87/a6771e1546d97e7e041b6ae58d80074f81b7d5121207425c964ddf5cfdbd/sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc", size = 20372, upload-time = "2024-02-25T23:20:04.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, +] + +[[package]] +name = "spd" +version = "0.0.1" +source = { editable = "." } +dependencies = [ + { name = "datasets" }, + { name = "einops" }, + { name = "fastapi" }, + { name = "fire" }, + { name = "ipykernel" }, + { name = "jaxtyping" }, + { name = "matplotlib" }, + { name = "muutils" }, + { name = "numpy" }, + { name = "pydantic" }, + { name = "python-dotenv" }, + { name = "scipy" }, + { name = "simple-stories-train" }, + { name = "streamlit" }, + { name = "streamlit-antd-components" }, + { name = "sympy" }, + { name = "torch" }, + { name = "torchvision" }, + { name = "tqdm" }, + { name = "transformers" }, + { name = "uvicorn" }, + { name = "wandb" }, + { name = "wandb-workspaces" }, +] + +[package.dev-dependencies] +dev = [ + { name = "basedpyright" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "pytest-xdist" }, + { name = "ruff" }, +] + +[package.metadata] +requires-dist = [ + { name = "datasets", specifier = ">=2.21.0" }, + { name = "einops" }, + { name = "fastapi" }, + { name = "fire" }, + { name = "ipykernel" }, + { name = "jaxtyping" }, + { name = "matplotlib" }, + { name = "muutils" }, + { name = "numpy" }, + { name = "pydantic", specifier = "<2.12" }, + { name = "python-dotenv" }, + { name = "scipy", specifier = ">=1.14.1" }, + { name = "simple-stories-train", git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev" }, + { name = "streamlit" }, + { name = "streamlit-antd-components" }, + { name = "sympy" }, + { name = "torch", specifier = ">=2.6" }, + { name = "torchvision", specifier = ">=0.23,<0.24" }, + { name = "tqdm" }, + { name = "transformers" }, + { name = "uvicorn" }, + { name = "wandb", specifier = ">=0.20.1" }, + { name = "wandb-workspaces", specifier = "==0.1.12" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "basedpyright", specifier = "<1.32.0" }, + { name = "pre-commit" }, + { name = "pytest" }, + { name = "pytest-cov" }, + { name = "pytest-xdist" }, + { name = "ruff" }, +] + +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707, upload-time = "2023-09-30T13:58:05.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, +] + +[[package]] +name = "starlette" +version = "0.49.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1b/3f/507c21db33b66fb027a332f2cb3abbbe924cc3a79ced12f01ed8645955c9/starlette-0.49.1.tar.gz", hash = "sha256:481a43b71e24ed8c43b11ea02f5353d77840e01480881b8cb5a26b8cae64a8cb", size = 2654703, upload-time = "2025-10-28T17:34:10.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" }, +] + +[[package]] +name = "streamlit" +version = "1.50.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "altair" }, + { name = "blinker" }, + { name = "cachetools" }, + { name = "click" }, + { name = "gitpython" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, + { name = "pillow" }, + { name = "protobuf" }, + { name = "pyarrow" }, + { name = "pydeck" }, + { name = "requests" }, + { name = "tenacity" }, + { name = "toml" }, + { name = "tornado" }, + { name = "typing-extensions" }, + { name = "watchdog", marker = "sys_platform != 'darwin'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/f6/f7d3a0146577c1918439d3163707040f7111a7d2e7e2c73fa7adeb169c06/streamlit-1.50.0.tar.gz", hash = "sha256:87221d568aac585274a05ef18a378b03df332b93e08103fffcf3cd84d852af46", size = 9664808, upload-time = "2025-09-23T19:24:00.31Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/38/991bbf9fa3ed3d9c8e69265fc449bdaade8131c7f0f750dbd388c3c477dc/streamlit-1.50.0-py3-none-any.whl", hash = "sha256:9403b8f94c0a89f80cf679c2fcc803d9a6951e0fba542e7611995de3f67b4bb3", size = 10068477, upload-time = "2025-09-23T19:23:57.245Z" }, +] + +[[package]] +name = "streamlit-antd-components" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "streamlit" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/b5/dc/1ed6266b606e3b494b9af3e2c310ea6cbe2e967aa18873d03c5b267b4c81/streamlit_antd_components-0.3.2-py3-none-any.whl", hash = "sha256:5ae28496127202ed266ea167649436a15f3d548a4805ee5d992c6fc0fe103fd6", size = 2805502, upload-time = "2024-01-19T07:11:42.614Z" }, +] + +[[package]] +name = "sympy" +version = "1.14.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mpmath" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a2/09/77d55d46fd61b4a135c444fc97158ef34a095e5681d0a6c10b75bf356191/sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5", size = 6299353, upload-time = "2025-04-27T18:04:59.103Z" }, +] + +[[package]] +name = "tenacity" +version = "9.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0a/d4/2b0cd0fe285e14b36db076e78c93766ff1d529d70408bd1d2a5a84f1d929/tenacity-9.1.2.tar.gz", hash = "sha256:1169d376c297e7de388d18b4481760d478b0e99a777cad3a9c86e556f4b697cb", size = 48036, upload-time = "2025-04-02T08:25:09.966Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/30/643397144bfbfec6f6ef821f36f33e57d35946c44a2352d3c9f0ae847619/tenacity-9.1.2-py3-none-any.whl", hash = "sha256:f77bf36710d8b73a50b2dd155c97b870017ad21afe6ab300326b0371b3b05138", size = 28248, upload-time = "2025-04-02T08:25:07.678Z" }, +] + +[[package]] +name = "termcolor" +version = "3.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/87/56/ab275c2b56a5e2342568838f0d5e3e66a32354adcc159b495e374cda43f5/termcolor-3.2.0.tar.gz", hash = "sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58", size = 14423, upload-time = "2025-10-25T19:11:42.586Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/d5/141f53d7c1eb2a80e6d3e9a390228c3222c27705cbe7f048d3623053f3ca/termcolor-3.2.0-py3-none-any.whl", hash = "sha256:a10343879eba4da819353c55cb8049b0933890c2ebf9ad5d3ecd2bb32ea96ea6", size = 7698, upload-time = "2025-10-25T19:11:41.536Z" }, +] + +[[package]] +name = "tiktoken" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "regex" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3", size = 1050802, upload-time = "2025-10-06T20:22:00.96Z" }, + { url = "https://files.pythonhosted.org/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160", size = 993995, upload-time = "2025-10-06T20:22:02.788Z" }, + { url = "https://files.pythonhosted.org/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa", size = 1128948, upload-time = "2025-10-06T20:22:03.814Z" }, + { url = "https://files.pythonhosted.org/packages/0c/da/c028aa0babf77315e1cef357d4d768800c5f8a6de04d0eac0f377cb619fa/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:4a1a4fcd021f022bfc81904a911d3df0f6543b9e7627b51411da75ff2fe7a1be", size = 1151986, upload-time = "2025-10-06T20:22:05.173Z" }, + { url = "https://files.pythonhosted.org/packages/a0/5a/886b108b766aa53e295f7216b509be95eb7d60b166049ce2c58416b25f2a/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:981a81e39812d57031efdc9ec59fa32b2a5a5524d20d4776574c4b4bd2e9014a", size = 1194222, upload-time = "2025-10-06T20:22:06.265Z" }, + { url = "https://files.pythonhosted.org/packages/f4/f8/4db272048397636ac7a078d22773dd2795b1becee7bc4922fe6207288d57/tiktoken-0.12.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9baf52f84a3f42eef3ff4e754a0db79a13a27921b457ca9832cf944c6be4f8f3", size = 1255097, upload-time = "2025-10-06T20:22:07.403Z" }, + { url = "https://files.pythonhosted.org/packages/8e/32/45d02e2e0ea2be3a9ed22afc47d93741247e75018aac967b713b2941f8ea/tiktoken-0.12.0-cp313-cp313-win_amd64.whl", hash = "sha256:b8a0cd0c789a61f31bf44851defbd609e8dd1e2c8589c614cc1060940ef1f697", size = 879117, upload-time = "2025-10-06T20:22:08.418Z" }, + { url = "https://files.pythonhosted.org/packages/ce/76/994fc868f88e016e6d05b0da5ac24582a14c47893f4474c3e9744283f1d5/tiktoken-0.12.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d5f89ea5680066b68bcb797ae85219c72916c922ef0fcdd3480c7d2315ffff16", size = 1050309, upload-time = "2025-10-06T20:22:10.939Z" }, + { url = "https://files.pythonhosted.org/packages/f6/b8/57ef1456504c43a849821920d582a738a461b76a047f352f18c0b26c6516/tiktoken-0.12.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b4e7ed1c6a7a8a60a3230965bdedba8cc58f68926b835e519341413370e0399a", size = 993712, upload-time = "2025-10-06T20:22:12.115Z" }, + { url = "https://files.pythonhosted.org/packages/72/90/13da56f664286ffbae9dbcfadcc625439142675845baa62715e49b87b68b/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:fc530a28591a2d74bce821d10b418b26a094bf33839e69042a6e86ddb7a7fb27", size = 1128725, upload-time = "2025-10-06T20:22:13.541Z" }, + { url = "https://files.pythonhosted.org/packages/05/df/4f80030d44682235bdaecd7346c90f67ae87ec8f3df4a3442cb53834f7e4/tiktoken-0.12.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:06a9f4f49884139013b138920a4c393aa6556b2f8f536345f11819389c703ebb", size = 1151875, upload-time = "2025-10-06T20:22:14.559Z" }, + { url = "https://files.pythonhosted.org/packages/22/1f/ae535223a8c4ef4c0c1192e3f9b82da660be9eb66b9279e95c99288e9dab/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:04f0e6a985d95913cabc96a741c5ffec525a2c72e9df086ff17ebe35985c800e", size = 1194451, upload-time = "2025-10-06T20:22:15.545Z" }, + { url = "https://files.pythonhosted.org/packages/78/a7/f8ead382fce0243cb625c4f266e66c27f65ae65ee9e77f59ea1653b6d730/tiktoken-0.12.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:0ee8f9ae00c41770b5f9b0bb1235474768884ae157de3beb5439ca0fd70f3e25", size = 1253794, upload-time = "2025-10-06T20:22:16.624Z" }, + { url = "https://files.pythonhosted.org/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" }, +] + +[[package]] +name = "tokenizers" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1c/46/fb6854cec3278fbfa4a75b50232c77622bc517ac886156e6afbfa4d8fc6e/tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9", size = 363123, upload-time = "2025-09-19T09:49:23.424Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/33/f4b2d94ada7ab297328fc671fed209368ddb82f965ec2224eb1892674c3a/tokenizers-0.22.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59fdb013df17455e5f950b4b834a7b3ee2e0271e6378ccb33aa74d178b513c73", size = 3069318, upload-time = "2025-09-19T09:49:11.848Z" }, + { url = "https://files.pythonhosted.org/packages/1c/58/2aa8c874d02b974990e89ff95826a4852a8b2a273c7d1b4411cdd45a4565/tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:8d4e484f7b0827021ac5f9f71d4794aaef62b979ab7608593da22b1d2e3c4edc", size = 2926478, upload-time = "2025-09-19T09:49:09.759Z" }, + { url = "https://files.pythonhosted.org/packages/1e/3b/55e64befa1e7bfea963cf4b787b2cea1011362c4193f5477047532ce127e/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d2962dd28bc67c1f205ab180578a78eef89ac60ca7ef7cbe9635a46a56422a", size = 3256994, upload-time = "2025-09-19T09:48:56.701Z" }, + { url = "https://files.pythonhosted.org/packages/71/0b/fbfecf42f67d9b7b80fde4aabb2b3110a97fac6585c9470b5bff103a80cb/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:38201f15cdb1f8a6843e6563e6e79f4abd053394992b9bbdf5213ea3469b4ae7", size = 3153141, upload-time = "2025-09-19T09:48:59.749Z" }, + { url = "https://files.pythonhosted.org/packages/17/a9/b38f4e74e0817af8f8ef925507c63c6ae8171e3c4cb2d5d4624bf58fca69/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1cbe5454c9a15df1b3443c726063d930c16f047a3cc724b9e6e1a91140e5a21", size = 3508049, upload-time = "2025-09-19T09:49:05.868Z" }, + { url = "https://files.pythonhosted.org/packages/d2/48/dd2b3dac46bb9134a88e35d72e1aa4869579eacc1a27238f1577270773ff/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7d094ae6312d69cc2a872b54b91b309f4f6fbce871ef28eb27b52a98e4d0214", size = 3710730, upload-time = "2025-09-19T09:49:01.832Z" }, + { url = "https://files.pythonhosted.org/packages/93/0e/ccabc8d16ae4ba84a55d41345207c1e2ea88784651a5a487547d80851398/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afd7594a56656ace95cdd6df4cca2e4059d294c5cfb1679c57824b605556cb2f", size = 3412560, upload-time = "2025-09-19T09:49:03.867Z" }, + { url = "https://files.pythonhosted.org/packages/d0/c6/dc3a0db5a6766416c32c034286d7c2d406da1f498e4de04ab1b8959edd00/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2ef6063d7a84994129732b47e7915e8710f27f99f3a3260b8a38fc7ccd083f4", size = 3250221, upload-time = "2025-09-19T09:49:07.664Z" }, + { url = "https://files.pythonhosted.org/packages/d7/a6/2c8486eef79671601ff57b093889a345dd3d576713ef047776015dc66de7/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ba0a64f450b9ef412c98f6bcd2a50c6df6e2443b560024a09fa6a03189726879", size = 9345569, upload-time = "2025-09-19T09:49:14.214Z" }, + { url = "https://files.pythonhosted.org/packages/6b/16/32ce667f14c35537f5f605fe9bea3e415ea1b0a646389d2295ec348d5657/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:331d6d149fa9c7d632cde4490fb8bbb12337fa3a0232e77892be656464f4b446", size = 9271599, upload-time = "2025-09-19T09:49:16.639Z" }, + { url = "https://files.pythonhosted.org/packages/51/7c/a5f7898a3f6baa3fc2685c705e04c98c1094c523051c805cdd9306b8f87e/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:607989f2ea68a46cb1dfbaf3e3aabdf3f21d8748312dbeb6263d1b3b66c5010a", size = 9533862, upload-time = "2025-09-19T09:49:19.146Z" }, + { url = "https://files.pythonhosted.org/packages/36/65/7e75caea90bc73c1dd8d40438adf1a7bc26af3b8d0a6705ea190462506e1/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a0f307d490295717726598ef6fa4f24af9d484809223bbc253b201c740a06390", size = 9681250, upload-time = "2025-09-19T09:49:21.501Z" }, + { url = "https://files.pythonhosted.org/packages/30/2c/959dddef581b46e6209da82df3b78471e96260e2bc463f89d23b1bf0e52a/tokenizers-0.22.1-cp39-abi3-win32.whl", hash = "sha256:b5120eed1442765cd90b903bb6cfef781fd8fe64e34ccaecbae4c619b7b12a82", size = 2472003, upload-time = "2025-09-19T09:49:27.089Z" }, + { url = "https://files.pythonhosted.org/packages/b3/46/e33a8c93907b631a99377ef4c5f817ab453d0b34f93529421f42ff559671/tokenizers-0.22.1-cp39-abi3-win_amd64.whl", hash = "sha256:65fd6e3fb11ca1e78a6a93602490f134d1fdeb13bcef99389d5102ea318ed138", size = 2674684, upload-time = "2025-09-19T09:49:24.953Z" }, +] + +[[package]] +name = "toml" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/be/ba/1f744cdc819428fc6b5084ec34d9b30660f6f9daaf70eead706e3203ec3c/toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f", size = 22253, upload-time = "2020-11-01T01:40:22.204Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" }, +] + +[[package]] +name = "torch" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "fsspec" }, + { name = "jinja2" }, + { name = "networkx" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufile-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparselt-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools" }, + { name = "sympy" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/4e/469ced5a0603245d6a19a556e9053300033f9c5baccf43a3d25ba73e189e/torch-2.8.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2b2f96814e0345f5a5aed9bf9734efa913678ed19caf6dc2cddb7930672d6128", size = 101936856, upload-time = "2025-08-06T14:54:01.526Z" }, + { url = "https://files.pythonhosted.org/packages/16/82/3948e54c01b2109238357c6f86242e6ecbf0c63a1af46906772902f82057/torch-2.8.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:65616ca8ec6f43245e1f5f296603e33923f4c30f93d65e103d9e50c25b35150b", size = 887922844, upload-time = "2025-08-06T14:55:50.78Z" }, + { url = "https://files.pythonhosted.org/packages/e3/54/941ea0a860f2717d86a811adf0c2cd01b3983bdd460d0803053c4e0b8649/torch-2.8.0-cp313-cp313-win_amd64.whl", hash = "sha256:659df54119ae03e83a800addc125856effda88b016dfc54d9f65215c3975be16", size = 241330968, upload-time = "2025-08-06T14:54:45.293Z" }, + { url = "https://files.pythonhosted.org/packages/de/69/8b7b13bba430f5e21d77708b616f767683629fc4f8037564a177d20f90ed/torch-2.8.0-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:1a62a1ec4b0498930e2543535cf70b1bef8c777713de7ceb84cd79115f553767", size = 73915128, upload-time = "2025-08-06T14:54:34.769Z" }, + { url = "https://files.pythonhosted.org/packages/15/0e/8a800e093b7f7430dbaefa80075aee9158ec22e4c4fc3c1a66e4fb96cb4f/torch-2.8.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:83c13411a26fac3d101fe8035a6b0476ae606deb8688e904e796a3534c197def", size = 102020139, upload-time = "2025-08-06T14:54:39.047Z" }, + { url = "https://files.pythonhosted.org/packages/4a/15/5e488ca0bc6162c86a33b58642bc577c84ded17c7b72d97e49b5833e2d73/torch-2.8.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:8f0a9d617a66509ded240add3754e462430a6c1fc5589f86c17b433dd808f97a", size = 887990692, upload-time = "2025-08-06T14:56:18.286Z" }, + { url = "https://files.pythonhosted.org/packages/b4/a8/6a04e4b54472fc5dba7ca2341ab219e529f3c07b6941059fbf18dccac31f/torch-2.8.0-cp313-cp313t-win_amd64.whl", hash = "sha256:a7242b86f42be98ac674b88a4988643b9bc6145437ec8f048fea23f72feb5eca", size = 241603453, upload-time = "2025-08-06T14:55:22.945Z" }, + { url = "https://files.pythonhosted.org/packages/04/6e/650bb7f28f771af0cb791b02348db8b7f5f64f40f6829ee82aa6ce99aabe/torch-2.8.0-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7b677e17f5a3e69fdef7eb3b9da72622f8d322692930297e4ccb52fefc6c8211", size = 73632395, upload-time = "2025-08-06T14:55:28.645Z" }, +] + +[[package]] +name = "torchvision" +version = "0.23.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pillow" }, + { name = "torch" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/37/45a5b9407a7900f71d61b2b2f62db4b7c632debca397f205fdcacb502780/torchvision-0.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1c37e325e09a184b730c3ef51424f383ec5745378dc0eca244520aca29722600", size = 1856886, upload-time = "2025-08-06T14:58:05.491Z" }, + { url = "https://files.pythonhosted.org/packages/ac/da/a06c60fc84fc849377cf035d3b3e9a1c896d52dbad493b963c0f1cdd74d0/torchvision-0.23.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:2f7fd6c15f3697e80627b77934f77705f3bc0e98278b989b2655de01f6903e1d", size = 2353112, upload-time = "2025-08-06T14:58:26.265Z" }, + { url = "https://files.pythonhosted.org/packages/a0/27/5ce65ba5c9d3b7d2ccdd79892ab86a2f87ac2ca6638f04bb0280321f1a9c/torchvision-0.23.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a76fafe113b2977be3a21bf78f115438c1f88631d7a87203acb3dd6ae55889e6", size = 8627658, upload-time = "2025-08-06T14:58:15.999Z" }, + { url = "https://files.pythonhosted.org/packages/1f/e4/028a27b60aa578a2fa99d9d7334ff1871bb17008693ea055a2fdee96da0d/torchvision-0.23.0-cp313-cp313-win_amd64.whl", hash = "sha256:07d069cb29691ff566e3b7f11f20d91044f079e1dbdc9d72e0655899a9b06938", size = 1600749, upload-time = "2025-08-06T14:58:10.719Z" }, + { url = "https://files.pythonhosted.org/packages/05/35/72f91ad9ac7c19a849dedf083d347dc1123f0adeb401f53974f84f1d04c8/torchvision-0.23.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:2df618e1143805a7673aaf82cb5720dd9112d4e771983156aaf2ffff692eebf9", size = 2047192, upload-time = "2025-08-06T14:58:11.813Z" }, + { url = "https://files.pythonhosted.org/packages/1d/9d/406cea60a9eb9882145bcd62a184ee61e823e8e1d550cdc3c3ea866a9445/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:2a3299d2b1d5a7aed2d3b6ffb69c672ca8830671967eb1cee1497bacd82fe47b", size = 2359295, upload-time = "2025-08-06T14:58:17.469Z" }, + { url = "https://files.pythonhosted.org/packages/2b/f4/34662f71a70fa1e59de99772142f22257ca750de05ccb400b8d2e3809c1d/torchvision-0.23.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:76bc4c0b63d5114aa81281390f8472a12a6a35ce9906e67ea6044e5af4cab60c", size = 8800474, upload-time = "2025-08-06T14:58:22.53Z" }, + { url = "https://files.pythonhosted.org/packages/6e/f5/b5a2d841a8d228b5dbda6d524704408e19e7ca6b7bb0f24490e081da1fa1/torchvision-0.23.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b9e2dabf0da9c8aa9ea241afb63a8f3e98489e706b22ac3f30416a1be377153b", size = 1527667, upload-time = "2025-08-06T14:58:14.446Z" }, +] + +[[package]] +name = "tornado" +version = "6.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563, upload-time = "2025-08-08T18:26:42.945Z" }, + { url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729, upload-time = "2025-08-08T18:26:44.473Z" }, + { url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295, upload-time = "2025-08-08T18:26:46.021Z" }, + { url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644, upload-time = "2025-08-08T18:26:47.625Z" }, + { url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878, upload-time = "2025-08-08T18:26:50.599Z" }, + { url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549, upload-time = "2025-08-08T18:26:51.864Z" }, + { url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973, upload-time = "2025-08-08T18:26:53.625Z" }, + { url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954, upload-time = "2025-08-08T18:26:55.072Z" }, + { url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023, upload-time = "2025-08-08T18:26:56.677Z" }, + { url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427, upload-time = "2025-08-08T18:26:57.91Z" }, + { url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456, upload-time = "2025-08-08T18:26:59.207Z" }, +] + +[[package]] +name = "tqdm" +version = "4.67.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, +] + +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621, upload-time = "2024-04-19T11:11:49.746Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, +] + +[[package]] +name = "transformers" +version = "4.57.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "huggingface-hub" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "regex" }, + { name = "requests" }, + { name = "safetensors" }, + { name = "tokenizers" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/68/a39307bcc4116a30b2106f2e689130a48de8bd8a1e635b5e1030e46fcd9e/transformers-4.57.1.tar.gz", hash = "sha256:f06c837959196c75039809636cd964b959f6604b75b8eeec6fdfc0440b89cc55", size = 10142511, upload-time = "2025-10-14T15:39:26.18Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/d3/c16c3b3cf7655a67db1144da94b021c200ac1303f82428f2beef6c2e72bb/transformers-4.57.1-py3-none-any.whl", hash = "sha256:b10d05da8fa67dc41644dbbf9bc45a44cb86ae33da6f9295f5fbf5b7890bd267", size = 11990925, upload-time = "2025-10-14T15:39:23.085Z" }, +] + +[[package]] +name = "triton" +version = "3.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "setuptools" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" }, + { url = "https://files.pythonhosted.org/packages/20/63/8cb444ad5cdb25d999b7d647abac25af0ee37d292afc009940c05b82dda0/triton-3.4.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7936b18a3499ed62059414d7df563e6c163c5e16c3773678a3ee3d417865035d", size = 155659780, upload-time = "2025-07-30T19:58:51.171Z" }, +] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/e3/70399cb7dd41c10ac53367ae42139cf4b1ca5f36bb3dc6c9d33acdb43655/typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464", size = 75949, upload-time = "2025-10-01T02:14:41.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, +] + +[[package]] +name = "tzdata" +version = "2025.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, +] + +[[package]] +name = "urllib3" +version = "2.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/22/9ee70a2574a4f4599c47dd506532914ce044817c7752a79b6a51286319bc/urllib3-2.5.0.tar.gz", hash = "sha256:3fc47733c7e419d4bc3f6b3dc2b4f890bb743906a30d56ba4a5bfa4bbff92760", size = 393185, upload-time = "2025-06-18T14:07:41.644Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, +] + +[[package]] +name = "uvicorn" +version = "0.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/f06b84e2697fef4688ca63bdb2fdf113ca0a3be33f94488f2cadb690b0cf/uvicorn-0.38.0.tar.gz", hash = "sha256:fd97093bdd120a2609fc0d3afe931d4d4ad688b6e75f0f929fde1bc36fe0e91d", size = 80605, upload-time = "2025-10-18T13:46:44.63Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ee/d9/d88e73ca598f4f6ff671fb5fde8a32925c2e08a637303a1d12883c7305fa/uvicorn-0.38.0-py3-none-any.whl", hash = "sha256:48c0afd214ceb59340075b4a052ea1ee91c16fbc2a9b1469cca0e54566977b02", size = 68109, upload-time = "2025-10-18T13:46:42.958Z" }, +] + +[[package]] +name = "virtualenv" +version = "20.35.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/20/28/e6f1a6f655d620846bd9df527390ecc26b3805a0c5989048c210e22c5ca9/virtualenv-20.35.4.tar.gz", hash = "sha256:643d3914d73d3eeb0c552cbb12d7e82adf0e504dbf86a3182f8771a153a1971c", size = 6028799, upload-time = "2025-10-29T06:57:40.511Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/0c/c05523fa3181fdf0c9c52a6ba91a23fbf3246cc095f26f6516f9c60e6771/virtualenv-20.35.4-py3-none-any.whl", hash = "sha256:c21c9cede36c9753eeade68ba7d523529f228a403463376cf821eaae2b650f1b", size = 6005095, upload-time = "2025-10-29T06:57:37.598Z" }, +] + +[[package]] +name = "wadler-lindig" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/67/cbae4bf7683a64755c2c1778c418fea96d00e34395bb91743f08bd951571/wadler_lindig-0.1.7.tar.gz", hash = "sha256:81d14d3fe77d441acf3ebd7f4aefac20c74128bf460e84b512806dccf7b2cd55", size = 15842, upload-time = "2025-06-18T07:00:42.843Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/96/04e7b441807b26b794da5b11e59ed7f83b2cf8af202bd7eba8ad2fa6046e/wadler_lindig-0.1.7-py3-none-any.whl", hash = "sha256:e3ec83835570fd0a9509f969162aeb9c65618f998b1f42918cfc8d45122fe953", size = 20516, upload-time = "2025-06-18T07:00:41.684Z" }, +] + +[[package]] +name = "wandb" +version = "0.22.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "gitpython" }, + { name = "packaging" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c1/d1/6b70f365ed86bd69debba8ad55dec8606fc21006e7ca703a5a091bd3b719/wandb-0.22.3.tar.gz", hash = "sha256:04468a8ab2769a46f5e384c9c4ada5da0dced005ca689a8424e4b8b5cb2a0291", size = 44337368, upload-time = "2025-10-28T23:59:10.275Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/23/02/87fb60f587ec249f784a40bd91c30de1b2b24d691ee72675d5b66c3d0728/wandb-0.22.3-py3-none-macosx_12_0_arm64.whl", hash = "sha256:81b3b6e405f38342b0a080898b7d00c5b9375432f5ba358942a09e65cdcfe781", size = 18758047, upload-time = "2025-10-28T23:58:46.56Z" }, + { url = "https://files.pythonhosted.org/packages/26/88/64081740ef2b2efc7fbcb2139a07a849e42bcb09ae0c56ae50c41bd0ad63/wandb-0.22.3-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:d29c16817cca6401b4919069ec7570c781eacb67dc0b1ff2e0096a9a59581720", size = 19798011, upload-time = "2025-10-28T23:58:49.718Z" }, + { url = "https://files.pythonhosted.org/packages/19/72/c4f922b33dbb84d1c81ee045ff8791dd14e26d79e1e9bbafff964b7043e2/wandb-0.22.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb955d73a4ba55df9adc61fafbabef5556784d33fc39c7b5c8165d2694ddeb3b", size = 18542713, upload-time = "2025-10-28T23:58:51.927Z" }, + { url = "https://files.pythonhosted.org/packages/ad/98/3ce5f6e2086d91b0c51b38ae7ff591109e7da2bb25fe1a12eec0cdbaa494/wandb-0.22.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23f3ebe41a26506117a098fdfd2706ed0e50b37899bfbefe3a0628fcbd70c69d", size = 19984910, upload-time = "2025-10-28T23:58:54.641Z" }, + { url = "https://files.pythonhosted.org/packages/5e/57/e68cb38427b60490d6ddf1b992e6c7f36be83be1079d291ce87a8d347f48/wandb-0.22.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2973462bed5d4a653b1a97cf9fc350673bb200fb356a2f4eba34beae9b87e0aa", size = 18581776, upload-time = "2025-10-28T23:58:56.975Z" }, + { url = "https://files.pythonhosted.org/packages/66/6d/543f907ce0c6b6da13628b23d19ca7282c559fd73eb47b04977b9a61d0c6/wandb-0.22.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c5c2bd18f95c1639863c527da0a5818ac6b0e5194f9c691426b265908ddd8b2c", size = 20078800, upload-time = "2025-10-28T23:58:59.217Z" }, + { url = "https://files.pythonhosted.org/packages/da/91/1decaf1a6ac2017481c782e0fad7f90bc9ae4057f3d76d478cb6527f3dd3/wandb-0.22.3-py3-none-win32.whl", hash = "sha256:09ca1edfe0fd6dc30447d368acddb825668e60ee705c98594a6bbfd30d34d47e", size = 19160297, upload-time = "2025-10-28T23:59:01.536Z" }, + { url = "https://files.pythonhosted.org/packages/4c/ba/3b092634279994b0c79fe05220532822be09f3a353ae95c54e7142769db8/wandb-0.22.3-py3-none-win_amd64.whl", hash = "sha256:55403bf93872c9978433d101324f51e43e78c70c809bf6d06ca7b2760e39f497", size = 19160300, upload-time = "2025-10-28T23:59:04.06Z" }, + { url = "https://files.pythonhosted.org/packages/7f/80/4662fce9eebcc8c71f5083e9152ccaf7d43d4ca9c446e1422f9aa784a51c/wandb-0.22.3-py3-none-win_arm64.whl", hash = "sha256:49f66b05882abfa53816cc8d01b3c2435a89c5a090176802fa6928b5979d34d9", size = 17461959, upload-time = "2025-10-28T23:59:07.059Z" }, +] + +[[package]] +name = "wandb-workspaces" +version = "0.1.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "wandb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ea/2d/ca96d36a7f8a416e1a962ee96d22922accd742ecc9074073217e7d0feede/wandb_workspaces-0.1.12.tar.gz", hash = "sha256:1e983afd4a758319a8b3acb83b43067b61bebc3db61e2dc09f087ea4e5434468", size = 73555, upload-time = "2025-02-17T16:11:26.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/f0/9811ad9c52bdb33e36e61e162277fa1b9a60f143b4d534731816c0438a7d/wandb_workspaces-0.1.12-py3-none-any.whl", hash = "sha256:a59767794ab13ea3d5cbbc7cef5ca55c370497c7c81834b30ab6005ebe318e1c", size = 84034, upload-time = "2025-02-17T16:11:22.724Z" }, +] + +[[package]] +name = "watchdog" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/db/7d/7f3d619e951c88ed75c6037b246ddcf2d322812ee8ea189be89511721d54/watchdog-6.0.0.tar.gz", hash = "sha256:9ddf7c82fda3ae8e24decda1338ede66e1c99883db93711d8fb941eaa2d8c282", size = 131220, upload-time = "2024-11-01T14:07:13.037Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/c7/ca4bf3e518cb57a686b2feb4f55a1892fd9a3dd13f470fca14e00f80ea36/watchdog-6.0.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:7607498efa04a3542ae3e05e64da8202e58159aa1fa4acddf7678d34a35d4f13", size = 79079, upload-time = "2024-11-01T14:06:59.472Z" }, + { url = "https://files.pythonhosted.org/packages/5c/51/d46dc9332f9a647593c947b4b88e2381c8dfc0942d15b8edc0310fa4abb1/watchdog-6.0.0-py3-none-manylinux2014_armv7l.whl", hash = "sha256:9041567ee8953024c83343288ccc458fd0a2d811d6a0fd68c4c22609e3490379", size = 79078, upload-time = "2024-11-01T14:07:01.431Z" }, + { url = "https://files.pythonhosted.org/packages/d4/57/04edbf5e169cd318d5f07b4766fee38e825d64b6913ca157ca32d1a42267/watchdog-6.0.0-py3-none-manylinux2014_i686.whl", hash = "sha256:82dc3e3143c7e38ec49d61af98d6558288c415eac98486a5c581726e0737c00e", size = 79076, upload-time = "2024-11-01T14:07:02.568Z" }, + { url = "https://files.pythonhosted.org/packages/ab/cc/da8422b300e13cb187d2203f20b9253e91058aaf7db65b74142013478e66/watchdog-6.0.0-py3-none-manylinux2014_ppc64.whl", hash = "sha256:212ac9b8bf1161dc91bd09c048048a95ca3a4c4f5e5d4a7d1b1a7d5752a7f96f", size = 79077, upload-time = "2024-11-01T14:07:03.893Z" }, + { url = "https://files.pythonhosted.org/packages/2c/3b/b8964e04ae1a025c44ba8e4291f86e97fac443bca31de8bd98d3263d2fcf/watchdog-6.0.0-py3-none-manylinux2014_ppc64le.whl", hash = "sha256:e3df4cbb9a450c6d49318f6d14f4bbc80d763fa587ba46ec86f99f9e6876bb26", size = 79078, upload-time = "2024-11-01T14:07:05.189Z" }, + { url = "https://files.pythonhosted.org/packages/62/ae/a696eb424bedff7407801c257d4b1afda455fe40821a2be430e173660e81/watchdog-6.0.0-py3-none-manylinux2014_s390x.whl", hash = "sha256:2cce7cfc2008eb51feb6aab51251fd79b85d9894e98ba847408f662b3395ca3c", size = 79077, upload-time = "2024-11-01T14:07:06.376Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e8/dbf020b4d98251a9860752a094d09a65e1b436ad181faf929983f697048f/watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:20ffe5b202af80ab4266dcd3e91aae72bf2da48c0d33bdb15c66658e685e94e2", size = 79078, upload-time = "2024-11-01T14:07:07.547Z" }, + { url = "https://files.pythonhosted.org/packages/07/f6/d0e5b343768e8bcb4cda79f0f2f55051bf26177ecd5651f84c07567461cf/watchdog-6.0.0-py3-none-win32.whl", hash = "sha256:07df1fdd701c5d4c8e55ef6cf55b8f0120fe1aef7ef39a1c6fc6bc2e606d517a", size = 79065, upload-time = "2024-11-01T14:07:09.525Z" }, + { url = "https://files.pythonhosted.org/packages/db/d9/c495884c6e548fce18a8f40568ff120bc3a4b7b99813081c8ac0c936fa64/watchdog-6.0.0-py3-none-win_amd64.whl", hash = "sha256:cbafb470cf848d93b5d013e2ecb245d4aa1c8fd0504e863ccefa32445359d680", size = 79070, upload-time = "2024-11-01T14:07:10.686Z" }, + { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, +] + +[[package]] +name = "wcwidth" +version = "0.2.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, +] + +[[package]] +name = "xxhash" +version = "3.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/84/30869e01909fb37a6cc7e18688ee8bf1e42d57e7e0777636bd47524c43c7/xxhash-3.6.0.tar.gz", hash = "sha256:f0162a78b13a0d7617b2845b90c763339d1f1d82bb04a4b07f4ab535cc5e05d6", size = 85160, upload-time = "2025-10-02T14:37:08.097Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/76/35d05267ac82f53ae9b0e554da7c5e281ee61f3cad44c743f0fcd354f211/xxhash-3.6.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:599e64ba7f67472481ceb6ee80fa3bd828fd61ba59fb11475572cc5ee52b89ec", size = 32738, upload-time = "2025-10-02T14:34:55.839Z" }, + { url = "https://files.pythonhosted.org/packages/31/a8/3fbce1cd96534a95e35d5120637bf29b0d7f5d8fa2f6374e31b4156dd419/xxhash-3.6.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7d8b8aaa30fca4f16f0c84a5c8d7ddee0e25250ec2796c973775373257dde8f1", size = 30821, upload-time = "2025-10-02T14:34:57.219Z" }, + { url = "https://files.pythonhosted.org/packages/0c/ea/d387530ca7ecfa183cb358027f1833297c6ac6098223fd14f9782cd0015c/xxhash-3.6.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d597acf8506d6e7101a4a44a5e428977a51c0fadbbfd3c39650cca9253f6e5a6", size = 194127, upload-time = "2025-10-02T14:34:59.21Z" }, + { url = "https://files.pythonhosted.org/packages/ba/0c/71435dcb99874b09a43b8d7c54071e600a7481e42b3e3ce1eb5226a5711a/xxhash-3.6.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:858dc935963a33bc33490128edc1c12b0c14d9c7ebaa4e387a7869ecc4f3e263", size = 212975, upload-time = "2025-10-02T14:35:00.816Z" }, + { url = "https://files.pythonhosted.org/packages/84/7a/c2b3d071e4bb4a90b7057228a99b10d51744878f4a8a6dd643c8bd897620/xxhash-3.6.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ba284920194615cb8edf73bf52236ce2e1664ccd4a38fdb543506413529cc546", size = 212241, upload-time = "2025-10-02T14:35:02.207Z" }, + { url = "https://files.pythonhosted.org/packages/81/5f/640b6eac0128e215f177df99eadcd0f1b7c42c274ab6a394a05059694c5a/xxhash-3.6.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4b54219177f6c6674d5378bd862c6aedf64725f70dd29c472eaae154df1a2e89", size = 445471, upload-time = "2025-10-02T14:35:03.61Z" }, + { url = "https://files.pythonhosted.org/packages/5e/1e/3c3d3ef071b051cc3abbe3721ffb8365033a172613c04af2da89d5548a87/xxhash-3.6.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:42c36dd7dbad2f5238950c377fcbf6811b1cdb1c444fab447960030cea60504d", size = 193936, upload-time = "2025-10-02T14:35:05.013Z" }, + { url = "https://files.pythonhosted.org/packages/2c/bd/4a5f68381939219abfe1c22a9e3a5854a4f6f6f3c4983a87d255f21f2e5d/xxhash-3.6.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f22927652cba98c44639ffdc7aaf35828dccf679b10b31c4ad72a5b530a18eb7", size = 210440, upload-time = "2025-10-02T14:35:06.239Z" }, + { url = "https://files.pythonhosted.org/packages/eb/37/b80fe3d5cfb9faff01a02121a0f4d565eb7237e9e5fc66e73017e74dcd36/xxhash-3.6.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b45fad44d9c5c119e9c6fbf2e1c656a46dc68e280275007bbfd3d572b21426db", size = 197990, upload-time = "2025-10-02T14:35:07.735Z" }, + { url = "https://files.pythonhosted.org/packages/d7/fd/2c0a00c97b9e18f72e1f240ad4e8f8a90fd9d408289ba9c7c495ed7dc05c/xxhash-3.6.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:6f2580ffab1a8b68ef2b901cde7e55fa8da5e4be0977c68f78fc80f3c143de42", size = 210689, upload-time = "2025-10-02T14:35:09.438Z" }, + { url = "https://files.pythonhosted.org/packages/93/86/5dd8076a926b9a95db3206aba20d89a7fc14dd5aac16e5c4de4b56033140/xxhash-3.6.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:40c391dd3cd041ebc3ffe6f2c862f402e306eb571422e0aa918d8070ba31da11", size = 414068, upload-time = "2025-10-02T14:35:11.162Z" }, + { url = "https://files.pythonhosted.org/packages/af/3c/0bb129170ee8f3650f08e993baee550a09593462a5cddd8e44d0011102b1/xxhash-3.6.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f205badabde7aafd1a31e8ca2a3e5a763107a71c397c4481d6a804eb5063d8bd", size = 191495, upload-time = "2025-10-02T14:35:12.971Z" }, + { url = "https://files.pythonhosted.org/packages/e9/3a/6797e0114c21d1725e2577508e24006fd7ff1d8c0c502d3b52e45c1771d8/xxhash-3.6.0-cp313-cp313-win32.whl", hash = "sha256:2577b276e060b73b73a53042ea5bd5203d3e6347ce0d09f98500f418a9fcf799", size = 30620, upload-time = "2025-10-02T14:35:14.129Z" }, + { url = "https://files.pythonhosted.org/packages/86/15/9bc32671e9a38b413a76d24722a2bf8784a132c043063a8f5152d390b0f9/xxhash-3.6.0-cp313-cp313-win_amd64.whl", hash = "sha256:757320d45d2fbcce8f30c42a6b2f47862967aea7bf458b9625b4bbe7ee390392", size = 31542, upload-time = "2025-10-02T14:35:15.21Z" }, + { url = "https://files.pythonhosted.org/packages/39/c5/cc01e4f6188656e56112d6a8e0dfe298a16934b8c47a247236549a3f7695/xxhash-3.6.0-cp313-cp313-win_arm64.whl", hash = "sha256:457b8f85dec5825eed7b69c11ae86834a018b8e3df5e77783c999663da2f96d6", size = 27880, upload-time = "2025-10-02T14:35:16.315Z" }, + { url = "https://files.pythonhosted.org/packages/f3/30/25e5321c8732759e930c555176d37e24ab84365482d257c3b16362235212/xxhash-3.6.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:a42e633d75cdad6d625434e3468126c73f13f7584545a9cf34e883aa1710e702", size = 32956, upload-time = "2025-10-02T14:35:17.413Z" }, + { url = "https://files.pythonhosted.org/packages/9f/3c/0573299560d7d9f8ab1838f1efc021a280b5ae5ae2e849034ef3dee18810/xxhash-3.6.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:568a6d743219e717b07b4e03b0a828ce593833e498c3b64752e0f5df6bfe84db", size = 31072, upload-time = "2025-10-02T14:35:18.844Z" }, + { url = "https://files.pythonhosted.org/packages/7a/1c/52d83a06e417cd9d4137722693424885cc9878249beb3a7c829e74bf7ce9/xxhash-3.6.0-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:bec91b562d8012dae276af8025a55811b875baace6af510412a5e58e3121bc54", size = 196409, upload-time = "2025-10-02T14:35:20.31Z" }, + { url = "https://files.pythonhosted.org/packages/e3/8e/c6d158d12a79bbd0b878f8355432075fc82759e356ab5a111463422a239b/xxhash-3.6.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:78e7f2f4c521c30ad5e786fdd6bae89d47a32672a80195467b5de0480aa97b1f", size = 215736, upload-time = "2025-10-02T14:35:21.616Z" }, + { url = "https://files.pythonhosted.org/packages/bc/68/c4c80614716345d55071a396cf03d06e34b5f4917a467faf43083c995155/xxhash-3.6.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3ed0df1b11a79856df5ffcab572cbd6b9627034c1c748c5566fa79df9048a7c5", size = 214833, upload-time = "2025-10-02T14:35:23.32Z" }, + { url = "https://files.pythonhosted.org/packages/7e/e9/ae27c8ffec8b953efa84c7c4a6c6802c263d587b9fc0d6e7cea64e08c3af/xxhash-3.6.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0e4edbfc7d420925b0dd5e792478ed393d6e75ff8fc219a6546fb446b6a417b1", size = 448348, upload-time = "2025-10-02T14:35:25.111Z" }, + { url = "https://files.pythonhosted.org/packages/d7/6b/33e21afb1b5b3f46b74b6bd1913639066af218d704cc0941404ca717fc57/xxhash-3.6.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fba27a198363a7ef87f8c0f6b171ec36b674fe9053742c58dd7e3201c1ab30ee", size = 196070, upload-time = "2025-10-02T14:35:26.586Z" }, + { url = "https://files.pythonhosted.org/packages/96/b6/fcabd337bc5fa624e7203aa0fa7d0c49eed22f72e93229431752bddc83d9/xxhash-3.6.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:794fe9145fe60191c6532fa95063765529770edcdd67b3d537793e8004cabbfd", size = 212907, upload-time = "2025-10-02T14:35:28.087Z" }, + { url = "https://files.pythonhosted.org/packages/4b/d3/9ee6160e644d660fcf176c5825e61411c7f62648728f69c79ba237250143/xxhash-3.6.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:6105ef7e62b5ac73a837778efc331a591d8442f8ef5c7e102376506cb4ae2729", size = 200839, upload-time = "2025-10-02T14:35:29.857Z" }, + { url = "https://files.pythonhosted.org/packages/0d/98/e8de5baa5109394baf5118f5e72ab21a86387c4f89b0e77ef3e2f6b0327b/xxhash-3.6.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:f01375c0e55395b814a679b3eea205db7919ac2af213f4a6682e01220e5fe292", size = 213304, upload-time = "2025-10-02T14:35:31.222Z" }, + { url = "https://files.pythonhosted.org/packages/7b/1d/71056535dec5c3177eeb53e38e3d367dd1d16e024e63b1cee208d572a033/xxhash-3.6.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:d706dca2d24d834a4661619dcacf51a75c16d65985718d6a7d73c1eeeb903ddf", size = 416930, upload-time = "2025-10-02T14:35:32.517Z" }, + { url = "https://files.pythonhosted.org/packages/dc/6c/5cbde9de2cd967c322e651c65c543700b19e7ae3e0aae8ece3469bf9683d/xxhash-3.6.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5f059d9faeacd49c0215d66f4056e1326c80503f51a1532ca336a385edadd033", size = 193787, upload-time = "2025-10-02T14:35:33.827Z" }, + { url = "https://files.pythonhosted.org/packages/19/fa/0172e350361d61febcea941b0cc541d6e6c8d65d153e85f850a7b256ff8a/xxhash-3.6.0-cp313-cp313t-win32.whl", hash = "sha256:1244460adc3a9be84731d72b8e80625788e5815b68da3da8b83f78115a40a7ec", size = 30916, upload-time = "2025-10-02T14:35:35.107Z" }, + { url = "https://files.pythonhosted.org/packages/ad/e6/e8cf858a2b19d6d45820f072eff1bea413910592ff17157cabc5f1227a16/xxhash-3.6.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b1e420ef35c503869c4064f4a2f2b08ad6431ab7b229a05cce39d74268bca6b8", size = 31799, upload-time = "2025-10-02T14:35:36.165Z" }, + { url = "https://files.pythonhosted.org/packages/56/15/064b197e855bfb7b343210e82490ae672f8bc7cdf3ddb02e92f64304ee8a/xxhash-3.6.0-cp313-cp313t-win_arm64.whl", hash = "sha256:ec44b73a4220623235f67a996c862049f375df3b1052d9899f40a6382c32d746", size = 28044, upload-time = "2025-10-02T14:35:37.195Z" }, +] + +[[package]] +name = "yarl" +version = "1.22.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "multidict" }, + { name = "propcache" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/57/63/0c6ebca57330cd313f6102b16dd57ffaf3ec4c83403dcb45dbd15c6f3ea1/yarl-1.22.0.tar.gz", hash = "sha256:bebf8557577d4401ba8bd9ff33906f1376c877aa78d1fe216ad01b4d6745af71", size = 187169, upload-time = "2025-10-06T14:12:55.963Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/f3/d67de7260456ee105dc1d162d43a019ecad6b91e2f51809d6cddaa56690e/yarl-1.22.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8dee9c25c74997f6a750cd317b8ca63545169c098faee42c84aa5e506c819b53", size = 139980, upload-time = "2025-10-06T14:10:14.601Z" }, + { url = "https://files.pythonhosted.org/packages/01/88/04d98af0b47e0ef42597b9b28863b9060bb515524da0a65d5f4db160b2d5/yarl-1.22.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:01e73b85a5434f89fc4fe27dcda2aff08ddf35e4d47bbbea3bdcd25321af538a", size = 93424, upload-time = "2025-10-06T14:10:16.115Z" }, + { url = "https://files.pythonhosted.org/packages/18/91/3274b215fd8442a03975ce6bee5fe6aa57a8326b29b9d3d56234a1dca244/yarl-1.22.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:22965c2af250d20c873cdbee8ff958fb809940aeb2e74ba5f20aaf6b7ac8c70c", size = 93821, upload-time = "2025-10-06T14:10:17.993Z" }, + { url = "https://files.pythonhosted.org/packages/61/3a/caf4e25036db0f2da4ca22a353dfeb3c9d3c95d2761ebe9b14df8fc16eb0/yarl-1.22.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b4f15793aa49793ec8d1c708ab7f9eded1aa72edc5174cae703651555ed1b601", size = 373243, upload-time = "2025-10-06T14:10:19.44Z" }, + { url = "https://files.pythonhosted.org/packages/6e/9e/51a77ac7516e8e7803b06e01f74e78649c24ee1021eca3d6a739cb6ea49c/yarl-1.22.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:e5542339dcf2747135c5c85f68680353d5cb9ffd741c0f2e8d832d054d41f35a", size = 342361, upload-time = "2025-10-06T14:10:21.124Z" }, + { url = "https://files.pythonhosted.org/packages/d4/f8/33b92454789dde8407f156c00303e9a891f1f51a0330b0fad7c909f87692/yarl-1.22.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:5c401e05ad47a75869c3ab3e35137f8468b846770587e70d71e11de797d113df", size = 387036, upload-time = "2025-10-06T14:10:22.902Z" }, + { url = "https://files.pythonhosted.org/packages/d9/9a/c5db84ea024f76838220280f732970aa4ee154015d7f5c1bfb60a267af6f/yarl-1.22.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:243dda95d901c733f5b59214d28b0120893d91777cb8aa043e6ef059d3cddfe2", size = 397671, upload-time = "2025-10-06T14:10:24.523Z" }, + { url = "https://files.pythonhosted.org/packages/11/c9/cd8538dc2e7727095e0c1d867bad1e40c98f37763e6d995c1939f5fdc7b1/yarl-1.22.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bec03d0d388060058f5d291a813f21c011041938a441c593374da6077fe21b1b", size = 377059, upload-time = "2025-10-06T14:10:26.406Z" }, + { url = "https://files.pythonhosted.org/packages/a1/b9/ab437b261702ced75122ed78a876a6dec0a1b0f5e17a4ac7a9a2482d8abe/yarl-1.22.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:b0748275abb8c1e1e09301ee3cf90c8a99678a4e92e4373705f2a2570d581273", size = 365356, upload-time = "2025-10-06T14:10:28.461Z" }, + { url = "https://files.pythonhosted.org/packages/b2/9d/8e1ae6d1d008a9567877b08f0ce4077a29974c04c062dabdb923ed98e6fe/yarl-1.22.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:47fdb18187e2a4e18fda2c25c05d8251a9e4a521edaed757fef033e7d8498d9a", size = 361331, upload-time = "2025-10-06T14:10:30.541Z" }, + { url = "https://files.pythonhosted.org/packages/ca/5a/09b7be3905962f145b73beb468cdd53db8aa171cf18c80400a54c5b82846/yarl-1.22.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:c7044802eec4524fde550afc28edda0dd5784c4c45f0be151a2d3ba017daca7d", size = 382590, upload-time = "2025-10-06T14:10:33.352Z" }, + { url = "https://files.pythonhosted.org/packages/aa/7f/59ec509abf90eda5048b0bc3e2d7b5099dffdb3e6b127019895ab9d5ef44/yarl-1.22.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:139718f35149ff544caba20fce6e8a2f71f1e39b92c700d8438a0b1d2a631a02", size = 385316, upload-time = "2025-10-06T14:10:35.034Z" }, + { url = "https://files.pythonhosted.org/packages/e5/84/891158426bc8036bfdfd862fabd0e0fa25df4176ec793e447f4b85cf1be4/yarl-1.22.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e1b51bebd221006d3d2f95fbe124b22b247136647ae5dcc8c7acafba66e5ee67", size = 374431, upload-time = "2025-10-06T14:10:37.76Z" }, + { url = "https://files.pythonhosted.org/packages/bb/49/03da1580665baa8bef5e8ed34c6df2c2aca0a2f28bf397ed238cc1bbc6f2/yarl-1.22.0-cp313-cp313-win32.whl", hash = "sha256:d3e32536234a95f513bd374e93d717cf6b2231a791758de6c509e3653f234c95", size = 81555, upload-time = "2025-10-06T14:10:39.649Z" }, + { url = "https://files.pythonhosted.org/packages/9a/ee/450914ae11b419eadd067c6183ae08381cfdfcb9798b90b2b713bbebddda/yarl-1.22.0-cp313-cp313-win_amd64.whl", hash = "sha256:47743b82b76d89a1d20b83e60d5c20314cbd5ba2befc9cda8f28300c4a08ed4d", size = 86965, upload-time = "2025-10-06T14:10:41.313Z" }, + { url = "https://files.pythonhosted.org/packages/98/4d/264a01eae03b6cf629ad69bae94e3b0e5344741e929073678e84bf7a3e3b/yarl-1.22.0-cp313-cp313-win_arm64.whl", hash = "sha256:5d0fcda9608875f7d052eff120c7a5da474a6796fe4d83e152e0e4d42f6d1a9b", size = 81205, upload-time = "2025-10-06T14:10:43.167Z" }, + { url = "https://files.pythonhosted.org/packages/88/fc/6908f062a2f77b5f9f6d69cecb1747260831ff206adcbc5b510aff88df91/yarl-1.22.0-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:719ae08b6972befcba4310e49edb1161a88cdd331e3a694b84466bd938a6ab10", size = 146209, upload-time = "2025-10-06T14:10:44.643Z" }, + { url = "https://files.pythonhosted.org/packages/65/47/76594ae8eab26210b4867be6f49129861ad33da1f1ebdf7051e98492bf62/yarl-1.22.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:47d8a5c446df1c4db9d21b49619ffdba90e77c89ec6e283f453856c74b50b9e3", size = 95966, upload-time = "2025-10-06T14:10:46.554Z" }, + { url = "https://files.pythonhosted.org/packages/ab/ce/05e9828a49271ba6b5b038b15b3934e996980dd78abdfeb52a04cfb9467e/yarl-1.22.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:cfebc0ac8333520d2d0423cbbe43ae43c8838862ddb898f5ca68565e395516e9", size = 97312, upload-time = "2025-10-06T14:10:48.007Z" }, + { url = "https://files.pythonhosted.org/packages/d1/c5/7dffad5e4f2265b29c9d7ec869c369e4223166e4f9206fc2243ee9eea727/yarl-1.22.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4398557cbf484207df000309235979c79c4356518fd5c99158c7d38203c4da4f", size = 361967, upload-time = "2025-10-06T14:10:49.997Z" }, + { url = "https://files.pythonhosted.org/packages/50/b2/375b933c93a54bff7fc041e1a6ad2c0f6f733ffb0c6e642ce56ee3b39970/yarl-1.22.0-cp313-cp313t-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2ca6fd72a8cd803be290d42f2dec5cdcd5299eeb93c2d929bf060ad9efaf5de0", size = 323949, upload-time = "2025-10-06T14:10:52.004Z" }, + { url = "https://files.pythonhosted.org/packages/66/50/bfc2a29a1d78644c5a7220ce2f304f38248dc94124a326794e677634b6cf/yarl-1.22.0-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ca1f59c4e1ab6e72f0a23c13fca5430f889634166be85dbf1013683e49e3278e", size = 361818, upload-time = "2025-10-06T14:10:54.078Z" }, + { url = "https://files.pythonhosted.org/packages/46/96/f3941a46af7d5d0f0498f86d71275696800ddcdd20426298e572b19b91ff/yarl-1.22.0-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:6c5010a52015e7c70f86eb967db0f37f3c8bd503a695a49f8d45700144667708", size = 372626, upload-time = "2025-10-06T14:10:55.767Z" }, + { url = "https://files.pythonhosted.org/packages/c1/42/8b27c83bb875cd89448e42cd627e0fb971fa1675c9ec546393d18826cb50/yarl-1.22.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d7672ecf7557476642c88497c2f8d8542f8e36596e928e9bcba0e42e1e7d71f", size = 341129, upload-time = "2025-10-06T14:10:57.985Z" }, + { url = "https://files.pythonhosted.org/packages/49/36/99ca3122201b382a3cf7cc937b95235b0ac944f7e9f2d5331d50821ed352/yarl-1.22.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:3b7c88eeef021579d600e50363e0b6ee4f7f6f728cd3486b9d0f3ee7b946398d", size = 346776, upload-time = "2025-10-06T14:10:59.633Z" }, + { url = "https://files.pythonhosted.org/packages/85/b4/47328bf996acd01a4c16ef9dcd2f59c969f495073616586f78cd5f2efb99/yarl-1.22.0-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:f4afb5c34f2c6fecdcc182dfcfc6af6cccf1aa923eed4d6a12e9d96904e1a0d8", size = 334879, upload-time = "2025-10-06T14:11:01.454Z" }, + { url = "https://files.pythonhosted.org/packages/c2/ad/b77d7b3f14a4283bffb8e92c6026496f6de49751c2f97d4352242bba3990/yarl-1.22.0-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:59c189e3e99a59cf8d83cbb31d4db02d66cda5a1a4374e8a012b51255341abf5", size = 350996, upload-time = "2025-10-06T14:11:03.452Z" }, + { url = "https://files.pythonhosted.org/packages/81/c8/06e1d69295792ba54d556f06686cbd6a7ce39c22307100e3fb4a2c0b0a1d/yarl-1.22.0-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:5a3bf7f62a289fa90f1990422dc8dff5a458469ea71d1624585ec3a4c8d6960f", size = 356047, upload-time = "2025-10-06T14:11:05.115Z" }, + { url = "https://files.pythonhosted.org/packages/4b/b8/4c0e9e9f597074b208d18cef227d83aac36184bfbc6eab204ea55783dbc5/yarl-1.22.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:de6b9a04c606978fdfe72666fa216ffcf2d1a9f6a381058d4378f8d7b1e5de62", size = 342947, upload-time = "2025-10-06T14:11:08.137Z" }, + { url = "https://files.pythonhosted.org/packages/e0/e5/11f140a58bf4c6ad7aca69a892bff0ee638c31bea4206748fc0df4ebcb3a/yarl-1.22.0-cp313-cp313t-win32.whl", hash = "sha256:1834bb90991cc2999f10f97f5f01317f99b143284766d197e43cd5b45eb18d03", size = 86943, upload-time = "2025-10-06T14:11:10.284Z" }, + { url = "https://files.pythonhosted.org/packages/31/74/8b74bae38ed7fe6793d0c15a0c8207bbb819cf287788459e5ed230996cdd/yarl-1.22.0-cp313-cp313t-win_amd64.whl", hash = "sha256:ff86011bd159a9d2dfc89c34cfd8aff12875980e3bd6a39ff097887520e60249", size = 93715, upload-time = "2025-10-06T14:11:11.739Z" }, + { url = "https://files.pythonhosted.org/packages/69/66/991858aa4b5892d57aef7ee1ba6b4d01ec3b7eb3060795d34090a3ca3278/yarl-1.22.0-cp313-cp313t-win_arm64.whl", hash = "sha256:7861058d0582b847bc4e3a4a4c46828a410bca738673f35a29ba3ca5db0b473b", size = 83857, upload-time = "2025-10-06T14:11:13.586Z" }, + { url = "https://files.pythonhosted.org/packages/73/ae/b48f95715333080afb75a4504487cbe142cae1268afc482d06692d605ae6/yarl-1.22.0-py3-none-any.whl", hash = "sha256:1380560bdba02b6b6c90de54133c81c9f2a453dee9912fe58c1dcced1edb7cff", size = 46814, upload-time = "2025-10-06T14:12:53.872Z" }, +] From 84816951c2a302668733dc3d907d65b597ded35c Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:27:50 +0000 Subject: [PATCH 76/77] format and type fixes --- spd/clustering/ci_dt/attn.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/spd/clustering/ci_dt/attn.py b/spd/clustering/ci_dt/attn.py index 31b3a38c2..82f1f2736 100644 --- a/spd/clustering/ci_dt/attn.py +++ b/spd/clustering/ci_dt/attn.py @@ -25,6 +25,7 @@ # ----------------------- configuration ----------------------- config = CIDTConfig( + wandb_run_path="wandb:goodfire/spd/runs/lxs77xye", batch_size=16, n_batches=4, activation_threshold=0.01, @@ -36,14 +37,12 @@ # %% # ----------------------- load model ----------------------- -wandb_run_path: str = "wandb:goodfire/spd/runs/lxs77xye" - -spd_run: SPDRunInfo = SPDRunInfo.from_path(wandb_run_path) +spd_run: SPDRunInfo = SPDRunInfo.from_path(config.wandb_run_path) model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) model.to(device) cfg: Config = spd_run.config -print(f"Loaded model from {wandb_run_path}") +print(f"Loaded model from {config.wandb_run_path}") # %% # ----------------------- load dataset ----------------------- From b7b6f4b4de266847dd49e49b4b6b8e8905bc43a4 Mon Sep 17 00:00:00 2001 From: Michael Ivanitskiy Date: Wed, 29 Oct 2025 10:30:37 +0000 Subject: [PATCH 77/77] more type fixes --- spd/clustering/ci_dt/minimal_run.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/spd/clustering/ci_dt/minimal_run.py b/spd/clustering/ci_dt/minimal_run.py index 9411cf7b7..dea2f76f7 100644 --- a/spd/clustering/ci_dt/minimal_run.py +++ b/spd/clustering/ci_dt/minimal_run.py @@ -118,8 +118,10 @@ models: list[tuple[int, MultiOutputClassifier]] = [] for k in tqdm(range(1, len(layers)), desc="Training"): - X = np.concatenate(layers[:k], axis=1) if k > 0 else np.zeros((layers[0].shape[0], 0), bool) - Y = layers[k] + X_prev_layers_cis = ( + np.concatenate(layers[:k], axis=1) if k > 0 else np.zeros((layers[0].shape[0], 0), bool) + ) + Y_current_layer_cis = layers[k] clf = MultiOutputClassifier( DecisionTreeClassifier( @@ -128,7 +130,7 @@ random_state=RANDOM_STATE, ) ) - clf.fit(X.astype(np.uint8), Y.astype(np.uint8)) + clf.fit(X_prev_layers_cis.astype(np.uint8), Y_current_layer_cis.astype(np.uint8)) models.append((k, clf)) # %% ----------------------- Compute Metrics ----------------------- @@ -139,7 +141,7 @@ def extract_prob_class_1(proba_list: list[np.ndarray], clf: MultiOutputClassifie """Extract P(y=1) for each output.""" result: list[np.ndarray] = [] for i, p in enumerate(proba_list): - estimator = clf.estimators_[i] # type: ignore + estimator = clf.estimators_[i] # pyright: ignore[reportIndexIssue] assert isinstance(estimator, DecisionTreeClassifier) assert len(estimator.classes_) == 2 result.append(p[:, 1]) # P(y=1) @@ -165,11 +167,11 @@ def tree_to_dict(tree: DecisionTreeClassifier) -> dict[str, Any]: print("\nPer-layer metrics:") for layer_idx, clf in models: # Prepare X, Y for this layer - X = np.concatenate(layers[:layer_idx], axis=1) - Y = layers[layer_idx] + X_prev_layers_cis = np.concatenate(layers[:layer_idx], axis=1) + Y_current_layer_cis = layers[layer_idx] # Predict - proba_list = clf.predict_proba(X.astype(np.uint8)) # type: ignore + proba_list = clf.predict_proba(X_prev_layers_cis.astype(np.uint8)) # type: ignore P = extract_prob_class_1(proba_list, clf) Y_pred = P >= 0.5 @@ -178,12 +180,12 @@ def tree_to_dict(tree: DecisionTreeClassifier) -> dict[str, Any]: acc_scores: list[float] = [] bacc_scores: list[float] = [] - for j in range(Y.shape[1]): - y_true = Y[:, j].astype(int) + for j in range(Y_current_layer_cis.shape[1]): + y_true = Y_current_layer_cis[:, j].astype(int) y_prob = P[:, j] y_pred = Y_pred[:, j].astype(int) - ap_scores.append(average_precision_score(y_true, y_prob)) + ap_scores.append(average_precision_score(y_true, y_prob)) # pyright: ignore[reportArgumentType] acc_scores.append(accuracy_score(y_true, y_pred)) bacc_scores.append(balanced_accuracy_score(y_true, y_pred)) @@ -194,7 +196,7 @@ def tree_to_dict(tree: DecisionTreeClassifier) -> dict[str, Any]: print(f" Mean BAcc: {np.mean(bacc_scores):.3f}") # Store results with tree structures - trees_data = [tree_to_dict(est) for est in clf.estimators_] # type: ignore + trees_data = [tree_to_dict(est) for est in clf.estimators_] # pyright: ignore[reportArgumentType] results.append( {