diff --git a/.gitignore b/.gitignore index ad6cf6670..64704d4e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ spd/scripts/sweep_params.yaml -spd/scripts/sweep_params.yaml docs/coverage/** +artifacts/** +docs/dep_graph/** +tests/.temp/** **/out/ neuronpedia_outputs/ diff --git a/.vscode/launch.json b/.vscode/launch.json index 75c8edbb2..eb19f182a 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -230,6 +230,38 @@ "--model_path", "wandb:goodfire/spd/runs/ioprgffh" ] + }, + { + "name": "run_clustering example", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/clustering/scripts/run_clustering.py", + "args": [ + "--config", + "${workspaceFolder}/spd/clustering/configs/example.yaml", + ], + "python": "${command:python.interpreterPath}", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } + }, + { + "name": "clustering pipeline", + "type": "debugpy", + "request": "launch", + "program": "${workspaceFolder}/spd/clustering/scripts/run_pipeline.py", + "args": [ + "--config", + "${workspaceFolder}/spd/clustering/configs/pipeline_config.yaml", + ], + "python": "${command:python.interpreterPath}", + "console": "integratedTerminal", + "justMyCode": true, + "env": { + "PYDEVD_DISABLE_FILE_VALIDATION": "1" + } } ] } \ No newline at end of file diff --git a/Makefile b/Makefile index ff8ab8955..4cc60fe44 100644 --- a/Makefile +++ b/Makefile @@ -76,10 +76,23 @@ coverage: uv run python -m coverage report -m > $(COVERAGE_DIR)/coverage.txt uv run python -m coverage html --directory=$(COVERAGE_DIR)/html/ + +.PHONY: clean +clean: + @echo "Cleaning Python cache and build artifacts..." + find . -type d -name "__pycache__" -exec rm -rf {} + + find . -type d -name "*.egg-info" -exec rm -rf {} + + rm -rf build/ dist/ .ruff_cache/ .pytest_cache/ .coverage + + +.PHONY: clustering-dev +clustering-dev: + uv run spd-cluster --local --config spd/clustering/configs/pipeline-dev-simplestories.yaml + .PHONY: app app: @uv run python app/run_app.py .PHONY: install-app install-app: - (cd app/frontend && npm install) \ No newline at end of file + (cd app/frontend && npm install) diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..9e6f14815 --- /dev/null +++ b/TODO.md @@ -0,0 +1,73 @@ +# TODO: Cluster Coactivation Matrix Implementation + +## What Was Changed + +### 1. Added `ClusterActivations` dataclass (`spd/clustering/dashboard/compute_max_act.py`) +- New dataclass to hold vectorized cluster activations for all clusters +- Contains `activations` tensor [n_samples, n_clusters] and `cluster_indices` list + +### 2. Added `compute_all_cluster_activations()` function +- Vectorized computation of all cluster activations at once +- Replaces the per-cluster loop for better performance +- Returns `ClusterActivations` object + +### 3. Added `compute_cluster_coactivations()` function +- Computes coactivation matrix from list of `ClusterActivations` across batches +- Binarizes activations (acts > 0) and computes matrix multiplication: `activation_mask.T @ activation_mask` +- Follows the pattern from `spd/clustering/merge.py:69` +- Returns tuple of (coactivation_matrix, cluster_indices) + +### 4. Modified `compute_max_activations()` function +- Now accumulates `ClusterActivations` from each batch in `all_cluster_activations` list +- Calls `compute_cluster_coactivations()` to compute the matrix +- **Changed return type**: now returns `tuple[DashboardData, np.ndarray, list[int]]` + - Added coactivation matrix and cluster_indices to return value + +### 5. Modified `spd/clustering/dashboard/run.py` +- Updated to handle new return value from `compute_max_activations()` +- Saves coactivation matrix as `coactivations.npz` in the dashboard output directory +- NPZ file contains: + - `coactivations`: the [n_clusters, n_clusters] matrix + - `cluster_indices`: array mapping matrix positions to cluster IDs + +## What Needs to be Checked + +### Testing +- [ ] **Run the dashboard pipeline** on a real clustering run to verify: + - Coactivation computation doesn't crash + - Coactivations are saved correctly to NPZ file + - Matrix dimensions are correct + - `cluster_indices` mapping is correct + +### Type Checking +- [ ] Run `make type` to ensure no type errors were introduced +- [ ] Verify jaxtyping annotations are correct + +### Verification +- [ ] Load a saved `coactivations.npz` file and verify: + ```python + data = np.load("coactivations.npz") + coact = data["coactivations"] + cluster_indices = data["cluster_indices"] + # Check: coact should be symmetric + # Check: diagonal should be >= off-diagonal (clusters coactivate with themselves most) + # Check: cluster_indices length should match coact.shape[0] + ``` + +### Performance +- [ ] Check if vectorization actually improved performance +- [ ] Monitor memory usage with large numbers of clusters + +### Edge Cases +- [ ] Test with clusters that have zero activations +- [ ] Test with single-batch runs +- [ ] Test with very large number of clusters + +### Integration +- [ ] Verify the coactivation matrix can be used in downstream analysis +- [ ] Consider if visualization of coactivations should be added to dashboard + +## Notes +- The coactivation matrix is computed over all samples processed (n_batches * batch_size * seq_len samples) +- Binarization threshold is currently hardcoded as `> 0` - may want to make this configurable +- The computation happens in the dashboard pipeline, NOT during the main clustering pipeline diff --git a/pyproject.toml b/pyproject.toml index 62393b9e7..cd8ab45be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ dependencies = [ # see: https://github.com/huggingface/datasets/issues/6980 https://github.com/huggingface/datasets/pull/6991 (fixed in https://github.com/huggingface/datasets/releases/tag/2.21.0 ) "datasets>=2.21.0", "simple_stories_train @ git+https://github.com/goodfire-ai/simple_stories_train.git@dev", + "scipy>=1.14.1", + "muutils", + "scikit-learn", "fastapi", "uvicorn", ] @@ -40,10 +43,12 @@ dev = [ "ruff", "basedpyright<1.32.0", # pyright and wandb issues, see https://github.com/goodfire-ai/spd/pull/232 "pre-commit", + "nbconvert", ] [project.scripts] spd-run = "spd.scripts.run:cli" +spd-cluster = "spd.clustering.scripts.run_pipeline:cli" [build-system] requires = ["setuptools", "wheel"] diff --git a/spd/base_config.py b/spd/base_config.py index c9b488e19..860898907 100644 --- a/spd/base_config.py +++ b/spd/base_config.py @@ -6,6 +6,14 @@ from pydantic import BaseModel, ConfigDict +class FileTypeError(ValueError): + """Error raised when a file has an unsupported type/extension.""" + + +class ConfigValidationError(ValueError): + """Error raised when a config file fails pydantic validation.""" + + class BaseConfig(BaseModel): """Pydantic BaseModel suited for configs. @@ -15,6 +23,8 @@ class BaseConfig(BaseModel): model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid", frozen=True) + # TODO: add a "config_type" field, which is set to the class name, so that when loading a config we can check whether the config type matches the expected class + @classmethod def from_file(cls, path: Path | str) -> Self: """Load config from path to a JSON or YAML file.""" @@ -27,9 +37,16 @@ def from_file(cls, path: Path | str) -> Self: case Path() if path.suffix in [".yaml", ".yml"]: data = yaml.safe_load(path.read_text()) case _: - raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}") + raise FileTypeError(f"Only (.json, .yaml, .yml) files are supported, got {path}") + + try: + cfg = cls.model_validate(data) + except Exception as e: + raise ConfigValidationError( + f"Error validating config {cls=} from path `{path.as_posix()}`\n{data = }" + ) from e - return cls.model_validate(data) + return cfg def to_file(self, path: Path | str) -> None: """Save config to file (format inferred from extension).""" @@ -43,4 +60,4 @@ def to_file(self, path: Path | str) -> None: case ".yaml" | ".yml": path.write_text(yaml.dump(self.model_dump(mode="json"))) case _: - raise ValueError(f"Only (.json, .yaml, .yml) files are supported, got {path}") + raise FileTypeError(f"Only (.json, .yaml, .yml) files are supported, got {path}") diff --git a/spd/clustering/__init__.py b/spd/clustering/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/clustering/activations.py b/spd/clustering/activations.py new file mode 100644 index 000000000..cd6a2b742 --- /dev/null +++ b/spd/clustering/activations.py @@ -0,0 +1,267 @@ +from dataclasses import dataclass +from functools import cached_property +from typing import Literal, NamedTuple + +import torch +from jaxtyping import Bool, Float, Float16, Int +from torch import Tensor + +from spd.clustering.consts import ( + ActivationsTensor, + BoolActivationsTensor, + ClusterCoactivationShaped, + ComponentLabels, +) +from spd.clustering.util import ModuleFilterFunc +from spd.models.component_model import ComponentModel, OutputWithCache + + +def component_activations( + model: ComponentModel, + device: torch.device | str, + batch: Int[Tensor, "batch_size n_ctx"], +) -> dict[str, ActivationsTensor]: + """Get the component activations over a **single** batch.""" + causal_importances: dict[str, ActivationsTensor] + with torch.no_grad(): + model_output: OutputWithCache = model( + batch.to(device), + cache_type="input", + ) + + # TODO: !!!IMPORTANT!!! unclear what the right thing from CIOutputs is + causal_importances = model.calc_causal_importances( + pre_weight_acts=model_output.cache, + sampling="continuous", + detach_inputs=False, + ).upper_leaky + + return causal_importances + + +def compute_coactivatons( + activations: ActivationsTensor | BoolActivationsTensor, +) -> ClusterCoactivationShaped: + """Compute the coactivations matrix from the activations.""" + # TODO: this works for both boolean and continuous activations, + # but we could do better by just using OR for boolean activations + # and maybe even some bitshift hacks. but for now, we convert to float16 + activations_f16: Float16[Tensor, "samples C"] = activations.to(torch.float16) + return activations_f16.T @ activations_f16 + + +class FilteredActivations(NamedTuple): + activations: ActivationsTensor + "activations after filtering dead components" + + labels: ComponentLabels + "list of length c with labels for each preserved component" + + dead_components_labels: ComponentLabels | None + "list of labels for dead components, or None if no filtering was applied" + + @property + def n_alive(self) -> int: + """Number of alive components after filtering.""" + n_alive: int = len(self.labels) + assert n_alive == self.activations.shape[1], ( + f"{n_alive = } != {self.activations.shape[1] = }" + ) + return n_alive + + @property + def n_dead(self) -> int: + """Number of dead components after filtering.""" + return len(self.dead_components_labels) if self.dead_components_labels else 0 + + +def filter_dead_components( + activations: ActivationsTensor, + labels: ComponentLabels, + filter_dead_threshold: float = 0.01, +) -> FilteredActivations: + """Filter out dead components based on a threshold + + if `filter_dead_threshold` is 0, no filtering is applied. + activations and labels are returned as is, `dead_components_labels` is `None`. + + otherwise, components whose **maximum** activations across all samples is below the threshold + are considered dead and filtered out. The labels of these components are returned in `dead_components_labels`. + `dead_components_labels` will also be `None` if no components were below the threshold. + """ + dead_components_lst: ComponentLabels | None = None + if filter_dead_threshold > 0: + dead_components_lst = ComponentLabels(list()) + max_act: Float[Tensor, " c"] = activations.max(dim=0).values + dead_components: Bool[Tensor, " c"] = max_act < filter_dead_threshold + + if dead_components.any(): + activations = activations[:, ~dead_components] + alive_labels: list[tuple[str, bool]] = [ + (lbl, bool(keep.item())) + for lbl, keep in zip(labels, ~dead_components, strict=False) + ] + # re-assign labels only if we are filtering + labels = ComponentLabels([label for label, keep in alive_labels if keep]) + dead_components_lst = ComponentLabels( + [label for label, keep in alive_labels if not keep] + ) + + return FilteredActivations( + activations=activations, + labels=labels, + dead_components_labels=dead_components_lst if dead_components_lst else None, + ) + + +@dataclass(frozen=True) +class ProcessedActivations: + """Processed activations after filtering and concatenation""" + + activations_raw: dict[str, ActivationsTensor] + "activations after filtering, but prior to concatenation" + + activations: ActivationsTensor + "activations after filtering and concatenation" + + labels: ComponentLabels + "list of length c with labels for each preserved component, format `{module_name}:{component_index}`" + + dead_components_lst: ComponentLabels | None + "list of labels for dead components, or None if no filtering was applied" + + def validate(self) -> None: + """Validate the processed activations""" + # getting this property will also perform a variety of other checks + assert self.n_components_alive > 0 + + @property + def n_components_original(self) -> int: + """Total number of components before filtering. equal to the sum of all components in `activations_raw`, or to `n_components_alive + n_components_dead`""" + return sum(act.shape[1] for act in self.activations_raw.values()) + + @property + def n_components_alive(self) -> int: + """Number of alive components after filtering. equal to the length of `labels`""" + n_alive: int = len(self.labels) + assert n_alive + self.n_components_dead == self.n_components_original, ( + f"({n_alive = }) + ({self.n_components_dead = }) != ({self.n_components_original = })" + ) + assert n_alive == self.activations.shape[1], ( + f"{n_alive = } != {self.activations.shape[1] = }" + ) + + return n_alive + + @property + def n_components_dead(self) -> int: + """Number of dead components after filtering. equal to the length of `dead_components_lst` if it is not None, or 0 otherwise""" + return len(self.dead_components_lst) if self.dead_components_lst else 0 + + @cached_property + def label_index(self) -> dict[str, int | None]: + """Create a mapping from label to alive index (`None` if dead)""" + return { + **{label: i for i, label in enumerate(self.labels)}, + **( + {label: None for label in self.dead_components_lst} + if self.dead_components_lst + else {} + ), + } + + def get_label_index(self, label: str) -> int | None: + """Get the index of a label in the activations, or None if it is dead""" + return self.label_index[label] + + def get_label_index_alive(self, label: str) -> int: + """Get the index of a label in the activations, or raise if it is dead""" + idx: int | None = self.get_label_index(label) + if idx is None: + raise ValueError(f"Label '{label}' is dead and has no index in the activations.") + return idx + + @property + def module_keys(self) -> list[str]: + """Get the module keys from the activations_raw""" + return list(self.activations_raw.keys()) + + def get_module_indices(self, module_key: str) -> list[int | None]: + """given a module key, return a list len "num components in that moduel", with int index in alive components, or None if dead""" + num_components: int = self.activations_raw[module_key].shape[1] + return [self.label_index[f"{module_key}:{i}"] for i in range(num_components)] + + +def process_activations( + activations: dict[ + str, # module name to + Float[Tensor, "samples C"] # (sample x component gate activations) + | Float[Tensor, " n_sample n_ctx C"], # (sample x seq index x component gate activations) + ], + filter_dead_threshold: float = 0.01, + seq_mode: Literal["concat", "seq_mean", None] = None, + filter_modules: ModuleFilterFunc | None = None, +) -> ProcessedActivations: + """get back a dict of coactivations, slices, and concated activations + + Args: + activations: Dictionary of activations by module + filter_dead_threshold: Threshold for filtering dead components + seq_mode: How to handle sequence dimension + filter_modules: Function to filter modules + sort_components: Whether to sort components by similarity within each module + """ + + # reshape -- special cases for llms + # ============================================================ + activations_: dict[str, ActivationsTensor] + if seq_mode == "concat": + # Concatenate the sequence dimension into the sample dimension + activations_ = { + key: act.reshape(act.shape[0] * act.shape[1], act.shape[2]) + for key, act in activations.items() + } + elif seq_mode == "seq_mean": + # Take the mean over the sequence dimension + activations_ = { + key: act.mean(dim=1) if act.ndim == 3 else act for key, act in activations.items() + } + else: + # Use the activations as they are + activations_ = activations + + # put the labelled activations into one big matrix and filter them + # ============================================================ + + # filter activations for only the modules we want + if filter_modules is not None: + activations_ = {key: act for key, act in activations_.items() if filter_modules(key)} + + # compute the labels and total component count + total_c: int = 0 + labels: ComponentLabels = ComponentLabels(list()) + for key, act in activations_.items(): + c: int = act.shape[-1] + labels.extend([f"{key}:{i}" for i in range(c)]) + total_c += c + + # concat the activations + act_concat: ActivationsTensor = torch.cat([activations_[key] for key in activations_], dim=-1) + + # filter dead components + filtered_components: FilteredActivations = filter_dead_components( + activations=act_concat, + labels=labels, + filter_dead_threshold=filter_dead_threshold, + ) + + assert filtered_components.n_alive + filtered_components.n_dead == total_c, ( + f"({filtered_components.n_alive = }) + ({filtered_components.n_dead = }) != ({total_c = })" + ) + + return ProcessedActivations( + activations_raw=activations_, + activations=filtered_components.activations, + labels=filtered_components.labels, + dead_components_lst=filtered_components.dead_components_labels, + ) diff --git a/spd/clustering/ci_dt/VISUALIZATION_PLAN.md b/spd/clustering/ci_dt/VISUALIZATION_PLAN.md new file mode 100644 index 000000000..9a486e484 --- /dev/null +++ b/spd/clustering/ci_dt/VISUALIZATION_PLAN.md @@ -0,0 +1,958 @@ +# CI Decision Tree Visualization Plan + +## Overview + +This document outlines the complete visualization strategy for causal importance decision trees, including static plots (matplotlib/PDF) and interactive visualizations (HTML/JS). + +--- + +## Part 1: Static Plot Improvements + +### 1.1 Layer Metrics - Distribution Plots + +**Current:** Bar charts for mean AP, accuracy, balanced accuracy per layer + +**New:** Scatter plots with horizontal jitter showing full distribution per layer + +**Implementation:** +- Replace `plot_layer_metrics()` bar charts with jittered scatter plots +- For each layer, show all target component metrics as points with random horizontal jitter +- Add mean/median line overlays +- Better titles explaining metrics in terms of confusion matrix: + +```python +# Accuracy title +r"Accuracy per Target Component\n" + +r"$\text{Accuracy} = \frac{TP + TN}{TP + TN + FP + FN}$" + +# Balanced Accuracy title +r"Balanced Accuracy per Target Component\n" + +r"$\text{Balanced Acc} = \frac{1}{2}\left(\frac{TP}{TP+FN} + \frac{TN}{TN+FP}\right)$" + +# Average Precision title +r"Average Precision per Target Component\n" + +r"$\text{AP} = \sum_n (R_n - R_{n-1}) P_n$" + "\n" + +r"where $P_n = \frac{TP}{TP+FP}$ (precision), $R_n = \frac{TP}{TP+FN}$ (recall)" +``` + +**Rationale:** Shows full distribution of performance across targets, not just means. More informative about variance in tree quality. + +--- + +### 1.2 AP vs Prevalence Plot + +**Current:** Simple scatter plot with alpha=0.6 + +**New Improvements:** +1. **Log x-axis** for prevalence (many rare components) +2. **No marker edges** (set `edgecolors='none'`) +3. **Color by tree depth** using viridis colormap +4. **Enhanced title:** + ```python + r"Average Precision vs Component Prevalence\n" + + r"Prevalence = $\frac{n_\text{active samples}}{n_\text{total samples}}$" + ``` + +**Additional:** Add heatmap version (see 1.3 below) + +--- + +### 1.3 Tree Statistics - New Heatmaps + +**Current:** Has depth vs accuracy, leaf count vs accuracy, depth vs leaf count heatmaps + +**New Addition:** AP vs Prevalence heatmap + +**Implementation:** +- Add new heatmap to `plot_tree_statistics()`: + - x-axis: prevalence bins (log scale, e.g. [0.001, 0.01, 0.1, 0.5, 1.0]) + - y-axis: AP bins (linear, 0 to 1) + - color: log10(count + 1) as in existing heatmaps + - title: + ```python + r"Tree Performance vs Component Prevalence\n" + + r"AP = Average Precision, Prev = $\frac{n_\text{active}}{n_\text{total}}$" + ``` + +**Rationale:** Complements the scatter plot; easier to see density patterns. + +--- + +### 1.4 Global Title Improvements + +**Rules:** +- Use LaTeX notation via raw strings: `r"$\text{TP}$"` not unicode "TP" +- Use `\n` for line breaks in long titles +- Explain abbreviations and formulas +- Be explicit about what's plotted + +**Examples:** + +```python +# Before +"Covariance of components (all layers)" + +# After +r"Component Coactivation Matrix\n" + +r"$\text{Cov}(i,j) = \mathbb{E}[(A_i - \mu_i)(A_j - \mu_j)]$\n" + +r"where $A_i$ is binary activation of component $i$" + +# Before +"Tree depth" + +# After +r"Distribution of Decision Tree Depths\n" + +r"(Depth = longest path from root to leaf)" + +# Before +"Activations (True)" + +# After +r"True Binary Activations\n" + +r"$A_{ij} = \mathbb{1}[\text{activation}_{ij} > \theta]$, $\theta = $" + f"{config.activation_threshold}" +``` + +--- + +### 1.5 Activations Plot - Sorting and Diff + +**Current:** Two subplots (true, predicted) with no ordering + +**New Architecture:** + +``` +plot_activations_unsorted(...) # Original style with layer boundaries +plot_activations_sorted(...) # New sorted version with diff +``` + +#### 1.5.1 Unsorted Version (Enhanced) + +**Changes:** +- Add layer boundary lines and labels (borrow from `spd/clustering/plotting/activations.py:add_component_labeling()`) +- Show module names on y-axis (component dimension) +- Keep samples unsorted on x-axis +- Two subplots: true, predicted + +**Implementation:** +```python +def plot_activations_unsorted( + layers_true: list[np.ndarray], + layers_pred: list[np.ndarray], + module_keys: list[str], # NEW: need module names +) -> None: + """Show true and predicted activations with layer boundaries.""" + # Concatenate + A_true = np.concatenate(layers_true, axis=1) + A_pred = np.concatenate(layers_pred, axis=1) + + # Create component labels like "blocks.0.attn:0", "blocks.0.attn:1", ... + component_labels = [] + for module_key, layer in zip(module_keys, layers_true): + n_components = layer.shape[1] + component_labels.extend([f"{module_key}:{i}" for i in range(n_components)]) + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 8)) + + # Plot + ax1.imshow(A_true.T, aspect="auto", interpolation="nearest", cmap="Blues") + ax2.imshow(A_pred.T, aspect="auto", interpolation="nearest", cmap="Reds") + + # Add layer boundaries (adapt from spd/clustering/plotting/activations.py) + add_component_labeling(ax1, component_labels, axis='y') + add_component_labeling(ax2, component_labels, axis='y') + + # Titles + ax1.set_title(r"True Binary Activations (Unsorted)\n" + + r"$A_{ij} = \mathbb{1}[\text{act}_{ij} > \theta]$") + ax2.set_title(r"Predicted Binary Activations (Unsorted)\n" + + r"$\hat{A}_{ij} = \mathbb{1}[P(A_{ij}=1) > 0.5]$") +``` + +#### 1.5.2 Sorted Version (New) + +**Sorting Strategy:** + +1. **Sample Sorting (Greedy):** + - Compute sample similarity matrix (cosine similarity on true activations) + - Greedy ordering: start from most central sample, add nearest neighbor iteratively + - Apply **same ordering** to predicted activations (so we can compare) + - Reference implementation already exists in `spd/clustering/plotting/activations.py:120-162` + +2. **Component Sorting (Greedy):** + - Compute component similarity matrix (cosine similarity on true activations) + - Same greedy algorithm but on columns instead of rows + - Apply same ordering to both true and predicted + +**Three Subplots:** +1. True activations (samples sorted, components sorted) +2. Predicted activations (same ordering) +3. **Diff plot:** `predicted - true` with RdBu colormap + - Red = False Positive (predicted 1, true 0) + - Blue = False Negative (predicted 0, true 1) + - White = Correct + +**Implementation:** +```python +def plot_activations_sorted( + layers_true: list[np.ndarray], + layers_pred: list[np.ndarray], + module_keys: list[str], +) -> None: + """Show sorted activations with diff plot.""" + A_true = np.concatenate(layers_true, axis=1).astype(float) + A_pred = np.concatenate(layers_pred, axis=1).astype(float) + + # Sort samples (greedy on rows) + sample_order = greedy_sort(A_true, axis=0) # Returns indices + A_true_sorted_samples = A_true[sample_order, :] + A_pred_sorted_samples = A_pred[sample_order, :] + + # Sort components (greedy on columns) + component_order = greedy_sort(A_true_sorted_samples, axis=1) + A_true_sorted = A_true_sorted_samples[:, component_order] + A_pred_sorted = A_pred_sorted_samples[:, component_order] + + # Diff + A_diff = A_pred_sorted - A_true_sorted # Range: [-1, 0, 1] + + fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 12)) + + ax1.imshow(A_true_sorted.T, aspect="auto", interpolation="nearest", cmap="Blues") + ax1.set_title(r"True Activations (Sorted)\n" + + r"Samples and components sorted by similarity") + + ax2.imshow(A_pred_sorted.T, aspect="auto", interpolation="nearest", cmap="Reds") + ax2.set_title(r"Predicted Activations (Sorted)\n" + + r"Same ordering as true activations") + + # Diff plot with centered colormap + im3 = ax3.imshow(A_diff.T, aspect="auto", interpolation="nearest", + cmap="RdBu_r", vmin=-1, vmax=1) + ax3.set_title(r"Prediction Errors (Predicted - True)\n" + + r"Red = FP ($\hat{A}=1, A=0$), Blue = FN ($\hat{A}=0, A=1$), White = Correct") + plt.colorbar(im3, ax=ax3) + + fig.tight_layout() +``` + +**Helper Function:** +```python +def greedy_sort(A: np.ndarray, axis: int) -> np.ndarray: + """Greedy ordering by similarity. + + Args: + A: 2D array + axis: 0 for rows, 1 for columns + + Returns: + Indices in sorted order + """ + # Transpose if sorting columns + if axis == 1: + A = A.T + + # Compute cosine similarity + norms = np.linalg.norm(A, axis=1, keepdims=True) + norms = np.where(norms > 1e-8, norms, 1.0) + A_normalized = A / norms + similarity = A_normalized @ A_normalized.T + + # Greedy ordering (same as in activations.py) + n = similarity.shape[0] + avg_sim = similarity.mean(axis=1) + start_idx = int(np.argmax(avg_sim)) + + ordered = [start_idx] + remaining = set(range(n)) + remaining.remove(start_idx) + current = start_idx + + while remaining: + sims = [(i, similarity[current, i]) for i in remaining] + best_idx = max(sims, key=lambda x: x[1])[0] + ordered.append(best_idx) + remaining.remove(best_idx) + current = best_idx + + return np.array(ordered) +``` + +--- + +### 1.6 Covariance Matrix - Sorted Version + +**Current:** Single unsorted covariance plot + +**New:** Two versions +1. **Unsorted** with layer boundaries (like activations unsorted) +2. **Sorted** using same component ordering from activations + +**Implementation:** +```python +def plot_covariance_unsorted( + layers_true: list[np.ndarray], + module_keys: list[str], +) -> None: + """Covariance with layer boundaries.""" + A = np.concatenate(layers_true, axis=1).astype(float) + C = np.cov(A, rowvar=False) + + component_labels = [...] # Same as activations + + fig, ax = plt.subplots(figsize=(8, 8)) + im = ax.imshow(C, aspect="auto", interpolation="nearest", cmap="RdBu_r") + + # Add layer boundaries on both axes + add_component_labeling(ax, component_labels, axis='x') + add_component_labeling(ax, component_labels, axis='y') + + ax.set_title(r"Component Covariance Matrix (Unsorted)\n" + + r"$\text{Cov}(i,j) = \mathbb{E}[(A_i - \mu_i)(A_j - \mu_j)]$") + plt.colorbar(im) + +def plot_covariance_sorted( + layers_true: list[np.ndarray], + component_order: np.ndarray, # Pass in from activations +) -> None: + """Covariance with sorted components.""" + A = np.concatenate(layers_true, axis=1).astype(float) + A_sorted = A[:, component_order] + C_sorted = np.cov(A_sorted, rowvar=False) + + fig, ax = plt.subplots(figsize=(8, 8)) + im = ax.imshow(C_sorted, aspect="auto", interpolation="nearest", cmap="RdBu_r") + ax.set_title(r"Component Covariance Matrix (Sorted)\n" + + r"Components ordered by similarity") + plt.colorbar(im) +``` + +--- + +## Part 2: Interactive Tree Visualization (HTML/JS) + +### 2.1 High-Level Architecture + +**Export:** Python creates one JSON per tree → **Display:** HTML/JS loads JSON and renders visualizations + +### 2.2 Data to Export (per tree) + +#### Tree Metadata +```json +{ + "layer_index": 1, + "target_component_idx": 5, + "module_key": "blocks.0.mlp.W_gate", + "metrics": { + "ap": 0.85, + "accuracy": 0.92, + "balanced_accuracy": 0.88, + "prevalence": 0.023, + "n_samples": 200, + "n_positive": 46, + "n_negative": 154, + "confusion_matrix": { + "TP": 40, + "TN": 144, + "FP": 10, + "FN": 6 + } + }, + "tree_stats": { + "max_depth": 5, + "n_leaves": 12, + "n_nodes": 23 + } +} +``` + +#### Tree Structure +```json +{ + "structure": { + "children_left": [1, -1, 3, 4, -1, ...], + "children_right": [2, -1, 5, 6, -1, ...], + "feature": [7, -2, 12, 3, -2, ...], + "threshold": [0.5, -2, 0.5, 0.5, -2, ...], + "value": [[30, 20], [5, 15], ...], // [n_negative, n_positive] per node + "n_node_samples": [200, 50, 150, ...] + }, + "feature_names": [ + "blocks.0.attn.W_Q:3 (prev=0.15, AP=0.82)", + "blocks.0.attn.W_Q:17 (prev=0.08, AP=0.91)", + "blocks.0.mlp.W_in:5 (prev=0.23, AP=0.76)", + ... + ] +} +``` + +#### Activation Histograms +```json +{ + "true_activations": { + "histogram": { + "bins": [0.0, 0.01, 0.02, ...], // Bin edges + "counts": [120, 45, 23, ...] + } + }, + "predicted_probabilities": { + "histogram": { + "bins": [0.0, 0.1, 0.2, ...], + "counts": [80, 30, 40, ...] + } + } +} +``` + +#### Token-Level Samples + +**Sample Selection Strategies:** + +1. **Stratified by confusion matrix** (recommended): + - 2 True Positives (high confidence, low confidence) + - 2 True Negatives (high confidence, low confidence) + - 2 False Positives (worst errors) + - 2 False Negatives (worst errors) + - Total: 8 samples + +2. **Fallback if categories insufficient:** + - Random samples from each category + - Fill missing categories with "N/A" + +**Data Structure:** +```json +{ + "samples": [ + { + "sample_idx": 42, + "category": "TP", // TP, TN, FP, or FN + "confidence": 0.95, // abs(predicted_prob - 0.5) + "tokens": ["The", "cat", "sat", "on", "the", "mat"], + "true_activations": [0.0, 0.0, 0.82, 0.91, 0.0, 0.0], // Continuous values + "predicted_probabilities": [0.05, 0.1, 0.88, 0.94, 0.02, 0.01], + "true_binary": [0, 0, 1, 1, 0, 0], + "predicted_binary": [0, 0, 1, 1, 0, 0], + "max_true_pos": 2, // Index of max activation in true + "max_pred_pos": 3 // Index of max activation in predicted + }, + // ... 7 more samples + ] +} +``` + +#### Input Features Summary +```json +{ + "input_features_by_module": { + "blocks.0.attn.W_Q": [3, 17, 42], // Component indices used in tree + "blocks.0.mlp.W_in": [5, 12] + }, + "n_input_features_total": 5, + "n_components_total": 256 // All components in layer 0 +} +``` + +### 2.3 Python Export Implementation + +**New File:** `spd/clustering/ci_dt/export.py` + +```python +"""Export decision tree data to JSON for interactive visualization.""" + +from pathlib import Path +from typing import Any +import json +import numpy as np +from sklearn.tree import DecisionTreeClassifier + +from spd.clustering.ci_dt.core import LayerModel + + +def export_tree_json( + tree: DecisionTreeClassifier, + layer_idx: int, + target_idx: int, + module_key: str, + X: np.ndarray, # Input features (all layer 0 components) + Y_true: np.ndarray, # True binary activations for this target + Y_prob: np.ndarray, # Predicted probabilities + tokens_batch: list[list[str]], # Decoded tokens for all samples + feature_names: list[str], + output_path: Path, +) -> None: + """Export single tree to JSON.""" + + # 1. Compute metrics + Y_pred = (Y_prob >= 0.5).astype(int) + metrics = compute_tree_metrics(Y_true, Y_pred, Y_prob) + + # 2. Serialize tree structure + tree_dict = serialize_tree(tree, feature_names) + + # 3. Create activation histograms + histograms = create_histograms(Y_true, Y_prob) + + # 4. Select and export token samples + samples = select_token_samples( + Y_true, Y_prob, Y_pred, tokens_batch + ) + + # 5. Identify which input features are used + input_features = extract_input_features(tree, module_key) + + # 6. Combine into single JSON + data = { + "metadata": { + "layer_index": layer_idx, + "target_component_idx": target_idx, + "module_key": module_key, + "metrics": metrics, + "tree_stats": { + "max_depth": int(tree.tree_.max_depth), + "n_leaves": int(tree.tree_.n_leaves), + "n_nodes": int(tree.tree_.node_count), + } + }, + "tree": tree_dict, + "histograms": histograms, + "samples": samples, + "input_features": input_features, + } + + output_path.parent.mkdir(parents=True, exist_ok=True) + with open(output_path, 'w') as f: + json.dump(data, f, indent=2) + + +def export_all_trees( + models: list[LayerModel], + layers_true: list[np.ndarray], + per_layer_stats: list[dict], + component_acts: dict[str, Tensor], # Original activations (continuous) + batch_data: dict, # From dataloader (has token IDs) + tokenizer, # HuggingFace tokenizer + feature_names: list[list[str]], + output_dir: Path, +) -> None: + """Export all trees and create index.""" + + # Decode all tokens once + tokens_batch = decode_all_tokens(batch_data, tokenizer) + + # Export each tree + tree_index = [] + for layer_idx, model in enumerate(models): + module_key = list(component_acts.keys())[layer_idx] + X = layers_true[0] # Always predict from layer 0 + Y_true = layers_true[layer_idx + 1] # Target layer + + for target_idx, estimator in enumerate(model.model.estimators_): + # Get predictions for this target + Y_prob = estimator.predict_proba(X)[:, 1] + + # Get feature names for this layer's inputs + feat_names = feature_names[layer_idx] if feature_names else None + + # Export + tree_path = output_dir / "data" / f"tree_{layer_idx}_{target_idx}.json" + export_tree_json( + tree=estimator, + layer_idx=layer_idx, + target_idx=target_idx, + module_key=module_key, + X=X, + Y_true=Y_true[:, target_idx], + Y_prob=Y_prob, + tokens_batch=tokens_batch, + feature_names=feat_names, + output_path=tree_path, + ) + + # Add to index + tree_index.append({ + "layer": layer_idx, + "target": target_idx, + "module_key": module_key, + "ap": per_layer_stats[layer_idx]["ap"][target_idx], + "file": f"data/tree_{layer_idx}_{target_idx}.json" + }) + + # Write index + index_path = output_dir / "data" / "index.json" + with open(index_path, 'w') as f: + json.dump(tree_index, f, indent=2) + + +def select_token_samples( + Y_true: np.ndarray, + Y_prob: np.ndarray, + Y_pred: np.ndarray, + tokens_batch: list[list[str]], + n_per_category: int = 2, +) -> list[dict]: + """Select stratified samples from confusion matrix categories.""" + + # Categorize samples + TP_mask = (Y_true == 1) & (Y_pred == 1) + TN_mask = (Y_true == 0) & (Y_pred == 0) + FP_mask = (Y_true == 0) & (Y_pred == 1) + FN_mask = (Y_true == 1) & (Y_pred == 0) + + # Confidence = distance from decision boundary + confidence = np.abs(Y_prob - 0.5) + + samples = [] + + for category, mask in [("TP", TP_mask), ("TN", TN_mask), + ("FP", FP_mask), ("FN", FN_mask)]: + indices = np.where(mask)[0] + if len(indices) == 0: + continue + + # Sort by confidence + sorted_indices = indices[np.argsort(confidence[indices])[::-1]] + + # Take high and low confidence + n_take = min(n_per_category, len(sorted_indices)) + if n_take == 2: + selected = [sorted_indices[0], sorted_indices[-1]] # High and low + else: + selected = sorted_indices[:n_take] + + for idx in selected: + samples.append({ + "sample_idx": int(idx), + "category": category, + "confidence": float(confidence[idx]), + "tokens": tokens_batch[idx], + "true_activations": Y_true[idx].tolist(), # Would need continuous version + "predicted_probabilities": [float(Y_prob[idx])] * len(tokens_batch[idx]), + "true_binary": int(Y_true[idx]), + "predicted_binary": int(Y_pred[idx]), + }) + + return samples +``` + +**Integration in `run.py`:** + +```python +# After computing metrics (line ~121) +from spd.clustering.ci_dt.export import export_all_trees + +export_output_dir = Path("./ci_dt_vis") +export_all_trees( + models=models, + layers_true=layers_true, + per_layer_stats=per_layer_stats, + component_acts=component_acts_concat, + batch_data=next(iter(dataloader)), # Need to save earlier + tokenizer=cfg.task_config.tokenizer, + feature_names=feature_names, + output_dir=export_output_dir, +) +print(f"Exported tree visualizations to {export_output_dir}") +``` + +### 2.4 HTML/JS Viewer Implementation + +**File Structure:** +``` +ci_dt_vis/ +├── index.html # Main viewer +├── data/ +│ ├── index.json # Tree index +│ ├── tree_1_0.json # Individual trees +│ ├── tree_1_1.json +│ └── ... +├── js/ +│ ├── viewer.js # Main app logic +│ ├── tree-display.js # Tree visualization +│ ├── token-display.js # Token highlighting +│ └── sparklines.js # Histograms +└── css/ + └── style.css +``` + +**`index.html`:** +```html + + + + CI Decision Tree Viewer + + + + +
+
+

Select Tree

+ + +
+ +
+

Tree Metrics

+
+
+ +
+

Decision Tree Structure

+
+
+ +
+

Activation Distributions

+ +
+ +
+

Example Samples

+
+
+
+ + + + + + + +``` + +**`js/viewer.js`:** +```javascript +// Main viewer logic +let treeIndex = []; +let currentTree = null; + +async function init() { + // Load tree index + const response = await fetch('data/index.json'); + treeIndex = await response.json(); + + // Populate layer selector + const layers = [...new Set(treeIndex.map(t => t.layer))]; + const layerSelect = document.getElementById('layer-select'); + layers.forEach(layer => { + const option = document.createElement('option'); + option.value = layer; + option.text = `Layer ${layer}`; + layerSelect.appendChild(option); + }); + + // Event listeners + layerSelect.addEventListener('change', onLayerChange); + document.getElementById('target-select').addEventListener('change', onTargetChange); + + // Load first tree + if (treeIndex.length > 0) { + await loadTree(treeIndex[0].layer, treeIndex[0].target); + } +} + +function onLayerChange() { + const layer = parseInt(document.getElementById('layer-select').value); + const trees = treeIndex.filter(t => t.layer === layer); + + const targetSelect = document.getElementById('target-select'); + targetSelect.innerHTML = ''; + trees.forEach(tree => { + const option = document.createElement('option'); + option.value = tree.target; + option.text = `Target ${tree.target} (AP=${tree.ap.toFixed(3)})`; + targetSelect.appendChild(option); + }); + + if (trees.length > 0) { + loadTree(layer, trees[0].target); + } +} + +async function loadTree(layer, target) { + const response = await fetch(`data/tree_${layer}_${target}.json`); + currentTree = await response.json(); + + displayMetrics(currentTree.metadata); + displayHistograms(currentTree.histograms); + displayTree(currentTree.tree); + displayTokenSamples(currentTree.samples); +} + +function displayMetrics(metadata) { + const m = metadata.metrics; + const cm = m.confusion_matrix; + + const html = ` + + + + + + + + + + +
AP:${m.ap.toFixed(3)}
Accuracy:${m.accuracy.toFixed(3)}
Balanced Acc:${m.balanced_accuracy.toFixed(3)}
Prevalence:${m.prevalence.toFixed(4)}
Confusion Matrix:
TP:${cm.TP}
TN:${cm.TN}
FP:${cm.FP}
FN:${cm.FN}
+ `; + document.getElementById('metrics').innerHTML = html; +} + +function displayHistograms(histograms) { + // Use sparklines.js to render dual histograms + const canvas = document.getElementById('hist-canvas'); + const ctx = canvas.getContext('2d'); + + // Draw true activations (blue) and predicted (red) overlaid + drawHistogram(ctx, histograms.true_activations, 'blue', 0); + drawHistogram(ctx, histograms.predicted_probabilities, 'red', 0); +} + +function displayTree(treeData) { + // Use tree-display.js to render D3 tree + renderDecisionTree('tree-svg', treeData); +} + +function displayTokenSamples(samples) { + const container = document.getElementById('samples-container'); + container.innerHTML = ''; + + samples.forEach(sample => { + const div = document.createElement('div'); + div.className = `sample sample-${sample.category}`; + div.innerHTML = ` +

${sample.category} (confidence: ${sample.confidence.toFixed(3)})

+
${renderTokens(sample)}
+ `; + container.appendChild(div); + }); +} + +function renderTokens(sample) { + // Create dual-color token visualization + // Blue background = true activation, Red = predicted + return sample.tokens.map((token, i) => { + const trueVal = sample.true_activations[i]; + const predVal = sample.predicted_probabilities[i]; + + // Dual gradient or side-by-side bars + return ` + ${token} + `; + }).join(' '); +} + +// Initialize on load +init(); +``` + +**`js/tree-display.js`:** +```javascript +function renderDecisionTree(containerId, treeData) { + const container = document.getElementById(containerId); + container.innerHTML = ''; + + // Simple text-based tree for now + // Can upgrade to D3.js interactive tree later + + const textTree = buildTextTree(treeData.structure, treeData.feature_names); + const pre = document.createElement('pre'); + pre.textContent = textTree; + container.appendChild(pre); +} + +function buildTextTree(structure, featureNames, nodeIdx = 0, depth = 0) { + const indent = ' '.repeat(depth); + + if (structure.children_left[nodeIdx] === -1) { + // Leaf node + const value = structure.value[nodeIdx]; + const prediction = value[1] > value[0] ? 'ACTIVE' : 'INACTIVE'; + return `${indent}→ ${prediction} (${value[0]}/${value[1]})\n`; + } + + // Internal node + const feature = structure.feature[nodeIdx]; + const threshold = structure.threshold[nodeIdx]; + const featureName = featureNames[feature]; + + let result = `${indent}${featureName} <= ${threshold}?\n`; + result += buildTextTree(structure, featureNames, structure.children_left[nodeIdx], depth + 1); + result += `${indent}else:\n`; + result += buildTextTree(structure, featureNames, structure.children_right[nodeIdx], depth + 1); + + return result; +} +``` + +--- + +## Implementation Checklist + +### Phase 1: Static Plot Improvements +- [ ] Update `plot_layer_metrics()`: scatter with jitter instead of bars +- [ ] Add LaTeX titles to all metrics plots (TP/FP/TN/FN formulas) +- [ ] Update AP vs prevalence: log scale, no edges, color by depth +- [ ] Add AP vs prevalence heatmap to `plot_tree_statistics()` +- [ ] Implement `greedy_sort()` helper function +- [ ] Create `plot_activations_unsorted()` with layer boundaries +- [ ] Create `plot_activations_sorted()` with diff plot +- [ ] Create `plot_covariance_unsorted()` with layer boundaries +- [ ] Create `plot_covariance_sorted()` +- [ ] Update all plot titles with LaTeX and newlines +- [ ] Test with existing `run.py` workflow + +### Phase 2: Data Export +- [ ] Create `spd/clustering/ci_dt/export.py` +- [ ] Implement `export_tree_json()` +- [ ] Implement `export_all_trees()` +- [ ] Implement `select_token_samples()` with stratified sampling +- [ ] Implement `serialize_tree()`, `compute_tree_metrics()`, etc. +- [ ] Add export call to `run.py` +- [ ] Test JSON output schema + +### Phase 3: Interactive Viewer +- [ ] Create `ci_dt_vis/` directory structure +- [ ] Implement `index.html` layout +- [ ] Implement `viewer.js` tree selection and loading +- [ ] Implement `tree-display.js` text rendering (D3 optional) +- [ ] Implement `token-display.js` dual-color visualization +- [ ] Implement histogram rendering (reuse or adapt sparklines.js) +- [ ] Add CSS styling +- [ ] Test end-to-end workflow + +### Phase 4: Documentation +- [ ] Update `run.py` docstrings +- [ ] Add README in `ci_dt_vis/` explaining viewer usage +- [ ] Document JSON schema +- [ ] Add example screenshots + +--- + +## Open Questions / Design Decisions + +1. **Token samples per tree:** 8 total (2 per category) seems reasonable. Too many? +2. **Histogram bins:** 50 bins for activations, 20 for probabilities? +3. **D3.js tree or text?** Start with text, add D3 if needed +4. **Component sorting:** Should we also show a version with components sorted by layer, then by similarity within layer? +5. **File size:** Each tree JSON might be 50-200KB. With 1000s of trees, total size could be 50-200MB. Acceptable? +6. **Continuous activations for tokens:** Currently we only have binary. Need to save continuous pre-threshold values? + +--- + +## Success Metrics + +**Static Plots:** +- Plots are immediately interpretable without prior knowledge +- Titles explain abbreviations and formulas +- Layer boundaries visible in unsorted plots +- Sorting reveals structure (coactivation patterns) +- Diff plot clearly shows FP/FN errors + +**Interactive Viewer:** +- Can load and view any tree in <1 second +- Token examples clearly show where component activates +- Confusion matrix category examples are informative +- Tree structure is readable +- Histograms show activation distributions clearly diff --git a/spd/clustering/ci_dt/__init__.py b/spd/clustering/ci_dt/__init__.py new file mode 100644 index 000000000..3f8e91e98 --- /dev/null +++ b/spd/clustering/ci_dt/__init__.py @@ -0,0 +1,31 @@ +"""Causal importance decision tree package.""" + +from spd.clustering.ci_dt.config import CIDTConfig +from spd.clustering.ci_dt.core import ( + LayerModel, + build_xy, + concat_cols, + extract_prob_class_1, + get_estimator_for, + layer_metrics, + predict_all, + predict_k, + proba_for_layer, + train_trees, +) + +__all__ = [ + # Config + "CIDTConfig", + # Core + "LayerModel", + "concat_cols", + "build_xy", + "train_trees", + "extract_prob_class_1", + "predict_k", + "predict_all", + "layer_metrics", + "proba_for_layer", + "get_estimator_for", +] diff --git a/spd/clustering/ci_dt/attn.py b/spd/clustering/ci_dt/attn.py new file mode 100644 index 000000000..82f1f2736 --- /dev/null +++ b/spd/clustering/ci_dt/attn.py @@ -0,0 +1,426 @@ +# %% +"""Attention pattern visualization for CI decision tree analysis.""" + +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Float, Int +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm + +from spd.clustering.ci_dt.config import CIDTConfig +from spd.configs import Config +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.models.component_model import ComponentModel, SPDRunInfo + +# magic autoreload +# %load_ext autoreload +# %autoreload 2 + +# %% +# ----------------------- configuration ----------------------- + +config = CIDTConfig( + wandb_run_path="wandb:goodfire/spd/runs/lxs77xye", + batch_size=16, + n_batches=4, + activation_threshold=0.01, + max_depth=8, + random_state=42, +) +device: str = "cuda" if torch.cuda.is_available() else "cpu" + +# %% +# ----------------------- load model ----------------------- + +spd_run: SPDRunInfo = SPDRunInfo.from_path(config.wandb_run_path) +model: ComponentModel = ComponentModel.from_pretrained(spd_run.checkpoint_path) +model.to(device) +cfg: Config = spd_run.config + +print(f"Loaded model from {config.wandb_run_path}") + +# %% +# ----------------------- load dataset ----------------------- + +# Create LM dataset and dataloader +assert isinstance(cfg.task_config, LMTaskConfig) +pretrained_model_name = cfg.pretrained_model_name +assert pretrained_model_name is not None + +dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + column_name=cfg.task_config.column_name, + is_tokenized=False, + streaming=False, + seed=0, +) +dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=config.batch_size, + buffer_size=cfg.task_config.buffer_size, + global_seed=cfg.seed, + ddp_rank=0, + ddp_world_size=1, +) +print(f"Created LM dataset with {cfg.task_config.dataset_name}") + +# %% +# ----------------------- extract attention patterns ----------------------- + + +def extract_attention_patterns_multibatch( + model: ComponentModel, + device: torch.device | str, + dataloader: DataLoader[Any], + n_batches: int, +) -> dict[str, Float[Tensor, "total_samples n_heads seq_len seq_len"]]: + """Extract attention patterns over multiple batches. + + Args: + model: ComponentModel containing the transformer + device: Device to run inference on + dataloader: DataLoader to get batches from + n_batches: Number of batches to process + + Returns: + Dictionary mapping layer names to attention patterns (on CPU) + Format: {layer_name: tensor of shape [total_samples, n_heads, seq_len, seq_len]} + """ + print(f"Extracting attention patterns for {n_batches} batches...") + all_attention_patterns: list[dict[str, Tensor]] = [] + + for _batch_idx in tqdm(range(n_batches), desc="Batches", total=n_batches): + batch_data = next(iter(dataloader)) + input_ids: Int[Tensor, "batch seq_len"] = batch_data["input_ids"].to(device) + + # Get attention patterns on GPU + with torch.no_grad(): + outputs = model.target_model(input_ids, output_attentions=True) + + # Extract attention patterns + # outputs.attentions is a tuple of tensors, one per layer + # Each tensor has shape [batch, n_heads, seq_len, seq_len] + batch_attention: dict[str, Tensor] = {} + if hasattr(outputs, "attentions") and outputs.attentions is not None: + for layer_idx, attn_weights in enumerate(outputs.attentions): + layer_name = f"layer_{layer_idx}" + # Move to CPU immediately + batch_attention[layer_name] = attn_weights.cpu() + + all_attention_patterns.append(batch_attention) + + # Concatenate all batches on CPU + print("Concatenating batches...") + layer_names: list[str] = list(all_attention_patterns[0].keys()) + attention_patterns_concat: dict[str, Tensor] = { + layer_name: torch.cat([batch[layer_name] for batch in all_attention_patterns], dim=0) + for layer_name in layer_names + } + + print(f"Extracted attention patterns for {len(layer_names)} layers") + return attention_patterns_concat + + +# Extract attention patterns +attention_patterns: dict[str, Float[Tensor, "total_samples n_heads seq_len seq_len"]] = ( + extract_attention_patterns_multibatch( + model=model, + device=device, + dataloader=dataloader, + n_batches=config.n_batches, + ) +) + +# Print shapes +print("\nAttention pattern shapes:") +for layer_name, attn in attention_patterns.items(): + print(f" {layer_name}: {attn.shape}") + +# %% +# ----------------------- compute attention statistics ----------------------- + + +def compute_attention_stats( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], +) -> dict[str, dict[str, Float[np.ndarray, "..."]]]: + """Compute statistics about attention patterns. + + Args: + attention_patterns: Dictionary of attention patterns per layer + + Returns: + Dictionary with statistics per layer including: + - mean_pattern: Average attention pattern [n_heads, seq_len, seq_len] + - entropy: Entropy of attention distributions [samples, n_heads, seq_len] + - max_attention: Maximum attention value [samples, n_heads, seq_len] + - sparsity: Fraction of attention < 0.01 [samples, n_heads] + """ + stats: dict[str, dict[str, np.ndarray]] = {} + + for layer_name, attn in attention_patterns.items(): + # Convert to numpy for stats + attn_np: np.ndarray = attn.numpy() + + # Mean pattern across all samples + mean_pattern: np.ndarray = attn_np.mean(axis=0) # [n_heads, seq_len, seq_len] + + # Entropy per query position: -sum(p * log(p)) + # Add small epsilon to avoid log(0) + epsilon = 1e-10 + attn_safe = attn_np + epsilon + entropy: np.ndarray = -(attn_safe * np.log(attn_safe)).sum( + axis=-1 + ) # [samples, n_heads, seq_len] + + # Max attention per query position + max_attention: np.ndarray = attn_np.max(axis=-1) # [samples, n_heads, seq_len] + + # Sparsity: fraction of attention weights < 0.01 + sparsity: np.ndarray = (attn_np < 0.01).mean(axis=(2, 3)) # [samples, n_heads] + + stats[layer_name] = { + "mean_pattern": mean_pattern, + "entropy": entropy, + "max_attention": max_attention, + "sparsity": sparsity, + } + + return stats + + +attention_stats = compute_attention_stats(attention_patterns) +print("Computed attention statistics") + +# %% +# ----------------------- plot: average attention patterns per layer ----------------------- + + +def plot_average_attention_per_layer( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], + max_layers: int | None = None, +) -> None: + """Plot average attention pattern for each layer (averaged over heads and samples). + + Args: + attention_patterns: Dictionary of attention patterns per layer + max_layers: Maximum number of layers to plot (default: all) + """ + layer_names = sorted(attention_patterns.keys()) + if max_layers is not None: + layer_names = layer_names[:max_layers] + + n_layers = len(layer_names) + n_cols = min(4, n_layers) + n_rows = (n_layers + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows)) + if n_layers == 1: + axes = np.array([axes]) + axes = axes.flatten() + + for idx, layer_name in enumerate(layer_names): + attn = attention_patterns[layer_name].numpy() + # Average over samples and heads + avg_attn = attn.mean(axis=(0, 1)) # [seq_len, seq_len] + + ax = axes[idx] + im = ax.imshow(avg_attn, cmap="viridis", aspect="auto") + ax.set_title(f"{layer_name}\n(avg over samples & heads)") + ax.set_xlabel("Key position") + ax.set_ylabel("Query position") + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + # Hide unused subplots + for idx in range(n_layers, len(axes)): + axes[idx].axis("off") + + fig.tight_layout() + + +plot_average_attention_per_layer(attention_patterns, max_layers=None) +print("Average attention per layer plots generated.") + +# %% +# ----------------------- plot: per-head attention for selected layers ----------------------- + + +def plot_per_head_attention( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], + layer_names: list[str] | None = None, +) -> None: + """Plot attention pattern for each head in selected layers. + + Args: + attention_patterns: Dictionary of attention patterns per layer + layer_names: List of layer names to plot (default: first layer) + """ + if layer_names is None: + layer_names = [sorted(attention_patterns.keys())[0]] + + for layer_name in layer_names: + if layer_name not in attention_patterns: + print(f"Warning: {layer_name} not found in attention patterns") + continue + + attn = attention_patterns[layer_name].numpy() + # Average over samples + avg_attn = attn.mean(axis=0) # [n_heads, seq_len, seq_len] + n_heads = avg_attn.shape[0] + + n_cols = min(4, n_heads) + n_rows = (n_heads + n_cols - 1) // n_cols + + fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows)) + if n_heads == 1: + axes = np.array([axes]) + axes = axes.flatten() + + for head_idx in range(n_heads): + ax = axes[head_idx] + im = ax.imshow(avg_attn[head_idx], cmap="viridis", aspect="auto") + ax.set_title(f"Head {head_idx}") + ax.set_xlabel("Key position") + ax.set_ylabel("Query position") + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + # Hide unused subplots + for idx in range(n_heads, len(axes)): + axes[idx].axis("off") + + fig.suptitle(f"{layer_name} - Per-Head Attention Patterns", fontsize=14, y=1.00) + fig.tight_layout() + + +# Plot first and last layers +all_layer_names = sorted(attention_patterns.keys()) +layers_to_plot = [all_layer_names[0], all_layer_names[-1]] +plot_per_head_attention(attention_patterns, layer_names=layers_to_plot) +print(f"Per-head attention plots generated for layers: {layers_to_plot}") + +# %% +# ----------------------- plot: attention entropy across layers ----------------------- + + +def plot_attention_entropy( + attention_stats: dict[str, dict[str, np.ndarray]], +) -> None: + """Plot attention entropy statistics across layers. + + Args: + attention_stats: Dictionary of attention statistics per layer + """ + layer_names = sorted(attention_stats.keys()) + + # Collect mean entropy per layer (averaged over samples, heads, and query positions) + mean_entropies: list[float] = [] + for layer_name in layer_names: + entropy = attention_stats[layer_name]["entropy"] # [samples, n_heads, seq_len] + mean_entropies.append(float(entropy.mean())) + + # Plot + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(range(len(layer_names)), mean_entropies, marker="o") + ax.set_xlabel("Layer") + ax.set_ylabel("Mean Attention Entropy") + ax.set_title("Attention Entropy Across Layers\n(Higher = more uniform attention)") + ax.set_xticks(range(len(layer_names))) + ax.set_xticklabels(layer_names, rotation=45, ha="right") + ax.grid(True, alpha=0.3) + fig.tight_layout() + + +plot_attention_entropy(attention_stats) +print("Attention entropy plot generated.") + +# %% +# ----------------------- plot: attention sparsity across layers ----------------------- + + +def plot_attention_sparsity( + attention_stats: dict[str, dict[str, np.ndarray]], +) -> None: + """Plot attention sparsity across layers. + + Args: + attention_stats: Dictionary of attention statistics per layer + """ + layer_names = sorted(attention_stats.keys()) + + # Collect mean sparsity per layer (averaged over samples and heads) + mean_sparsities: list[float] = [] + for layer_name in layer_names: + sparsity = attention_stats[layer_name]["sparsity"] # [samples, n_heads] + mean_sparsities.append(float(sparsity.mean())) + + # Plot + fig, ax = plt.subplots(figsize=(10, 5)) + ax.plot(range(len(layer_names)), mean_sparsities, marker="o", color="C1") + ax.set_xlabel("Layer") + ax.set_ylabel("Mean Sparsity (fraction < 0.01)") + ax.set_title("Attention Sparsity Across Layers\n(Higher = more sparse/focused attention)") + ax.set_xticks(range(len(layer_names))) + ax.set_xticklabels(layer_names, rotation=45, ha="right") + ax.set_ylim(0, 1) + ax.grid(True, alpha=0.3) + fig.tight_layout() + + +plot_attention_sparsity(attention_stats) +print("Attention sparsity plot generated.") + +# %% +# ----------------------- plot: attention to first/last tokens ----------------------- + + +def plot_attention_to_special_positions( + attention_patterns: dict[str, Float[Tensor, "samples n_heads seq_len seq_len"]], +) -> None: + """Plot how much attention each position pays to first and last tokens. + + Args: + attention_patterns: Dictionary of attention patterns per layer + """ + layer_names = sorted(attention_patterns.keys()) + + # Collect attention to first and last tokens + attn_to_first: list[float] = [] + attn_to_last: list[float] = [] + + for layer_name in layer_names: + attn = attention_patterns[layer_name].numpy() + # Average over samples and heads + avg_attn = attn.mean(axis=(0, 1)) # [seq_len, seq_len] + + # Average attention to first token (across all query positions) + attn_to_first.append(float(avg_attn[:, 0].mean())) + + # Average attention to last token (across all query positions) + attn_to_last.append(float(avg_attn[:, -1].mean())) + + # Plot + fig, ax = plt.subplots(figsize=(10, 5)) + x = range(len(layer_names)) + ax.plot(x, attn_to_first, marker="o", label="Attention to first token") + ax.plot(x, attn_to_last, marker="s", label="Attention to last token") + ax.set_xlabel("Layer") + ax.set_ylabel("Mean Attention Weight") + ax.set_title("Attention to Special Token Positions Across Layers") + ax.set_xticks(x) + ax.set_xticklabels(layer_names, rotation=45, ha="right") + ax.legend() + ax.grid(True, alpha=0.3) + fig.tight_layout() + + +plot_attention_to_special_positions(attention_patterns) +print("Attention to special positions plot generated.") + +# %% diff --git a/spd/clustering/ci_dt/config.py b/spd/clustering/ci_dt/config.py new file mode 100644 index 000000000..de0f95cee --- /dev/null +++ b/spd/clustering/ci_dt/config.py @@ -0,0 +1,16 @@ +"""Configuration for causal importance decision tree training.""" + +from dataclasses import dataclass + + +@dataclass +class CIDTConfig: + """Configuration for causal importance decision tree training.""" + + wandb_run_path: str # WandB run path for the SPD model + batch_size: int = 10 # Number of samples per batch for GPU inference + n_batches: int = 25 # Number of batches to process (total samples = batch_size * n_batches) + n_ctx: int = 64 # Context length (sequence length) for tokenization + activation_threshold: float = 0.01 # Threshold for boolean conversion + max_depth: int = 8 # Maximum depth for decision trees + random_state: int = 7 # Random state for reproducibility diff --git a/spd/clustering/ci_dt/core.py b/spd/clustering/ci_dt/core.py new file mode 100644 index 000000000..ec9ba585b --- /dev/null +++ b/spd/clustering/ci_dt/core.py @@ -0,0 +1,230 @@ +"""Core library functions for causal importance decision trees.""" + +import warnings +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal + +import numpy as np +from jaxtyping import Bool, Float +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + balanced_accuracy_score, +) +from sklearn.multioutput import MultiOutputClassifier +from sklearn.tree import DecisionTreeClassifier +from tqdm import tqdm + + +@dataclass +class LayerModel: + """Holds a trained per-layer model.""" + + layer_index: int + model: MultiOutputClassifier + feature_dim: int + target_dim: int + + +def concat_cols( + Xs: Sequence[Bool[np.ndarray, "n_samples n_features"]], +) -> Bool[np.ndarray, "n_samples n_concat"]: + """Column-concat a sequence or return empty (n,0).""" + n_samples: int = Xs[0].shape[0] if len(Xs) else 0 + return np.concatenate(Xs, axis=1) if len(Xs) else np.zeros((n_samples, 0), bool) + + +def build_xy( + layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], +) -> list[ + tuple[ + Bool[np.ndarray, "n_samples n_features"], + Bool[np.ndarray, "n_samples n_targets"], + ] +]: + """Return (X_k,Y_k) for k=1..L-1 with X_k=concat(layers[:k]).""" + XYs: list[tuple[np.ndarray, np.ndarray]] = [] + for k in range(1, len(layers)): + X_k: np.ndarray = concat_cols(layers[:k]) + Y_k: np.ndarray = layers[k] + XYs.append((X_k, Y_k)) + return XYs + + +def train_trees( + layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + *, + max_depth: int | None = None, + min_samples_leaf: int = 1, + random_state: int | None = 0, +) -> list[LayerModel]: + """Train one decision tree per component per target layer using previous layers as features.""" + XYs = build_xy(layers) + models: list[LayerModel] = [] + for k, (X_k, Y_k) in tqdm(enumerate(XYs, start=1), total=len(XYs), desc="Training trees"): + base = DecisionTreeClassifier( + max_depth=max_depth, + min_samples_leaf=min_samples_leaf, + random_state=random_state, + ) + model = MultiOutputClassifier(base) + model.fit(X_k.astype(np.uint8), Y_k.astype(np.uint8)) + models.append(LayerModel(k, model, int(X_k.shape[1]), int(Y_k.shape[1]))) + return models + + +def extract_prob_class_1( + proba_list: list[np.ndarray], + model: MultiOutputClassifier, +) -> np.ndarray: + """Extract P(y=1) for each output. + + Assumes constant components are filtered out, so both classes should always be present. + """ + result: list[np.ndarray] = [] + for i, p in enumerate(proba_list): + estimator = model.estimators_[i] # pyright: ignore[reportIndexIssue] + assert isinstance(estimator, DecisionTreeClassifier) + classes = estimator.classes_ + assert len(classes) == 2, f"Expected 2 classes but got {len(classes)} for output {i}" + # Extract P(y=1) from second column + result.append(p[:, 1]) + return np.stack(result, axis=1) + + +def predict_k( + models: Sequence[LayerModel], + prefix_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + k: int, + *, + threshold: float = 0.5, +) -> Bool[np.ndarray, "n_samples n_components_k"]: + """Predict layer k activations from layers[:k].""" + lm: LayerModel = next(m for m in models if m.layer_index == k) + X: np.ndarray = concat_cols(prefix_layers) + # dbg_auto(X) + proba = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore + # dbg_auto(proba) + # dbg_auto(proba[0]) + P: np.ndarray = extract_prob_class_1(proba, lm.model) + # dbg_auto(P) + Y_hat: np.ndarray = (threshold <= P).astype(bool) + # dbg_auto(Y_hat) + return Y_hat + + +def predict_all( + models: Sequence[LayerModel], + seed_layers: Sequence[Bool[np.ndarray, "n_samples n_components"]], + *, + thresholds: Sequence[float] | None = None, +) -> list[Bool[np.ndarray, "n_samples n_components"]]: + """Sequentially predict layers 1.. using layer 0 as seed.""" + out: list[np.ndarray] = [seed_layers[0].copy()] + ths: list[float] = list(thresholds) if thresholds is not None else [] + for i, lm in enumerate(sorted(models, key=lambda m: m.layer_index)): + thr: float = ths[i] if i < len(ths) else 0.5 + out.append(predict_k(models, out, lm.layer_index, threshold=thr)) + return out + + +MetricKey = Literal["ap", "acc", "bacc", "prev", "tpr", "tnr", "precision", "npv", "f1"] + + +def layer_metrics( + Y_true: Bool[np.ndarray, "n t"], + Y_prob: Float[np.ndarray, "n t"], + Y_pred: Bool[np.ndarray, "n t"], +) -> dict[MetricKey, np.ndarray]: + """Return per-target metrics: AP, acc, bacc, prevalence, TPR, TNR, precision, NPV, F1. + + Returns: + Dictionary with keys: + - ap: Average precision + - acc: Accuracy + - bacc: Balanced accuracy + - prev: Prevalence (fraction of positive samples) + - tpr: True Positive Rate (Recall/Sensitivity) + - tnr: True Negative Rate (Specificity) + - precision: Precision (when we predict active, how often are we right?) + - npv: Negative Predictive Value (when we predict inactive, how often are we right?) + - f1: F1 score + + Each value is an array of length T (number of target components). + """ + T: int = Y_true.shape[1] + + ap: Float[np.ndarray, " t"] = np.full(T, np.nan) + acc: Float[np.ndarray, " t"] = np.full(T, np.nan) + bacc: Float[np.ndarray, " t"] = np.full(T, np.nan) + prev: Float[np.ndarray, " t"] = np.full(T, np.nan) + tpr: Float[np.ndarray, " t"] = np.full(T, np.nan) + tnr: Float[np.ndarray, " t"] = np.full(T, np.nan) + precision: Float[np.ndarray, " t"] = np.full(T, np.nan) + npv: Float[np.ndarray, " t"] = np.full(T, np.nan) + f1: Float[np.ndarray, " t"] = np.full(T, np.nan) + + for j in range(T): + y: np.ndarray = Y_true[:, j].astype(int) + p: np.ndarray = Y_prob[:, j] + yhat: np.ndarray = Y_pred[:, j].astype(int) + prev[j] = float(y.mean()) + + # Compute confusion matrix elements + tp: int = int(((y == 1) & (yhat == 1)).sum()) + tn: int = int(((y == 0) & (yhat == 0)).sum()) + fp: int = int(((y == 0) & (yhat == 1)).sum()) + fn: int = int(((y == 1) & (yhat == 0)).sum()) + + # TPR (Recall/Sensitivity) = TP / (TP + FN) + tpr[j] = tp / (tp + fn) + + # TNR (Specificity) = TN / (TN + FP) + tnr[j] = tn / (tn + fp) + + # Precision (PPV) = TP / (TP + FP) - when we predict active, how often are we right? + if (tp + fp) > 0: + precision[j] = tp / (tp + fp) + else: + precision[j] = np.nan + warnings.warn(f"Precision failed: {tp=}, {fp=}, {tp+fp=}", stacklevel=1) + + # Negative Predictive Value = TN / (TN + FN) - when we predict inactive, how often are we right? + npv[j] = tn / (tn + fn) + + # F1 = 2 * (precision * recall) / (precision + recall) + f1[j] = 2 * (precision[j] * tpr[j]) / (precision[j] + tpr[j]) + + # Sklearn metrics + ap[j] = average_precision_score(y, p) + acc[j] = accuracy_score(y, yhat) + bacc[j] = balanced_accuracy_score(y, yhat) + + return { + "ap": ap, + "acc": acc, + "bacc": bacc, + "prev": prev, + "tpr": tpr, + "tnr": tnr, + "precision": precision, + "npv": npv, + "f1": f1, + } + + +def proba_for_layer(lm: LayerModel, X: np.ndarray) -> np.ndarray: + """Return P(y=1) per target column.""" + proba_list = lm.model.predict_proba(X.astype(np.uint8)) # type: ignore + return extract_prob_class_1(proba_list, lm.model) + + +def get_estimator_for( + models: list[LayerModel], layer_idx: int, target_idx: int +) -> DecisionTreeClassifier: + """Fetch the per-output estimator for a given layer and column.""" + lm = next(m for m in models if m.layer_index == layer_idx) + estimator = lm.model.estimators_[target_idx] # pyright: ignore[reportIndexIssue] + assert isinstance(estimator, DecisionTreeClassifier) + return estimator diff --git a/spd/clustering/ci_dt/js/cluster-detail.js b/spd/clustering/ci_dt/js/cluster-detail.js new file mode 100644 index 000000000..83abfb96e --- /dev/null +++ b/spd/clustering/ci_dt/js/cluster-detail.js @@ -0,0 +1,740 @@ +let clusterData = null; +let allClusters = null; +let textSamples = {}; +let activationsArray = null; +let activationsMap = {}; +let currentClusterHash = null; +let modelInfo = {}; +let explanations = {}; + +// Component-level data +let componentActivations = {}; // Map component labels to their activation data +let enabledComponents = new Set(); // Track which components are enabled +let combinationStrategy = 'max'; // How to combine component activations: 'max', 'sum', 'mean' + +async function init() { + // Get cluster hash from URL + const urlParams = new URLSearchParams(window.location.search); + currentClusterHash = urlParams.get('id'); + + if (!currentClusterHash) { + const loading = document.getElementById('loading'); + if (!loading) { + const msg = 'Fatal error: loading element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + loading.textContent = 'No cluster ID specified'; + return; + } + + await loadData(); +} + +async function loadData() { + const progressBar = NOTIF.pbar('Loading cluster data...'); + + try { + progressBar.progress(0.1); + + // Load data in parallel + let clusters, samples, activationsMapResponse, modelInfoResponse; + + const clustersPath = CONFIG.getDataPath('clusters'); + const textSamplesPath = CONFIG.getDataPath('textSamples'); + const activationsMapPath = CONFIG.getDataPath('activationsMap'); + const modelInfoPath = CONFIG.getDataPath('modelInfo'); + const explanationsPath = CONFIG.getDataPath('explanations'); + + try { + [clusters, samples, activationsMapResponse, modelInfoResponse] = await Promise.all([ + loadJSONL(clustersPath, 'cluster_hash').catch(e => { + throw new Error(`Failed to load ${clustersPath}: ${e.message}`); + }), + loadJSONL(textSamplesPath, 'text_hash').catch(e => { + throw new Error(`Failed to load ${textSamplesPath}: ${e.message}`); + }), + fetch(activationsMapPath).catch(e => { + throw new Error(`Failed to load ${activationsMapPath}: ${e.message}`); + }), + fetch(modelInfoPath).catch(e => { + throw new Error(`Failed to load ${modelInfoPath}: ${e.message}`); + }) + ]); + + // Load explanations (non-critical, don't fail if missing) + explanations = await loadJSONL(explanationsPath, 'cluster_id').catch(() => ({})); + } catch (error) { + progressBar.complete(); + NOTIF.error(error.message, error, null); + const loading = document.getElementById('loading'); + if (loading) { + loading.textContent = error.message; + } else { + console.error('loading element not found, cannot display error message'); + } + throw error; + } + + progressBar.progress(0.4); + + if (!activationsMapResponse.ok) { + const msg = `Failed to load ${activationsMapPath} (HTTP ${activationsMapResponse.status})`; + NOTIF.error(msg, null, null); + throw new Error(msg); + } + if (!modelInfoResponse.ok) { + const msg = `Failed to load ${modelInfoPath} (HTTP ${modelInfoResponse.status})`; + NOTIF.error(msg, null, null); + throw new Error(msg); + } + + allClusters = clusters; + textSamples = samples; + + try { + activationsMap = await activationsMapResponse.json(); + } catch (error) { + const msg = `Failed to parse ${activationsMapPath} (invalid JSON)`; + NOTIF.error(msg, error, null); + throw new Error(msg); + } + + try { + modelInfo = await modelInfoResponse.json(); + } catch (error) { + const msg = `Failed to parse ${modelInfoPath} (invalid JSON)`; + NOTIF.error(msg, error, null); + throw new Error(msg); + } + + progressBar.progress(0.6); + + if (!allClusters[currentClusterHash]) { + const msg = 'Cluster not found'; + NOTIF.error(msg, null, null); + const loading = document.getElementById('loading'); + if (loading) { + loading.textContent = msg; + } else { + console.error('loading element not found, cannot display error message'); + } + progressBar.complete(); + return; + } + + clusterData = allClusters[currentClusterHash]; + + // Load activations (float16 compressed npz) + const activationsPath = CONFIG.getDataPath('activations'); + try { + activationsArray = await NDArray.load(activationsPath); + } catch (error) { + const msg = `Failed to load ${activationsPath}`; + NOTIF.error(msg, error, null); + throw new Error(msg); + } + + progressBar.progress(0.9); + + displayCluster(); + progressBar.complete(); + const loading = document.getElementById('loading'); + if (!loading) { + const msg = 'Fatal error: loading element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + loading.style.display = 'none'; + } catch (error) { + progressBar.complete(); + console.error('Load error:', error); + console.error('Stack:', error.stack); + } +} + +function displayCluster() { + // Update title + const clusterTitle = document.getElementById('clusterTitle'); + if (!clusterTitle) { + const msg = 'Fatal error: clusterTitle element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + clusterTitle.textContent = `Cluster ${currentClusterHash}`; + + // Display component count + const componentCount = document.getElementById('componentCount'); + if (!componentCount) { + const msg = 'Fatal error: componentCount element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + componentCount.textContent = clusterData.components.length; + + // Display explanation and setup copy handler + displayExplanation(); + setupCopyHandler(); + + // Initialize component data + initializeComponentData(); + + // Display model visualization + displayModelVisualization(); + + // Setup components table + setupComponentsTable(); + + // Setup hover highlighting between model view and components table + setupModelViewHighlighting(); + + // Display histogram plots + displayHistograms(); + + // Display token activation stats if available + if (clusterData.stats && clusterData.stats.token_activations) { + displayTokenActivations(); + } + + // Display samples + displaySamples(); +} + +function displayExplanation() { + const explanationSpan = document.getElementById('clusterExplanation'); + if (!explanationSpan) return; + + const explanationData = explanations[currentClusterHash]; + if (explanationData && explanationData.explanation) { + explanationSpan.textContent = explanationData.explanation; + explanationSpan.style.fontStyle = 'normal'; + explanationSpan.style.color = '#000'; + } else { + explanationSpan.textContent = 'No explanation'; + explanationSpan.style.fontStyle = 'italic'; + explanationSpan.style.color = '#666'; + } +} + +function setupCopyHandler() { + const copyBtn = document.getElementById('copyTemplateBtn'); + if (!copyBtn) return; + + copyBtn.addEventListener('click', async () => { + const template = JSON.stringify({ + cluster_id: currentClusterHash, + explanation: "" + }) + '\n'; + + try { + await navigator.clipboard.writeText(template); + NOTIF.success('Template copied to clipboard!'); + } catch (err) { + // Fallback for older browsers + const textArea = document.createElement('textarea'); + textArea.value = template; + textArea.style.position = 'fixed'; + textArea.style.left = '-999999px'; + document.body.appendChild(textArea); + textArea.select(); + try { + document.execCommand('copy'); + NOTIF.success('Template copied to clipboard!'); + } catch (e) { + NOTIF.error('Failed to copy template', e, null); + } + document.body.removeChild(textArea); + } + }); +} + +function initializeComponentData() { + // Load component activations if available + if (clusterData.component_activations) { + componentActivations = clusterData.component_activations; + } + + // Enable all components by default + enabledComponents.clear(); + clusterData.components.forEach(comp => { + enabledComponents.add(comp.label); + }); +} + +function displayModelVisualization() { + const modelViewDiv = document.getElementById('modelView'); + if (!modelViewDiv) { + const msg = 'Fatal error: modelView element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + renderModelView(modelViewDiv, currentClusterHash, allClusters, modelInfo, CONFIG.visualization.colormap, CONFIG.visualization.modelViewCellSize); +} + +function displayHistograms() { + const stats = clusterData.stats; + if (!stats) return; + + const histogramPlots = document.getElementById('histogramPlots'); + if (!histogramPlots) { + const msg = 'Fatal error: histogramPlots element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + histogramPlots.innerHTML = ''; + + // Color mapping for different histogram types + const statColors = { + 'all_activations': '#4169E1', + 'max_activation-max-16': '#DC143C', + 'max_activation-max-32': '#DC143C', + 'mean_activation-max-16': '#228B22', + 'median_activation-max-16': '#FF8C00', + 'min_activation-max-16': '#9370DB', + 'max_activation_position': '#FF6347' + }; + + // Discover all histogram stats + const histogramStats = []; + for (const [key, value] of Object.entries(stats)) { + if (value && typeof value === 'object' && 'bin_counts' in value && 'bin_edges' in value) { + histogramStats.push(key); + } + } + + // Create a plot for each histogram stat + histogramStats.forEach(statKey => { + const histData = stats[statKey]; + const color = statColors[statKey] || '#808080'; + const label = statKey.replace(/-/g, ' ').replace(/_/g, ' ') + .split(' ') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(' '); + + // Create container for this plot + const plotContainer = document.createElement('div'); + plotContainer.style.display = 'flex'; + plotContainer.style.flexDirection = 'column'; + plotContainer.style.alignItems = 'center'; + plotContainer.style.minWidth = '250px'; + + // Add label + const plotLabel = document.createElement('div'); + plotLabel.textContent = label; + plotLabel.style.fontSize = '12px'; + plotLabel.style.fontWeight = 'bold'; + plotLabel.style.marginBottom = '5px'; + plotLabel.style.textAlign = 'center'; + plotContainer.appendChild(plotLabel); + + // Create sparkline + const sparklineContainer = document.createElement('div'); + sparklineContainer.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + const min = histData.bin_edges[0]; + const max = histData.bin_edges[histData.bin_edges.length - 1]; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth || 200, + height: CONFIG.visualization.sparklineHeight || 60, + color: color, + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin || 35} + }); + + sparklineContainer.innerHTML = svg; + + // Add tooltip with statistics + const mean = calculateHistogramMean(histData); + const median = calculateHistogramMedian(histData); + const totalCount = histData.bin_counts.reduce((a, b) => a + b, 0); + sparklineContainer.title = `${label} (n=${totalCount})\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + plotContainer.appendChild(sparklineContainer); + histogramPlots.appendChild(plotContainer); + }); +} + +function displayTokenActivations() { + const tokenStats = clusterData.stats.token_activations; + + // Show the section + const tokenActivations = document.getElementById('tokenActivations'); + if (!tokenActivations) { + const msg = 'Fatal error: tokenActivations element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + tokenActivations.style.display = 'block'; + + // Setup top tokens table + if (tokenStats.top_tokens && tokenStats.top_tokens.length > 0) { + const tableData = tokenStats.top_tokens.map((item, idx) => ({ + rank: idx + 1, + token: item.token, + count: item.count, + percentage: ((item.count / tokenStats.total_activations) * 100) + })); + + const maxPercentage = tableData.length > 0 ? tableData[0].percentage : 0; + + const tableConfig = { + data: tableData, + columns: [ + { + key: 'rank', + label: '#', + type: 'number', + width: '40px', + align: 'right' + }, + { + key: 'token', + label: 'Token', + type: 'string', + width: '120px', + renderer: (value) => { + // Show token in a monospace box with visual formatting + const tokenDisplay = value.replace(/ /g, '·').replace(/\n/g, '↵'); + return `${tokenDisplay}`; + } + }, + { + key: 'percentage', + label: '%', + type: 'number', + width: '70px', + align: 'right', + renderer: (value) => { + const percentageValue = value; + const percentage = percentageValue.toFixed(1); + + // Color based on percentage (normalized by max percentage) + const normalizedPct = maxPercentage > 0 ? percentageValue / maxPercentage : 0; + const intensity = Math.floor((1 - normalizedPct) * 255); + const bgColor = `rgb(255, ${intensity}, ${intensity})`; + + const span = document.createElement('span'); + span.textContent = `${percentage}%`; + span.style.backgroundColor = bgColor; + span.style.padding = '2px 4px'; + span.style.borderRadius = '2px'; + + return span; + }, + infoFunction: () => { + return `Unique: ${tokenStats.total_unique_tokens.toLocaleString()} | Total: ${tokenStats.total_activations.toLocaleString()} | Entropy: ${tokenStats.entropy.toFixed(2)} | Conc: ${(tokenStats.concentration_ratio * 100).toFixed(1)}%`; + } + } + ], + pageSize: 10, + showFilters: false, + showInfo: true + }; + + new DataTable('#topTokensTable', tableConfig); + } +} + +function setupComponentsTable() { + const tableData = clusterData.components.map(comp => ({ + label: comp.label, + module: comp.module, + index: comp.index, + enabled: enabledComponents.has(comp.label) + })); + + const tableConfig = { + data: tableData, + columns: [ + { + key: 'enabled', + label: '✓', + type: 'boolean', + width: '40px', + align: 'center', + renderer: (value, row) => { + const checkbox = document.createElement('input'); + checkbox.type = 'checkbox'; + checkbox.checked = value; + checkbox.style.cursor = 'pointer'; + checkbox.addEventListener('change', (e) => { + onComponentToggle(row.label, e.target.checked); + }); + return checkbox; + }, + filterable: false + }, + { + key: 'module', + label: 'Module', + type: 'string', + width: '250px' + }, + { + key: 'index', + label: 'Index', + type: 'number', + width: '80px', + align: 'right' + } + ], + pageSize: CONFIG.clusterPage.pageSize, + showFilters: false + }; + + new DataTable('#componentsTable', tableConfig); +} + +function onComponentToggle(componentLabel, isEnabled) { + if (isEnabled) { + enabledComponents.add(componentLabel); + } else { + enabledComponents.delete(componentLabel); + } + + // Recompute and redisplay activations + recomputeDisplayedActivations(); +} + +async function recomputeDisplayedActivations() { + // If no components are enabled or component activations not available, use cluster-level + if (enabledComponents.size === 0 || !componentActivations || Object.keys(componentActivations).length === 0) { + // Just redisplay with cluster-level activations (default) + displaySamples(); + return; + } + + // If all components are enabled, use cluster-level activations (faster) + if (enabledComponents.size === clusterData.components.length) { + displaySamples(); + return; + } + + // Recompute activations based on enabled components + displaySamples(); +} + +function combineComponentActivations(componentActsList, strategy) { + // componentActsList: array of activation arrays [n_ctx] + // Returns: combined activation array [n_ctx] + + if (componentActsList.length === 0) { + return null; + } + + if (componentActsList.length === 1) { + return componentActsList[0]; + } + + const n_ctx = componentActsList[0].length; + const combined = new Array(n_ctx).fill(0); + + if (strategy === 'max') { + for (let i = 0; i < n_ctx; i++) { + let maxVal = componentActsList[0][i]; + for (let j = 1; j < componentActsList.length; j++) { + if (componentActsList[j][i] > maxVal) { + maxVal = componentActsList[j][i]; + } + } + combined[i] = maxVal; + } + } else if (strategy === 'sum') { + for (let i = 0; i < n_ctx; i++) { + let sum = 0; + for (let j = 0; j < componentActsList.length; j++) { + sum += componentActsList[j][i]; + } + combined[i] = sum; + } + } else if (strategy === 'mean') { + for (let i = 0; i < n_ctx; i++) { + let sum = 0; + for (let j = 0; j < componentActsList.length; j++) { + sum += componentActsList[j][i]; + } + combined[i] = sum / componentActsList.length; + } + } + + return combined; +} + +function setupModelViewHighlighting() { + // Get all model view cells + const modelViewCells = document.querySelectorAll('.modelview-module-cell'); + + // Get components table + const componentsTable = document.querySelector('#componentsTable'); + if (!componentsTable) return; + + modelViewCells.forEach(cell => { + cell.addEventListener('mouseenter', (e) => { + const moduleName = e.target.dataset.module; + if (!moduleName) return; + + // Find and highlight all rows in the components table that match this module + const tableRows = componentsTable.querySelectorAll('.tablejs-data-row'); + tableRows.forEach(row => { + const cells = row.querySelectorAll('td'); + if (cells.length > 1) { + const moduleCell = cells[1]; // Second column is module name (first is checkbox) + if (moduleCell && moduleCell.textContent === moduleName) { + row.style.backgroundColor = '#fff3cd'; // Light yellow highlight + } + } + }); + }); + + cell.addEventListener('mouseleave', () => { + // Remove highlighting from all rows + const tableRows = componentsTable.querySelectorAll('.tablejs-data-row'); + tableRows.forEach(row => { + row.style.backgroundColor = ''; + }); + }); + }); +} + +function displaySamples() { + const tbody = document.getElementById('samplesTableBody'); + if (!tbody) { + const msg = 'Fatal error: samplesTableBody element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + tbody.innerHTML = ''; + + // Get the main criterion samples (max_activation) + const criterionKey = Object.keys(clusterData.criterion_samples)[0]; + if (!criterionKey) { + tbody.innerHTML = 'No samples available'; + return; + } + + const sampleHashes = clusterData.criterion_samples[criterionKey]; + const samplesToShow = Math.min(CONFIG.clusterPage.maxSamplesPerCluster, sampleHashes.length); + + // Check if we need to use component-level activations + const useComponentActivations = componentActivations && + Object.keys(componentActivations).length > 0 && + enabledComponents.size < clusterData.components.length; + + for (let i = 0; i < samplesToShow; i++) { + const textHash = sampleHashes[i]; + const textSample = textSamples[textHash]; + + if (!textSample) { + console.warn(`Text sample not found for hash: ${textHash}`); + continue; + } + + let activationsData; + + if (useComponentActivations) { + // Compute combined activations from enabled components + const componentActsList = []; + + for (const comp of clusterData.components) { + if (enabledComponents.has(comp.label) && componentActivations[comp.label]) { + const compData = componentActivations[comp.label]; + // Find the activation for this text sample + const hashIdx = compData.activation_sample_hashes.indexOf(`${currentClusterHash}:${comp.label}:${textHash}`); + if (hashIdx !== -1) { + const activationIdx = compData.activation_indices[hashIdx]; + if (activationIdx !== undefined && activationsArray) { + const compActivations = activationsArray.get(activationIdx); + componentActsList.push(Array.from(compActivations.data)); + } + } + } + } + + if (componentActsList.length > 0) { + activationsData = combineComponentActivations(componentActsList, combinationStrategy); + } + } + + // Fall back to cluster-level activations if component activations not available + if (!activationsData) { + const fullHash = `${currentClusterHash}:${textHash}`; + const activationIdx = activationsMap[fullHash]; + + if (activationIdx !== undefined && activationsArray) { + const activations = activationsArray.get(activationIdx); + activationsData = Array.from(activations.data); + } + } + + let tokenViz; + if (activationsData) { + // Find max position + const maxPosition = activationsData.indexOf(Math.max(...activationsData)); + + // Use the proper token visualization with coloring and tooltips + tokenViz = createTokenVisualizationWithTooltip( + textSample.tokens, + activationsData, + maxPosition + ); + } else { + // Fallback to simple visualization if no activations + console.warn(`No activations found for sample ${i}`); + tokenViz = createSimpleTokenViz(textSample.tokens); + } + + const tr = document.createElement('tr'); + tr.innerHTML = ` + ${i + 1} + + `; + + // Add token visualization to last cell + tr.lastElementChild.appendChild(tokenViz); + + tbody.appendChild(tr); + } + + if (sampleHashes.length > CONFIG.clusterPage.maxSamplesPerCluster) { + const tr = document.createElement('tr'); + tr.innerHTML = ` + ... and ${sampleHashes.length - CONFIG.clusterPage.maxSamplesPerCluster} more samples + `; + tbody.appendChild(tr); + } +} + +function createSimpleTokenViz(tokens) { + const container = document.createElement('div'); + container.className = 'token-container'; + container.textContent = tokens.join(' '); + return container; +} + +// Initialize config and load data on page load +(async () => { + await initConfig(); + init(); +})(); \ No newline at end of file diff --git a/spd/clustering/ci_dt/js/cluster-selection.js b/spd/clustering/ci_dt/js/cluster-selection.js new file mode 100644 index 000000000..6a5ce1142 --- /dev/null +++ b/spd/clustering/ci_dt/js/cluster-selection.js @@ -0,0 +1,841 @@ +let clusterData = {}; +let modelInfo = {}; +let dataTable = null; +let explanations = {}; + +// Alpine.js data component for model info +const modelInfoData = { + data: {}, + hasData: false, + + async loadData() { + try { + const response = await fetch(CONFIG.getDataPath('modelInfo')); + this.data = await response.json(); + this.hasData = Object.keys(this.data).length > 0; + + // Also populate global modelInfo for DataTable renderers + modelInfo = this.data; + + console.log('Model info loaded:', this.hasData, Object.keys(this.data)); + } catch (error) { + console.error('Failed to load model info:', error); + this.hasData = false; + } + }, + + formatParameters(totalParams) { + if (!totalParams) return '-'; + if (totalParams >= 1000000) return (totalParams / 1000000).toFixed(1) + 'M'; + if (totalParams >= 1000) return (totalParams / 1000).toFixed(1) + 'K'; + return totalParams.toString(); + }, + + formatWandBLink(path) { + if (!path) return '-'; + + // Remove "wandb:" prefix if present + const cleanPath = path.replace(/^wandb:/, ''); + + // Convert to WandB URL + const url = `https://wandb.ai/${cleanPath}`; + + // Show shortened path in link text + const displayText = cleanPath.length > 60 + ? cleanPath.substring(0, 57) + '...' + : cleanPath; + + return `${displayText}`; + } +}; + +// Custom column renderers +const columnRenderers = { + modelView: function(value, row, col) { + const container = document.createElement('div'); + container.className = 'modelview-cell'; + + renderModelView(container, row.clusterHash, clusterData, modelInfo, CONFIG.visualization.colormap, CONFIG.visualization.modelViewCellSizeTable); + + return container; + }, + + modulesSummary: function(value, row, col) { + const modules = row.modules; + const container = document.createElement('div'); + container.className = 'module-summary'; + + if (modules.length === 1) { + const parts = modules[0].split('.'); + container.textContent = parts.length > 2 ? parts.slice(-2).join('.') : modules[0]; + } else if (modules.length <= 3) { + container.textContent = modules.map(m => { + const parts = m.split('.'); + return parts.length > 2 ? parts.slice(-2).join('.') : m; + }).join(', '); + } else { + container.textContent = `${modules.length} modules`; + } + + container.title = modules.join('\n'); + return container; + }, + + activationHistogram: function(value, row, col) { + const histData = row.stats.all_activations; + if (!histData) { + return 'No data'; + } + + const container = document.createElement('div'); + container.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + const min = row.stats.min_activation; + const max = row.stats.max_activation; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + // Pass bin centers as x-values and counts as y-values + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth, + height: CONFIG.visualization.sparklineHeight, + color: '#4169E1', + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin} + }); + + container.innerHTML = svg; + + const mean = row.stats.mean_activation; + const median = calculateHistogramMedian(histData); + const n = row.stats.n_tokens; + + container.title = `All Activations Histogram (n=${n})\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + return container; + }, + + maxActivationDistribution: function(value, row, col) { + const histData = row.stats['max_activation-max-16']; + if (!histData) { + return 'No data'; + } + + const container = document.createElement('div'); + container.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + const min = histData.bin_edges[0]; + const max = histData.bin_edges[histData.bin_edges.length - 1]; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + // Pass bin centers as x-values and counts as y-values + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth, + height: CONFIG.visualization.sparklineHeight, + color: '#DC143C', + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin} + }); + + container.innerHTML = svg; + + const n = row.stats.n_samples; + const mean = calculateHistogramMean(histData); + const median = calculateHistogramMedian(histData); + + container.title = `Max Activation Distribution (n=${n} samples)\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + return container; + }, + + clusterLink: function(value, row, col) { + return `View →`; + }, + + explanation: function(value, row, col) { + if (!value) { + return ''; + } + // Truncate long explanations + const maxLength = 60; + if (value.length > maxLength) { + const truncated = value.substring(0, maxLength) + '...'; + const span = document.createElement('span'); + span.textContent = truncated; + span.title = value; // Show full text on hover + return span; + } + return value; + }, + + tokenEntropy: function(value, row, col) { + const tokenStats = row.stats.token_activations; + if (!tokenStats) { + return 'N/A'; + } + return tokenStats.entropy.toFixed(2); + }, + + tokenConcentration: function(value, row, col) { + const tokenStats = row.stats.token_activations; + if (!tokenStats) { + return 'N/A'; + } + return (tokenStats.concentration_ratio * 100).toFixed(1) + '%'; + }, + + topToken: function(value, row, col) { + const tokenStats = row.stats.token_activations; + if (!tokenStats || !tokenStats.top_tokens || tokenStats.top_tokens.length === 0) { + return 'N/A'; + } + + const container = document.createElement('div'); + container.style.fontFamily = 'monospace'; + container.style.fontSize = '11px'; + container.style.lineHeight = '1.4'; + + const topN = Math.min(5, tokenStats.top_tokens.length); + const maxPercentage = tokenStats.top_tokens.length > 0 + ? ((tokenStats.top_tokens[0].count / tokenStats.total_activations) * 100) + : 0; + + for (let i = 0; i < topN; i++) { + const token = tokenStats.top_tokens[i]; + const tokenDisplay = token.token.replace(/ /g, '·').replace(/\n/g, '↵'); + const percentageValue = ((token.count / tokenStats.total_activations) * 100); + const percentage = percentageValue.toFixed(1); + + // Color based on percentage (normalized by max percentage) + const normalizedPct = maxPercentage > 0 ? percentageValue / maxPercentage : 0; + const intensity = Math.floor((1 - normalizedPct) * 255); + const bgColor = `rgb(255, ${intensity}, ${intensity})`; + + const line = document.createElement('div'); + line.style.display = 'flex'; + line.style.justifyContent = 'space-between'; + line.style.gap = '8px'; + + const tokenSpan = document.createElement('span'); + tokenSpan.innerHTML = `${tokenDisplay}`; + tokenSpan.style.textAlign = 'left'; + + const pctSpan = document.createElement('span'); + pctSpan.textContent = `${percentage}%`; + pctSpan.style.textAlign = 'right'; + pctSpan.style.backgroundColor = bgColor; + pctSpan.style.padding = '2px 4px'; + pctSpan.style.borderRadius = '2px'; + + line.appendChild(tokenSpan); + line.appendChild(pctSpan); + container.appendChild(line); + } + + return container; + }, + + // Generic histogram renderer for any BinnedData stat + genericHistogram: function(statKey, color, title) { + return function(value, row, col) { + const histData = row.stats[statKey]; + if (!histData || !histData.bin_counts) { + return 'No data'; + } + + const container = document.createElement('div'); + container.className = 'sparkline-cell'; + + // Calculate bin centers for x-axis + const binCenters = calculateBinCenters(histData.bin_edges); + + // Calculate statistics of underlying data + const min = histData.bin_edges[0]; + const max = histData.bin_edges[histData.bin_edges.length - 1]; + + // Set x-axis limits to [0, 1] if data is in that range + const xlims = (min >= 0 && max <= 1) ? [0, 1] : null; + + // Pass bin centers as x-values and counts as y-values + const svg = sparkbars(binCenters, histData.bin_counts, { + width: CONFIG.visualization.sparklineWidth, + height: CONFIG.visualization.sparklineHeight, + color: color, + shading: true, + lineWidth: 0, + markers: '', + margin: 2, + xlims: xlims, + ylims: [0, null], + logScale: true, + xAxis: {line: true, ticks: true, label_margin: 10}, + yAxis: {line: true, ticks: true, label_margin: CONFIG.visualization.sparklineYAxisMargin} + }); + + container.innerHTML = svg; + + const mean = calculateHistogramMean(histData); + const median = calculateHistogramMedian(histData); + const totalCount = histData.bin_counts.reduce((a, b) => a + b, 0); + + container.title = `${title} (n=${totalCount})\n\nMin: ${min.toFixed(4)}\nMax: ${max.toFixed(4)}\nMean: ${mean.toFixed(4)}\nMedian: ${median.toFixed(4)}`; + + return container; + }; + } +}; + +// ============================================================================ +// Helper Functions for Filtering and Sorting +// ============================================================================ + +/** + * Create a filter function for module arrays that supports wildcards, multiple patterns, and negation + * @param {string} filterValue - The filter pattern (supports * wildcards, , for OR, & for AND, @ for all-match, ! for negation) + * @returns {Function|null} Filter function or null if invalid + */ +function createModuleFilter(filterValue) { + if (!filterValue || !filterValue.trim()) return null; + + // Split by comma for OR groups + const orGroups = filterValue.split(',').map(g => g.trim()).filter(g => g); + + // Parse each OR group (which may contain & for AND) + const parsedOrGroups = orGroups.map(group => { + // Split by & for AND conditions within this OR group + const andConditions = group.split('&').map(c => c.trim()).filter(c => c); + + return andConditions.map(condition => { + let mode = 'some'; // default: at least one module matches + let negate = false; + let pattern = condition.toLowerCase(); + + // Check for @ prefix (all modules must match) + if (pattern.startsWith('@')) { + mode = 'every'; + pattern = pattern.substring(1); + } + // Check for ! prefix (no modules can match) + else if (pattern.startsWith('!')) { + negate = true; + pattern = pattern.substring(1); + } + + const regex = pattern.includes('*') + ? new RegExp('^' + pattern.replace(/\*/g, '.*') + '$') + : null; + + return { mode, negate, pattern, regex }; + }); + }); + + return (cellValue) => { + // cellValue is the modules array + if (!Array.isArray(cellValue)) return false; + + // OR logic across groups + return parsedOrGroups.some(andGroup => { + // AND logic within group + return andGroup.every(condition => { + const matchFn = (module) => { + const moduleLower = module.toLowerCase(); + return condition.regex + ? condition.regex.test(moduleLower) + : moduleLower.includes(condition.pattern); + }; + + if (condition.mode === 'every') { + // ALL modules must match + const result = cellValue.every(matchFn); + return condition.negate ? !result : result; + } else { + // At least ONE module must match (or none if negated) + const result = cellValue.some(matchFn); + return condition.negate ? !result : result; + } + }); + }); + }; +} + +/** + * Sort function for module arrays + * Primary: number of modules (ascending) + * Secondary: alphabetically by first module name + * @param {Array} modules - Array of module names + * @returns {string} Sortable string representation + */ +function sortModules(modules) { + if (!Array.isArray(modules) || modules.length === 0) return ''; + + // Pad module count for proper numeric sorting, then add first module name + const count = modules.length.toString().padStart(5, '0'); + const firstName = modules[0].toLowerCase(); + return `${count}_${firstName}`; +} + +/** + * Parse extended histogram filter syntax (e.g., "mean>0.5", "max<10", "mean>0.5, max<10") + * @param {string} filterValue - The filter string (can be comma-separated for multiple conditions) + * @returns {Array|null} Array of parsed filters [{ statType, operator, value }] or null if plain numeric + */ +function parseHistogramFilter(filterValue) { + const trimmed = filterValue.trim(); + if (!trimmed) return null; + + // Split by comma to support multiple conditions + const conditions = trimmed.split(',').map(c => c.trim()); + const parsedConditions = []; + + for (const condition of conditions) { + // Match pattern: statType operator value (e.g., "mean>0.5", "median<=0.2") + const match = condition.match(/^(mean|median|max|min|range|sum)\s*(==|!=|>=|<=|>|<)\s*(-?\d+\.?\d*)$/i); + + if (match) { + parsedConditions.push({ + statType: match[1].toLowerCase(), + operator: match[2], + value: parseFloat(match[3]) + }); + } else { + // If any condition doesn't match, return null to use default filter + return null; + } + } + + // Return array of conditions, or null if none were found + return parsedConditions.length > 0 ? parsedConditions : null; +} + +/** + * Create a filter function for histogram columns with extended syntax + * Supports multiple comma-separated conditions (AND logic) + * @param {string} statKey - The statistics key + * @param {string} filterValue - The filter string (e.g., "mean>0.5, max<10") + * @returns {Function|null} Filter function or null to use default + */ +function createHistogramFilter(statKey, filterValue) { + const parsedConditions = parseHistogramFilter(filterValue); + + if (!parsedConditions) { + // Return null to let default numeric filter handle it + // Default will filter on the sort value (mean by default) + return null; + } + + return (cellValue, row) => { + // All conditions must be satisfied (AND logic) + for (const condition of parsedConditions) { + const { statType, operator, value } = condition; + const histData = row.stats[statKey]; + + if (!histData || !histData.bin_counts || !histData.bin_edges) return false; + + // Calculate the requested statistic + let statValue; + switch (statType) { + case 'mean': + // For all_activations, use precomputed mean + if (statKey === 'all_activations' && row.stats.mean_activation !== undefined) { + statValue = row.stats.mean_activation; + } else { + statValue = calculateHistogramMean(histData); + } + break; + case 'median': + statValue = calculateHistogramMedian(histData); + break; + case 'max': + statValue = histData.bin_edges[histData.bin_edges.length - 1]; + break; + case 'min': + statValue = histData.bin_edges[0]; + break; + case 'range': + statValue = histData.bin_edges[histData.bin_edges.length - 1] - histData.bin_edges[0]; + break; + case 'sum': + statValue = histData.bin_counts.reduce((a, b) => a + b, 0); + break; + default: + return false; + } + + if (statValue === null || statValue === undefined) return false; + + let conditionMet = false; + switch (operator) { + case '==': conditionMet = Math.abs(statValue - value) < 0.0001; break; + case '!=': conditionMet = Math.abs(statValue - value) >= 0.0001; break; + case '>': conditionMet = statValue > value; break; + case '<': conditionMet = statValue < value; break; + case '>=': conditionMet = statValue >= value; break; + case '<=': conditionMet = statValue <= value; break; + default: conditionMet = false; + } + + // If any condition fails, return false + if (!conditionMet) return false; + } + + // All conditions passed + return true; + }; +} + +/** + * Get the top token string for sorting + * @param {object} value - Cell value (stats object) + * @param {object} row - The data row + * @returns {string} The top token string for sorting + */ +function sortTopToken(value, row) { + const tokenStats = row.stats.token_activations; + if (!tokenStats || !tokenStats.top_tokens || tokenStats.top_tokens.length === 0) { + return ''; + } + return tokenStats.top_tokens[0].token.toLowerCase(); +} + +/** + * Create a filter function for top tokens + * @param {string} filterValue - The filter string + * @returns {Function|null} Filter function or null if invalid + */ +function createTopTokenFilter(filterValue) { + if (!filterValue || !filterValue.trim()) return null; + + const pattern = filterValue.toLowerCase().trim(); + + return (cellValue, row) => { + const tokenStats = row.stats.token_activations; + if (!tokenStats || !tokenStats.top_tokens) return false; + + // Search in top 10 tokens + const topN = Math.min(10, tokenStats.top_tokens.length); + for (let i = 0; i < topN; i++) { + const token = tokenStats.top_tokens[i].token.toLowerCase(); + if (token.includes(pattern)) { + return true; + } + } + return false; + }; +} + +/** + * Create a filter function for numeric comparisons with operators + * @param {string} filterValue - The filter string (e.g., ">2.5", "<=0.8") + * @param {Function} valueExtractor - Function to extract numeric value from cellValue + * @returns {Function|null} Filter function or null if invalid + */ +function createNumericFilter(filterValue, valueExtractor) { + if (!filterValue || !filterValue.trim()) return null; + + const trimmed = filterValue.trim(); + + // Match pattern: operator value (e.g., ">2.5", "<=0.8") + const match = trimmed.match(/^(==|!=|>=|<=|>|<)\s*(-?\d+\.?\d*)$/); + + if (!match) { + // Try plain number (defaults to ==) + const plainNum = parseFloat(trimmed); + if (!isNaN(plainNum)) { + return (cellValue, row) => { + const value = valueExtractor(cellValue); + if (value === null || value === undefined) return false; + return Math.abs(value - plainNum) < 0.0001; + }; + } + return null; + } + + const operator = match[1]; + const targetValue = parseFloat(match[2]); + + return (cellValue, row) => { + const value = valueExtractor(cellValue); + if (value === null || value === undefined) return false; + + switch (operator) { + case '==': return Math.abs(value - targetValue) < 0.0001; + case '!=': return Math.abs(value - targetValue) >= 0.0001; + case '>': return value > targetValue; + case '<': return value < targetValue; + case '>=': return value >= targetValue; + case '<=': return value <= targetValue; + default: return false; + } + }; +} + +function processClusterData() { + const tableData = []; + + for (const [clusterHash, cluster] of Object.entries(clusterData)) { + const modules = new Set(); + cluster.components.forEach(comp => { + modules.add(comp.module); + }); + + const stats = cluster.stats; + + // Extract cluster ID from hash (format: "runid-iteration-clusteridx") + const parts = clusterHash.split('-'); + const clusterId = parseInt(parts[parts.length - 1]); + + // Get explanation for this cluster + const explanationData = explanations[clusterHash]; + const explanation = explanationData ? explanationData.explanation : null; + + tableData.push({ + id: clusterId, + clusterHash: clusterHash, + componentCount: cluster.components.length, + modules: Array.from(modules), + stats: stats, + explanation: explanation + }); + } + + return tableData; +} + +async function loadData() { + // Load cluster data (model info is handled by Alpine.js) + const clusters = await loadJSONL(CONFIG.getDataPath('clusters'), 'cluster_hash'); + + clusterData = clusters; + + // Load explanations (non-critical, don't fail if missing) + explanations = await loadJSONL(CONFIG.getDataPath('explanations'), 'cluster_id').catch(() => ({})); + + const tableData = processClusterData(); + + // Discover histogram stats from first cluster + const firstCluster = Object.values(clusterData)[0]; + const histogramStats = []; + if (firstCluster && firstCluster.stats) { + for (const [key, value] of Object.entries(firstCluster.stats)) { + if (value && typeof value === 'object' && 'bin_counts' in value && 'bin_edges' in value) { + histogramStats.push(key); + } + } + } + + // Base columns + const columns = [ + { + key: 'id', + label: 'ID', + type: 'number', + width: '10px', + align: 'center' + }, + { + key: 'componentCount', + label: 'Comps', + type: 'number', + width: '10px', + align: 'right' + }, + { + key: 'modules', + label: 'Model View', + type: 'string', + width: '21px', + align: 'center', + renderer: columnRenderers.modelView, + sortFunction: (modules) => sortModules(modules), + filterFunction: (filterValue) => createModuleFilter(filterValue), + filterTooltip: 'Filter by module. Separate with , (OR) or & (AND). Use * for wildcards. Prefix @ for all-match, ! to exclude. Examples: *mlp*,*attn* (OR), *mlp*&*attn* (AND), @*proj* (all), !*o_proj* (exclude)' + }, + { + key: 'modules', + label: 'Modules', + type: 'string', + width: '10px', + renderer: columnRenderers.modulesSummary, + sortFunction: (modules) => sortModules(modules), + filterFunction: (filterValue) => createModuleFilter(filterValue), + filterTooltip: 'Filter by module. Separate with , (OR) or & (AND). Use * for wildcards. Prefix @ for all-match, ! to exclude. Examples: *mlp*,*attn* (OR), *mlp*&*attn* (AND), @*proj* (all), !*o_proj* (exclude)' + } + ]; + + // Add histogram columns dynamically + const statColors = { + 'all_activations': '#4169E1', + 'max_activation-max-16': '#DC143C', + 'max_activation-max-32': '#DC143C', + 'mean_activation-max-16': '#228B22', + 'median_activation-max-16': '#FF8C00', + 'min_activation-max-16': '#9370DB', + 'max_activation_position': '#FF6347' + }; + + histogramStats.forEach(statKey => { + const color = statColors[statKey] || '#808080'; + const label = statKey.replace(/-/g, ' ').replace(/_/g, ' ') + .split(' ') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(' '); + + columns.push({ + id: 'histogram_' + statKey, + key: 'stats', + label: label, + type: 'number', + width: '200px', + align: 'center', + renderer: columnRenderers.genericHistogram(statKey, color, label), + sortFunction: (value, row) => { + const histData = row.stats[statKey]; + if (!histData || !histData.bin_counts || !histData.bin_edges) return -Infinity; + // For all_activations, use precomputed mean + if (statKey === 'all_activations' && row.stats.mean_activation !== undefined) { + return row.stats.mean_activation; + } + // Otherwise calculate mean from histogram + return calculateHistogramMean(histData); + }, + filterFunction: (filterValue) => createHistogramFilter(statKey, filterValue), + filterTooltip: 'Filter by statistics. Use: mean>0.5, median<0.2, max>=1.0, min>-0.1, range<5, sum>100. Combine with commas (e.g., mean>0.5, max<10)' + }); + }); + + // Token activation columns + columns.push({ + id: 'top_tokens', + key: 'stats', + label: 'Top Tokens', + type: 'string', + width: '150px', + align: 'left', + renderer: columnRenderers.topToken, + sortFunction: (value, row) => sortTopToken(value, row), + filterFunction: (filterValue) => createTopTokenFilter(filterValue), + filterTooltip: 'Search for tokens (case-insensitive substring match)' + }); + + columns.push({ + id: 'token_entropy', + key: 'stats', + label: 'Token Entropy', + type: 'number', + width: '60px', + align: 'right', + renderer: columnRenderers.tokenEntropy, + sortFunction: (value, row) => { + const tokenStats = row.stats.token_activations; + return tokenStats ? tokenStats.entropy : -Infinity; + }, + filterFunction: (filterValue) => createNumericFilter(filterValue, (stats) => { + const tokenStats = stats?.token_activations; + return tokenStats ? tokenStats.entropy : null; + }), + filterTooltip: 'Filter by entropy. Use operators: >, <, >=, <=, ==, != (e.g., >2.5)' + }); + + columns.push({ + id: 'token_concentration', + key: 'stats', + label: 'Token Conc.', + type: 'number', + width: '60px', + align: 'right', + renderer: columnRenderers.tokenConcentration, + sortFunction: (value, row) => { + const tokenStats = row.stats.token_activations; + return tokenStats ? tokenStats.concentration_ratio : -Infinity; + }, + filterFunction: (filterValue) => createNumericFilter(filterValue, (stats) => { + const tokenStats = stats?.token_activations; + return tokenStats ? tokenStats.concentration_ratio : null; + }), + filterTooltip: 'Filter by concentration (0-1). Use operators: >, <, >=, <=, ==, != (e.g., >0.5)' + }); + + // Explanation column + columns.push({ + key: 'explanation', + label: 'Explanation', + type: 'string', + width: '200px', + align: 'left', + renderer: columnRenderers.explanation, + filterTooltip: 'Filter by explanation text (case-insensitive substring match)' + }); + + // Actions column + columns.push({ + key: 'id', + label: 'Actions', + type: 'string', + width: '20px', + align: 'center', + renderer: columnRenderers.clusterLink, + filterable: false + }); + + const tableConfig = { + data: tableData, + columns: columns, + pageSize: CONFIG.indexPage.pageSize, + pageSizeOptions: CONFIG.indexPage.pageSizeOptions, + showFilters: CONFIG.indexPage.showFilters + }; + + dataTable = new DataTable('#clusterTableContainer', tableConfig); + + const loading = document.getElementById('loading'); + if (!loading) { + const msg = 'Fatal error: loading element not found in HTML'; + NOTIF.error(msg, null, null); + console.error(msg); + return; + } + loading.style.display = 'none'; +} + +document.addEventListener('DOMContentLoaded', async () => { + await initConfig(); + + // Check if Alpine.js loaded + if (typeof Alpine === 'undefined') { + const msg = 'Fatal error: Alpine.js failed to load. Check your internet connection or CDN.'; + NOTIF.error(msg, null, null); + console.error(msg); + } else { + // Manually trigger Alpine component's loadData now that CONFIG is ready + const modelInfoEl = document.getElementById('modelInfo'); + if (modelInfoEl && Alpine.$data(modelInfoEl)) { + Alpine.$data(modelInfoEl).loadData(); + } + } + + // Load cluster data and render table + loadData(); +}); diff --git a/spd/clustering/ci_dt/js/model-visualization.js b/spd/clustering/ci_dt/js/model-visualization.js new file mode 100644 index 000000000..f42e55922 --- /dev/null +++ b/spd/clustering/ci_dt/js/model-visualization.js @@ -0,0 +1,222 @@ +// Self-contained utilities for model visualization +// No global variables, all functions take necessary data as parameters + +function getClusterModuleStats(clusterId, clusterData) { + if (!clusterData || !clusterData[clusterId]) return {}; + + const cluster = clusterData[clusterId]; + const moduleStats = {}; + + // Count components per module for this specific cluster + cluster.components.forEach(comp => { + const module = comp.module; + if (!moduleStats[module]) { + moduleStats[module] = { + componentCount: 0, + components: [] + }; + } + moduleStats[module].componentCount++; + moduleStats[module].components.push(comp); + }); + + return moduleStats; +} + +function getModuleOrder(moduleName) { + if (moduleName.includes('q_proj')) return 0; + if (moduleName.includes('k_proj')) return 1; + if (moduleName.includes('v_proj')) return 2; + if (moduleName.includes('o_proj')) return 3; + if (moduleName.includes('gate_proj')) return 10; + if (moduleName.includes('up_proj')) return 11; + if (moduleName.includes('down_proj')) return 12; + return 999; +} + +function renderModelArchitecture(clusterId, clusterData, modelInfo, colormap = 'blues') { + if (!modelInfo || !modelInfo.module_list) { + throw new Error('Model info not loaded'); + } + + const moduleStats = clusterData && clusterData[clusterId] ? getClusterModuleStats(clusterId, clusterData) : {}; + const maxComponents = Math.max(...Object.values(moduleStats).map(s => s.componentCount), 1); + + // Group ALL modules from model_info by layer and type + const layerGroups = {}; + + modelInfo.module_list.forEach(moduleName => { + const parts = moduleName.split('.'); + let layerNum = -1; + let moduleType = 'other'; + + for (let i = 0; i < parts.length; i++) { + if (parts[i] === 'layers' && i + 1 < parts.length) { + layerNum = parseInt(parts[i + 1]); + } + } + + if (moduleName.includes('self_attn')) { + moduleType = 'attention'; + } else if (moduleName.includes('mlp')) { + moduleType = 'mlp'; + } + + if (!layerGroups[layerNum]) { + layerGroups[layerNum] = { attention: [], mlp: [], other: [] }; + } + + const count = moduleStats[moduleName] ? moduleStats[moduleName].componentCount : 0; + const components = moduleStats[moduleName] ? moduleStats[moduleName].components : []; + + layerGroups[layerNum][moduleType].push({ + name: moduleName, + count: count, + components: components + }); + }); + + // Sort modules within each group by desired order + Object.values(layerGroups).forEach(layer => { + layer.attention.sort((a, b) => getModuleOrder(a.name) - getModuleOrder(b.name)); + layer.mlp.sort((a, b) => getModuleOrder(a.name) - getModuleOrder(b.name)); + }); + + const sortedLayers = Object.keys(layerGroups).sort((a, b) => a - b); + const cellSize = 12; + + const moduleElements = []; + + sortedLayers.forEach(layerNum => { + const layer = layerGroups[layerNum]; + const layerElements = []; + + // Attention row (above MLP) + if (layer.attention.length > 0) { + const attentionRow = layer.attention.map(module => ({ + type: 'cell', + module: module.name, + count: module.count, + components: module.components.map(c => c.index).join(','), + color: getColorForValue(module.count, maxComponents, colormap), + size: cellSize + })); + layerElements.push({ type: 'row', cells: attentionRow }); + } + + // MLP row (below attention) + if (layer.mlp.length > 0) { + const mlpRow = layer.mlp.map(module => ({ + type: 'cell', + module: module.name, + count: module.count, + components: module.components.map(c => c.index).join(','), + color: getColorForValue(module.count, maxComponents, colormap), + size: cellSize + })); + layerElements.push({ type: 'row', cells: mlpRow }); + } + + // Other modules + if (layer.other.length > 0) { + const otherRow = layer.other.map(module => ({ + type: 'cell', + module: module.name, + count: module.count, + components: module.components.map(c => c.index).join(','), + color: getColorForValue(module.count, maxComponents, colormap), + size: cellSize + })); + layerElements.push({ type: 'row', cells: otherRow }); + } + + if (layerElements.length > 0) { + moduleElements.push({ type: 'layer', rows: layerElements }); + } + }); + + return { + elements: moduleElements, + maxComponents: maxComponents + }; +} + +function renderToHTML(architecture) { + let html = ''; + + architecture.elements.forEach(layer => { + html += '
'; + layer.rows.forEach(row => { + html += '
'; + row.cells.forEach(cell => { + html += `
`; + }); + html += '
'; + }); + html += '
'; + }); + + return html; +} + +// Consolidated tooltip setup - works for all model visualizations +function setupTooltips(containerElement) { + const tooltip = document.getElementById('tooltip'); + if (!tooltip) return; + + const cells = containerElement.querySelectorAll('.modelview-module-cell'); + + cells.forEach(cell => { + cell.addEventListener('mouseenter', (e) => { + const module = e.target.dataset.module; + const count = e.target.dataset.count; + const components = e.target.dataset.components; + + if (module) { + tooltip.textContent = `${module}\nComponents: ${count}\nIndices: ${components || 'none'}`; + tooltip.style.display = 'block'; + tooltip.style.left = (e.pageX + 10) + 'px'; + tooltip.style.top = (e.pageY + 10) + 'px'; + } + }); + + cell.addEventListener('mouseleave', () => { + tooltip.style.display = 'none'; + }); + + cell.addEventListener('mousemove', (e) => { + tooltip.style.left = (e.pageX + 10) + 'px'; + tooltip.style.top = (e.pageY + 10) + 'px'; + }); + }); +} + +// Consolidated render function - creates model visualization in a container +function renderModelView(containerElement, clusterHash, clusterData, modelInfo, colormap = 'blues', cellSize = null) { + if (!modelInfo || !modelInfo.module_list) { + containerElement.innerHTML = 'Model info loading...'; + return; + } + + if (!clusterData || !clusterData[clusterHash]) { + containerElement.innerHTML = 'Cluster data missing'; + return; + } + + try { + const architecture = renderModelArchitecture(clusterHash, clusterData, modelInfo, colormap); + const html = renderToHTML(architecture); + containerElement.innerHTML = html; + + // Apply cell size from config if provided + if (cellSize !== null) { + containerElement.style.setProperty('--modelview-cell-size', cellSize + 'px'); + } + + // Setup tooltips after a brief delay to ensure DOM is ready + setTimeout(() => setupTooltips(containerElement), 0); + } catch (error) { + console.error('Failed to render model visualization:', error); + containerElement.innerHTML = 'Model visualization error'; + } +} \ No newline at end of file diff --git a/spd/clustering/ci_dt/js/pkg/jszip.js b/spd/clustering/ci_dt/js/pkg/jszip.js new file mode 100644 index 000000000..60fbb41a6 --- /dev/null +++ b/spd/clustering/ci_dt/js/pkg/jszip.js @@ -0,0 +1,11577 @@ +/*! + +JSZip v3.10.1 - A JavaScript class for generating and reading zip files + + +(c) 2009-2016 Stuart Knightley +Dual licenced under the MIT license or GPLv3. See https://raw.github.com/Stuk/jszip/main/LICENSE.markdown. + +JSZip uses the library pako released under the MIT license : +https://github.com/nodeca/pako/blob/main/LICENSE +*/ + +(function(f){if(typeof exports==="object"&&typeof module!=="undefined"){module.exports=f()}else if(typeof define==="function"&&define.amd){define([],f)}else{var g;if(typeof window!=="undefined"){g=window}else if(typeof global!=="undefined"){g=global}else if(typeof self!=="undefined"){g=self}else{g=this}g.JSZip = f()}})(function(){var define,module,exports;return (function e(t,n,r){function s(o,u){if(!n[o]){if(!t[o]){var a=typeof require=="function"&&require;if(!u&&a)return a(o,!0);if(i)return i(o,!0);var f=new Error("Cannot find module '"+o+"'");throw f.code="MODULE_NOT_FOUND",f}var l=n[o]={exports:{}};t[o][0].call(l.exports,function(e){var n=t[o][1][e];return s(n?n:e)},l,l.exports,e,t,n,r)}return n[o].exports}var i=typeof require=="function"&&require;for(var o=0;o> 2; + enc2 = ((chr1 & 3) << 4) | (chr2 >> 4); + enc3 = remainingBytes > 1 ? (((chr2 & 15) << 2) | (chr3 >> 6)) : 64; + enc4 = remainingBytes > 2 ? (chr3 & 63) : 64; + + output.push(_keyStr.charAt(enc1) + _keyStr.charAt(enc2) + _keyStr.charAt(enc3) + _keyStr.charAt(enc4)); + + } + + return output.join(""); +}; + +// public method for decoding +exports.decode = function(input) { + var chr1, chr2, chr3; + var enc1, enc2, enc3, enc4; + var i = 0, resultIndex = 0; + + var dataUrlPrefix = "data:"; + + if (input.substr(0, dataUrlPrefix.length) === dataUrlPrefix) { + // This is a common error: people give a data url + // (data:image/png;base64,iVBOR...) with a {base64: true} and + // wonders why things don't work. + // We can detect that the string input looks like a data url but we + // *can't* be sure it is one: removing everything up to the comma would + // be too dangerous. + throw new Error("Invalid base64 input, it looks like a data url."); + } + + input = input.replace(/[^A-Za-z0-9+/=]/g, ""); + + var totalLength = input.length * 3 / 4; + if(input.charAt(input.length - 1) === _keyStr.charAt(64)) { + totalLength--; + } + if(input.charAt(input.length - 2) === _keyStr.charAt(64)) { + totalLength--; + } + if (totalLength % 1 !== 0) { + // totalLength is not an integer, the length does not match a valid + // base64 content. That can happen if: + // - the input is not a base64 content + // - the input is *almost* a base64 content, with a extra chars at the + // beginning or at the end + // - the input uses a base64 variant (base64url for example) + throw new Error("Invalid base64 input, bad content length."); + } + var output; + if (support.uint8array) { + output = new Uint8Array(totalLength|0); + } else { + output = new Array(totalLength|0); + } + + while (i < input.length) { + + enc1 = _keyStr.indexOf(input.charAt(i++)); + enc2 = _keyStr.indexOf(input.charAt(i++)); + enc3 = _keyStr.indexOf(input.charAt(i++)); + enc4 = _keyStr.indexOf(input.charAt(i++)); + + chr1 = (enc1 << 2) | (enc2 >> 4); + chr2 = ((enc2 & 15) << 4) | (enc3 >> 2); + chr3 = ((enc3 & 3) << 6) | enc4; + + output[resultIndex++] = chr1; + + if (enc3 !== 64) { + output[resultIndex++] = chr2; + } + if (enc4 !== 64) { + output[resultIndex++] = chr3; + } + + } + + return output; +}; + +},{"./support":30,"./utils":32}],2:[function(require,module,exports){ +"use strict"; + +var external = require("./external"); +var DataWorker = require("./stream/DataWorker"); +var Crc32Probe = require("./stream/Crc32Probe"); +var DataLengthProbe = require("./stream/DataLengthProbe"); + +/** + * Represent a compressed object, with everything needed to decompress it. + * @constructor + * @param {number} compressedSize the size of the data compressed. + * @param {number} uncompressedSize the size of the data after decompression. + * @param {number} crc32 the crc32 of the decompressed file. + * @param {object} compression the type of compression, see lib/compressions.js. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the compressed data. + */ +function CompressedObject(compressedSize, uncompressedSize, crc32, compression, data) { + this.compressedSize = compressedSize; + this.uncompressedSize = uncompressedSize; + this.crc32 = crc32; + this.compression = compression; + this.compressedContent = data; +} + +CompressedObject.prototype = { + /** + * Create a worker to get the uncompressed content. + * @return {GenericWorker} the worker. + */ + getContentWorker: function () { + var worker = new DataWorker(external.Promise.resolve(this.compressedContent)) + .pipe(this.compression.uncompressWorker()) + .pipe(new DataLengthProbe("data_length")); + + var that = this; + worker.on("end", function () { + if (this.streamInfo["data_length"] !== that.uncompressedSize) { + throw new Error("Bug : uncompressed data size mismatch"); + } + }); + return worker; + }, + /** + * Create a worker to get the compressed content. + * @return {GenericWorker} the worker. + */ + getCompressedWorker: function () { + return new DataWorker(external.Promise.resolve(this.compressedContent)) + .withStreamInfo("compressedSize", this.compressedSize) + .withStreamInfo("uncompressedSize", this.uncompressedSize) + .withStreamInfo("crc32", this.crc32) + .withStreamInfo("compression", this.compression) + ; + } +}; + +/** + * Chain the given worker with other workers to compress the content with the + * given compression. + * @param {GenericWorker} uncompressedWorker the worker to pipe. + * @param {Object} compression the compression object. + * @param {Object} compressionOptions the options to use when compressing. + * @return {GenericWorker} the new worker compressing the content. + */ +CompressedObject.createWorkerFrom = function (uncompressedWorker, compression, compressionOptions) { + return uncompressedWorker + .pipe(new Crc32Probe()) + .pipe(new DataLengthProbe("uncompressedSize")) + .pipe(compression.compressWorker(compressionOptions)) + .pipe(new DataLengthProbe("compressedSize")) + .withStreamInfo("compression", compression); +}; + +module.exports = CompressedObject; + +},{"./external":6,"./stream/Crc32Probe":25,"./stream/DataLengthProbe":26,"./stream/DataWorker":27}],3:[function(require,module,exports){ +"use strict"; + +var GenericWorker = require("./stream/GenericWorker"); + +exports.STORE = { + magic: "\x00\x00", + compressWorker : function () { + return new GenericWorker("STORE compression"); + }, + uncompressWorker : function () { + return new GenericWorker("STORE decompression"); + } +}; +exports.DEFLATE = require("./flate"); + +},{"./flate":7,"./stream/GenericWorker":28}],4:[function(require,module,exports){ +"use strict"; + +var utils = require("./utils"); + +/** + * The following functions come from pako, from pako/lib/zlib/crc32.js + * released under the MIT license, see pako https://github.com/nodeca/pako/ + */ + +// Use ordinary array, since untyped makes no boost here +function makeTable() { + var c, table = []; + + for(var n =0; n < 256; n++){ + c = n; + for(var k =0; k < 8; k++){ + c = ((c&1) ? (0xEDB88320 ^ (c >>> 1)) : (c >>> 1)); + } + table[n] = c; + } + + return table; +} + +// Create table on load. Just 255 signed longs. Not a problem. +var crcTable = makeTable(); + + +function crc32(crc, buf, len, pos) { + var t = crcTable, end = pos + len; + + crc = crc ^ (-1); + + for (var i = pos; i < end; i++ ) { + crc = (crc >>> 8) ^ t[(crc ^ buf[i]) & 0xFF]; + } + + return (crc ^ (-1)); // >>> 0; +} + +// That's all for the pako functions. + +/** + * Compute the crc32 of a string. + * This is almost the same as the function crc32, but for strings. Using the + * same function for the two use cases leads to horrible performances. + * @param {Number} crc the starting value of the crc. + * @param {String} str the string to use. + * @param {Number} len the length of the string. + * @param {Number} pos the starting position for the crc32 computation. + * @return {Number} the computed crc32. + */ +function crc32str(crc, str, len, pos) { + var t = crcTable, end = pos + len; + + crc = crc ^ (-1); + + for (var i = pos; i < end; i++ ) { + crc = (crc >>> 8) ^ t[(crc ^ str.charCodeAt(i)) & 0xFF]; + } + + return (crc ^ (-1)); // >>> 0; +} + +module.exports = function crc32wrapper(input, crc) { + if (typeof input === "undefined" || !input.length) { + return 0; + } + + var isArray = utils.getTypeOf(input) !== "string"; + + if(isArray) { + return crc32(crc|0, input, input.length, 0); + } else { + return crc32str(crc|0, input, input.length, 0); + } +}; + +},{"./utils":32}],5:[function(require,module,exports){ +"use strict"; +exports.base64 = false; +exports.binary = false; +exports.dir = false; +exports.createFolders = true; +exports.date = null; +exports.compression = null; +exports.compressionOptions = null; +exports.comment = null; +exports.unixPermissions = null; +exports.dosPermissions = null; + +},{}],6:[function(require,module,exports){ +"use strict"; + +// load the global object first: +// - it should be better integrated in the system (unhandledRejection in node) +// - the environment may have a custom Promise implementation (see zone.js) +var ES6Promise = null; +if (typeof Promise !== "undefined") { + ES6Promise = Promise; +} else { + ES6Promise = require("lie"); +} + +/** + * Let the user use/change some implementations. + */ +module.exports = { + Promise: ES6Promise +}; + +},{"lie":37}],7:[function(require,module,exports){ +"use strict"; +var USE_TYPEDARRAY = (typeof Uint8Array !== "undefined") && (typeof Uint16Array !== "undefined") && (typeof Uint32Array !== "undefined"); + +var pako = require("pako"); +var utils = require("./utils"); +var GenericWorker = require("./stream/GenericWorker"); + +var ARRAY_TYPE = USE_TYPEDARRAY ? "uint8array" : "array"; + +exports.magic = "\x08\x00"; + +/** + * Create a worker that uses pako to inflate/deflate. + * @constructor + * @param {String} action the name of the pako function to call : either "Deflate" or "Inflate". + * @param {Object} options the options to use when (de)compressing. + */ +function FlateWorker(action, options) { + GenericWorker.call(this, "FlateWorker/" + action); + + this._pako = null; + this._pakoAction = action; + this._pakoOptions = options; + // the `meta` object from the last chunk received + // this allow this worker to pass around metadata + this.meta = {}; +} + +utils.inherits(FlateWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +FlateWorker.prototype.processChunk = function (chunk) { + this.meta = chunk.meta; + if (this._pako === null) { + this._createPako(); + } + this._pako.push(utils.transformTo(ARRAY_TYPE, chunk.data), false); +}; + +/** + * @see GenericWorker.flush + */ +FlateWorker.prototype.flush = function () { + GenericWorker.prototype.flush.call(this); + if (this._pako === null) { + this._createPako(); + } + this._pako.push([], true); +}; +/** + * @see GenericWorker.cleanUp + */ +FlateWorker.prototype.cleanUp = function () { + GenericWorker.prototype.cleanUp.call(this); + this._pako = null; +}; + +/** + * Create the _pako object. + * TODO: lazy-loading this object isn't the best solution but it's the + * quickest. The best solution is to lazy-load the worker list. See also the + * issue #446. + */ +FlateWorker.prototype._createPako = function () { + this._pako = new pako[this._pakoAction]({ + raw: true, + level: this._pakoOptions.level || -1 // default compression + }); + var self = this; + this._pako.onData = function(data) { + self.push({ + data : data, + meta : self.meta + }); + }; +}; + +exports.compressWorker = function (compressionOptions) { + return new FlateWorker("Deflate", compressionOptions); +}; +exports.uncompressWorker = function () { + return new FlateWorker("Inflate", {}); +}; + +},{"./stream/GenericWorker":28,"./utils":32,"pako":38}],8:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("../stream/GenericWorker"); +var utf8 = require("../utf8"); +var crc32 = require("../crc32"); +var signature = require("../signature"); + +/** + * Transform an integer into a string in hexadecimal. + * @private + * @param {number} dec the number to convert. + * @param {number} bytes the number of bytes to generate. + * @returns {string} the result. + */ +var decToHex = function(dec, bytes) { + var hex = "", i; + for (i = 0; i < bytes; i++) { + hex += String.fromCharCode(dec & 0xff); + dec = dec >>> 8; + } + return hex; +}; + +/** + * Generate the UNIX part of the external file attributes. + * @param {Object} unixPermissions the unix permissions or null. + * @param {Boolean} isDir true if the entry is a directory, false otherwise. + * @return {Number} a 32 bit integer. + * + * adapted from http://unix.stackexchange.com/questions/14705/the-zip-formats-external-file-attribute : + * + * TTTTsstrwxrwxrwx0000000000ADVSHR + * ^^^^____________________________ file type, see zipinfo.c (UNX_*) + * ^^^_________________________ setuid, setgid, sticky + * ^^^^^^^^^________________ permissions + * ^^^^^^^^^^______ not used ? + * ^^^^^^ DOS attribute bits : Archive, Directory, Volume label, System file, Hidden, Read only + */ +var generateUnixExternalFileAttr = function (unixPermissions, isDir) { + + var result = unixPermissions; + if (!unixPermissions) { + // I can't use octal values in strict mode, hence the hexa. + // 040775 => 0x41fd + // 0100664 => 0x81b4 + result = isDir ? 0x41fd : 0x81b4; + } + return (result & 0xFFFF) << 16; +}; + +/** + * Generate the DOS part of the external file attributes. + * @param {Object} dosPermissions the dos permissions or null. + * @param {Boolean} isDir true if the entry is a directory, false otherwise. + * @return {Number} a 32 bit integer. + * + * Bit 0 Read-Only + * Bit 1 Hidden + * Bit 2 System + * Bit 3 Volume Label + * Bit 4 Directory + * Bit 5 Archive + */ +var generateDosExternalFileAttr = function (dosPermissions) { + // the dir flag is already set for compatibility + return (dosPermissions || 0) & 0x3F; +}; + +/** + * Generate the various parts used in the construction of the final zip file. + * @param {Object} streamInfo the hash with information about the compressed file. + * @param {Boolean} streamedContent is the content streamed ? + * @param {Boolean} streamingEnded is the stream finished ? + * @param {number} offset the current offset from the start of the zip file. + * @param {String} platform let's pretend we are this platform (change platform dependents fields) + * @param {Function} encodeFileName the function to encode the file name / comment. + * @return {Object} the zip parts. + */ +var generateZipParts = function(streamInfo, streamedContent, streamingEnded, offset, platform, encodeFileName) { + var file = streamInfo["file"], + compression = streamInfo["compression"], + useCustomEncoding = encodeFileName !== utf8.utf8encode, + encodedFileName = utils.transformTo("string", encodeFileName(file.name)), + utfEncodedFileName = utils.transformTo("string", utf8.utf8encode(file.name)), + comment = file.comment, + encodedComment = utils.transformTo("string", encodeFileName(comment)), + utfEncodedComment = utils.transformTo("string", utf8.utf8encode(comment)), + useUTF8ForFileName = utfEncodedFileName.length !== file.name.length, + useUTF8ForComment = utfEncodedComment.length !== comment.length, + dosTime, + dosDate, + extraFields = "", + unicodePathExtraField = "", + unicodeCommentExtraField = "", + dir = file.dir, + date = file.date; + + + var dataInfo = { + crc32 : 0, + compressedSize : 0, + uncompressedSize : 0 + }; + + // if the content is streamed, the sizes/crc32 are only available AFTER + // the end of the stream. + if (!streamedContent || streamingEnded) { + dataInfo.crc32 = streamInfo["crc32"]; + dataInfo.compressedSize = streamInfo["compressedSize"]; + dataInfo.uncompressedSize = streamInfo["uncompressedSize"]; + } + + var bitflag = 0; + if (streamedContent) { + // Bit 3: the sizes/crc32 are set to zero in the local header. + // The correct values are put in the data descriptor immediately + // following the compressed data. + bitflag |= 0x0008; + } + if (!useCustomEncoding && (useUTF8ForFileName || useUTF8ForComment)) { + // Bit 11: Language encoding flag (EFS). + bitflag |= 0x0800; + } + + + var extFileAttr = 0; + var versionMadeBy = 0; + if (dir) { + // dos or unix, we set the dos dir flag + extFileAttr |= 0x00010; + } + if(platform === "UNIX") { + versionMadeBy = 0x031E; // UNIX, version 3.0 + extFileAttr |= generateUnixExternalFileAttr(file.unixPermissions, dir); + } else { // DOS or other, fallback to DOS + versionMadeBy = 0x0014; // DOS, version 2.0 + extFileAttr |= generateDosExternalFileAttr(file.dosPermissions, dir); + } + + // date + // @see http://www.delorie.com/djgpp/doc/rbinter/it/52/13.html + // @see http://www.delorie.com/djgpp/doc/rbinter/it/65/16.html + // @see http://www.delorie.com/djgpp/doc/rbinter/it/66/16.html + + dosTime = date.getUTCHours(); + dosTime = dosTime << 6; + dosTime = dosTime | date.getUTCMinutes(); + dosTime = dosTime << 5; + dosTime = dosTime | date.getUTCSeconds() / 2; + + dosDate = date.getUTCFullYear() - 1980; + dosDate = dosDate << 4; + dosDate = dosDate | (date.getUTCMonth() + 1); + dosDate = dosDate << 5; + dosDate = dosDate | date.getUTCDate(); + + if (useUTF8ForFileName) { + // set the unicode path extra field. unzip needs at least one extra + // field to correctly handle unicode path, so using the path is as good + // as any other information. This could improve the situation with + // other archive managers too. + // This field is usually used without the utf8 flag, with a non + // unicode path in the header (winrar, winzip). This helps (a bit) + // with the messy Windows' default compressed folders feature but + // breaks on p7zip which doesn't seek the unicode path extra field. + // So for now, UTF-8 everywhere ! + unicodePathExtraField = + // Version + decToHex(1, 1) + + // NameCRC32 + decToHex(crc32(encodedFileName), 4) + + // UnicodeName + utfEncodedFileName; + + extraFields += + // Info-ZIP Unicode Path Extra Field + "\x75\x70" + + // size + decToHex(unicodePathExtraField.length, 2) + + // content + unicodePathExtraField; + } + + if(useUTF8ForComment) { + + unicodeCommentExtraField = + // Version + decToHex(1, 1) + + // CommentCRC32 + decToHex(crc32(encodedComment), 4) + + // UnicodeName + utfEncodedComment; + + extraFields += + // Info-ZIP Unicode Path Extra Field + "\x75\x63" + + // size + decToHex(unicodeCommentExtraField.length, 2) + + // content + unicodeCommentExtraField; + } + + var header = ""; + + // version needed to extract + header += "\x0A\x00"; + // general purpose bit flag + header += decToHex(bitflag, 2); + // compression method + header += compression.magic; + // last mod file time + header += decToHex(dosTime, 2); + // last mod file date + header += decToHex(dosDate, 2); + // crc-32 + header += decToHex(dataInfo.crc32, 4); + // compressed size + header += decToHex(dataInfo.compressedSize, 4); + // uncompressed size + header += decToHex(dataInfo.uncompressedSize, 4); + // file name length + header += decToHex(encodedFileName.length, 2); + // extra field length + header += decToHex(extraFields.length, 2); + + + var fileRecord = signature.LOCAL_FILE_HEADER + header + encodedFileName + extraFields; + + var dirRecord = signature.CENTRAL_FILE_HEADER + + // version made by (00: DOS) + decToHex(versionMadeBy, 2) + + // file header (common to file and central directory) + header + + // file comment length + decToHex(encodedComment.length, 2) + + // disk number start + "\x00\x00" + + // internal file attributes TODO + "\x00\x00" + + // external file attributes + decToHex(extFileAttr, 4) + + // relative offset of local header + decToHex(offset, 4) + + // file name + encodedFileName + + // extra field + extraFields + + // file comment + encodedComment; + + return { + fileRecord: fileRecord, + dirRecord: dirRecord + }; +}; + +/** + * Generate the EOCD record. + * @param {Number} entriesCount the number of entries in the zip file. + * @param {Number} centralDirLength the length (in bytes) of the central dir. + * @param {Number} localDirLength the length (in bytes) of the local dir. + * @param {String} comment the zip file comment as a binary string. + * @param {Function} encodeFileName the function to encode the comment. + * @return {String} the EOCD record. + */ +var generateCentralDirectoryEnd = function (entriesCount, centralDirLength, localDirLength, comment, encodeFileName) { + var dirEnd = ""; + var encodedComment = utils.transformTo("string", encodeFileName(comment)); + + // end of central dir signature + dirEnd = signature.CENTRAL_DIRECTORY_END + + // number of this disk + "\x00\x00" + + // number of the disk with the start of the central directory + "\x00\x00" + + // total number of entries in the central directory on this disk + decToHex(entriesCount, 2) + + // total number of entries in the central directory + decToHex(entriesCount, 2) + + // size of the central directory 4 bytes + decToHex(centralDirLength, 4) + + // offset of start of central directory with respect to the starting disk number + decToHex(localDirLength, 4) + + // .ZIP file comment length + decToHex(encodedComment.length, 2) + + // .ZIP file comment + encodedComment; + + return dirEnd; +}; + +/** + * Generate data descriptors for a file entry. + * @param {Object} streamInfo the hash generated by a worker, containing information + * on the file entry. + * @return {String} the data descriptors. + */ +var generateDataDescriptors = function (streamInfo) { + var descriptor = ""; + descriptor = signature.DATA_DESCRIPTOR + + // crc-32 4 bytes + decToHex(streamInfo["crc32"], 4) + + // compressed size 4 bytes + decToHex(streamInfo["compressedSize"], 4) + + // uncompressed size 4 bytes + decToHex(streamInfo["uncompressedSize"], 4); + + return descriptor; +}; + + +/** + * A worker to concatenate other workers to create a zip file. + * @param {Boolean} streamFiles `true` to stream the content of the files, + * `false` to accumulate it. + * @param {String} comment the comment to use. + * @param {String} platform the platform to use, "UNIX" or "DOS". + * @param {Function} encodeFileName the function to encode file names and comments. + */ +function ZipFileWorker(streamFiles, comment, platform, encodeFileName) { + GenericWorker.call(this, "ZipFileWorker"); + // The number of bytes written so far. This doesn't count accumulated chunks. + this.bytesWritten = 0; + // The comment of the zip file + this.zipComment = comment; + // The platform "generating" the zip file. + this.zipPlatform = platform; + // the function to encode file names and comments. + this.encodeFileName = encodeFileName; + // Should we stream the content of the files ? + this.streamFiles = streamFiles; + // If `streamFiles` is false, we will need to accumulate the content of the + // files to calculate sizes / crc32 (and write them *before* the content). + // This boolean indicates if we are accumulating chunks (it will change a lot + // during the lifetime of this worker). + this.accumulate = false; + // The buffer receiving chunks when accumulating content. + this.contentBuffer = []; + // The list of generated directory records. + this.dirRecords = []; + // The offset (in bytes) from the beginning of the zip file for the current source. + this.currentSourceOffset = 0; + // The total number of entries in this zip file. + this.entriesCount = 0; + // the name of the file currently being added, null when handling the end of the zip file. + // Used for the emitted metadata. + this.currentFile = null; + + + + this._sources = []; +} +utils.inherits(ZipFileWorker, GenericWorker); + +/** + * @see GenericWorker.push + */ +ZipFileWorker.prototype.push = function (chunk) { + + var currentFilePercent = chunk.meta.percent || 0; + var entriesCount = this.entriesCount; + var remainingFiles = this._sources.length; + + if(this.accumulate) { + this.contentBuffer.push(chunk); + } else { + this.bytesWritten += chunk.data.length; + + GenericWorker.prototype.push.call(this, { + data : chunk.data, + meta : { + currentFile : this.currentFile, + percent : entriesCount ? (currentFilePercent + 100 * (entriesCount - remainingFiles - 1)) / entriesCount : 100 + } + }); + } +}; + +/** + * The worker started a new source (an other worker). + * @param {Object} streamInfo the streamInfo object from the new source. + */ +ZipFileWorker.prototype.openedSource = function (streamInfo) { + this.currentSourceOffset = this.bytesWritten; + this.currentFile = streamInfo["file"].name; + + var streamedContent = this.streamFiles && !streamInfo["file"].dir; + + // don't stream folders (because they don't have any content) + if(streamedContent) { + var record = generateZipParts(streamInfo, streamedContent, false, this.currentSourceOffset, this.zipPlatform, this.encodeFileName); + this.push({ + data : record.fileRecord, + meta : {percent:0} + }); + } else { + // we need to wait for the whole file before pushing anything + this.accumulate = true; + } +}; + +/** + * The worker finished a source (an other worker). + * @param {Object} streamInfo the streamInfo object from the finished source. + */ +ZipFileWorker.prototype.closedSource = function (streamInfo) { + this.accumulate = false; + var streamedContent = this.streamFiles && !streamInfo["file"].dir; + var record = generateZipParts(streamInfo, streamedContent, true, this.currentSourceOffset, this.zipPlatform, this.encodeFileName); + + this.dirRecords.push(record.dirRecord); + if(streamedContent) { + // after the streamed file, we put data descriptors + this.push({ + data : generateDataDescriptors(streamInfo), + meta : {percent:100} + }); + } else { + // the content wasn't streamed, we need to push everything now + // first the file record, then the content + this.push({ + data : record.fileRecord, + meta : {percent:0} + }); + while(this.contentBuffer.length) { + this.push(this.contentBuffer.shift()); + } + } + this.currentFile = null; +}; + +/** + * @see GenericWorker.flush + */ +ZipFileWorker.prototype.flush = function () { + + var localDirLength = this.bytesWritten; + for(var i = 0; i < this.dirRecords.length; i++) { + this.push({ + data : this.dirRecords[i], + meta : {percent:100} + }); + } + var centralDirLength = this.bytesWritten - localDirLength; + + var dirEnd = generateCentralDirectoryEnd(this.dirRecords.length, centralDirLength, localDirLength, this.zipComment, this.encodeFileName); + + this.push({ + data : dirEnd, + meta : {percent:100} + }); +}; + +/** + * Prepare the next source to be read. + */ +ZipFileWorker.prototype.prepareNextSource = function () { + this.previous = this._sources.shift(); + this.openedSource(this.previous.streamInfo); + if (this.isPaused) { + this.previous.pause(); + } else { + this.previous.resume(); + } +}; + +/** + * @see GenericWorker.registerPrevious + */ +ZipFileWorker.prototype.registerPrevious = function (previous) { + this._sources.push(previous); + var self = this; + + previous.on("data", function (chunk) { + self.processChunk(chunk); + }); + previous.on("end", function () { + self.closedSource(self.previous.streamInfo); + if(self._sources.length) { + self.prepareNextSource(); + } else { + self.end(); + } + }); + previous.on("error", function (e) { + self.error(e); + }); + return this; +}; + +/** + * @see GenericWorker.resume + */ +ZipFileWorker.prototype.resume = function () { + if(!GenericWorker.prototype.resume.call(this)) { + return false; + } + + if (!this.previous && this._sources.length) { + this.prepareNextSource(); + return true; + } + if (!this.previous && !this._sources.length && !this.generatedError) { + this.end(); + return true; + } +}; + +/** + * @see GenericWorker.error + */ +ZipFileWorker.prototype.error = function (e) { + var sources = this._sources; + if(!GenericWorker.prototype.error.call(this, e)) { + return false; + } + for(var i = 0; i < sources.length; i++) { + try { + sources[i].error(e); + } catch(e) { + // the `error` exploded, nothing to do + } + } + return true; +}; + +/** + * @see GenericWorker.lock + */ +ZipFileWorker.prototype.lock = function () { + GenericWorker.prototype.lock.call(this); + var sources = this._sources; + for(var i = 0; i < sources.length; i++) { + sources[i].lock(); + } +}; + +module.exports = ZipFileWorker; + +},{"../crc32":4,"../signature":23,"../stream/GenericWorker":28,"../utf8":31,"../utils":32}],9:[function(require,module,exports){ +"use strict"; + +var compressions = require("../compressions"); +var ZipFileWorker = require("./ZipFileWorker"); + +/** + * Find the compression to use. + * @param {String} fileCompression the compression defined at the file level, if any. + * @param {String} zipCompression the compression defined at the load() level. + * @return {Object} the compression object to use. + */ +var getCompression = function (fileCompression, zipCompression) { + + var compressionName = fileCompression || zipCompression; + var compression = compressions[compressionName]; + if (!compression) { + throw new Error(compressionName + " is not a valid compression method !"); + } + return compression; +}; + +/** + * Create a worker to generate a zip file. + * @param {JSZip} zip the JSZip instance at the right root level. + * @param {Object} options to generate the zip file. + * @param {String} comment the comment to use. + */ +exports.generateWorker = function (zip, options, comment) { + + var zipFileWorker = new ZipFileWorker(options.streamFiles, comment, options.platform, options.encodeFileName); + var entriesCount = 0; + try { + + zip.forEach(function (relativePath, file) { + entriesCount++; + var compression = getCompression(file.options.compression, options.compression); + var compressionOptions = file.options.compressionOptions || options.compressionOptions || {}; + var dir = file.dir, date = file.date; + + file._compressWorker(compression, compressionOptions) + .withStreamInfo("file", { + name : relativePath, + dir : dir, + date : date, + comment : file.comment || "", + unixPermissions : file.unixPermissions, + dosPermissions : file.dosPermissions + }) + .pipe(zipFileWorker); + }); + zipFileWorker.entriesCount = entriesCount; + } catch (e) { + zipFileWorker.error(e); + } + + return zipFileWorker; +}; + +},{"../compressions":3,"./ZipFileWorker":8}],10:[function(require,module,exports){ +"use strict"; + +/** + * Representation a of zip file in js + * @constructor + */ +function JSZip() { + // if this constructor is used without `new`, it adds `new` before itself: + if(!(this instanceof JSZip)) { + return new JSZip(); + } + + if(arguments.length) { + throw new Error("The constructor with parameters has been removed in JSZip 3.0, please check the upgrade guide."); + } + + // object containing the files : + // { + // "folder/" : {...}, + // "folder/data.txt" : {...} + // } + // NOTE: we use a null prototype because we do not + // want filenames like "toString" coming from a zip file + // to overwrite methods and attributes in a normal Object. + this.files = Object.create(null); + + this.comment = null; + + // Where we are in the hierarchy + this.root = ""; + this.clone = function() { + var newObj = new JSZip(); + for (var i in this) { + if (typeof this[i] !== "function") { + newObj[i] = this[i]; + } + } + return newObj; + }; +} +JSZip.prototype = require("./object"); +JSZip.prototype.loadAsync = require("./load"); +JSZip.support = require("./support"); +JSZip.defaults = require("./defaults"); + +// TODO find a better way to handle this version, +// a require('package.json').version doesn't work with webpack, see #327 +JSZip.version = "3.10.1"; + +JSZip.loadAsync = function (content, options) { + return new JSZip().loadAsync(content, options); +}; + +JSZip.external = require("./external"); +module.exports = JSZip; + +},{"./defaults":5,"./external":6,"./load":11,"./object":15,"./support":30}],11:[function(require,module,exports){ +"use strict"; +var utils = require("./utils"); +var external = require("./external"); +var utf8 = require("./utf8"); +var ZipEntries = require("./zipEntries"); +var Crc32Probe = require("./stream/Crc32Probe"); +var nodejsUtils = require("./nodejsUtils"); + +/** + * Check the CRC32 of an entry. + * @param {ZipEntry} zipEntry the zip entry to check. + * @return {Promise} the result. + */ +function checkEntryCRC32(zipEntry) { + return new external.Promise(function (resolve, reject) { + var worker = zipEntry.decompressed.getContentWorker().pipe(new Crc32Probe()); + worker.on("error", function (e) { + reject(e); + }) + .on("end", function () { + if (worker.streamInfo.crc32 !== zipEntry.decompressed.crc32) { + reject(new Error("Corrupted zip : CRC32 mismatch")); + } else { + resolve(); + } + }) + .resume(); + }); +} + +module.exports = function (data, options) { + var zip = this; + options = utils.extend(options || {}, { + base64: false, + checkCRC32: false, + optimizedBinaryString: false, + createFolders: false, + decodeFileName: utf8.utf8decode + }); + + if (nodejsUtils.isNode && nodejsUtils.isStream(data)) { + return external.Promise.reject(new Error("JSZip can't accept a stream when loading a zip file.")); + } + + return utils.prepareContent("the loaded zip file", data, true, options.optimizedBinaryString, options.base64) + .then(function (data) { + var zipEntries = new ZipEntries(options); + zipEntries.load(data); + return zipEntries; + }).then(function checkCRC32(zipEntries) { + var promises = [external.Promise.resolve(zipEntries)]; + var files = zipEntries.files; + if (options.checkCRC32) { + for (var i = 0; i < files.length; i++) { + promises.push(checkEntryCRC32(files[i])); + } + } + return external.Promise.all(promises); + }).then(function addFiles(results) { + var zipEntries = results.shift(); + var files = zipEntries.files; + for (var i = 0; i < files.length; i++) { + var input = files[i]; + + var unsafeName = input.fileNameStr; + var safeName = utils.resolve(input.fileNameStr); + + zip.file(safeName, input.decompressed, { + binary: true, + optimizedBinaryString: true, + date: input.date, + dir: input.dir, + comment: input.fileCommentStr.length ? input.fileCommentStr : null, + unixPermissions: input.unixPermissions, + dosPermissions: input.dosPermissions, + createFolders: options.createFolders + }); + if (!input.dir) { + zip.file(safeName).unsafeOriginalName = unsafeName; + } + } + if (zipEntries.zipComment.length) { + zip.comment = zipEntries.zipComment; + } + + return zip; + }); +}; + +},{"./external":6,"./nodejsUtils":14,"./stream/Crc32Probe":25,"./utf8":31,"./utils":32,"./zipEntries":33}],12:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("../stream/GenericWorker"); + +/** + * A worker that use a nodejs stream as source. + * @constructor + * @param {String} filename the name of the file entry for this stream. + * @param {Readable} stream the nodejs stream. + */ +function NodejsStreamInputAdapter(filename, stream) { + GenericWorker.call(this, "Nodejs stream input adapter for " + filename); + this._upstreamEnded = false; + this._bindStream(stream); +} + +utils.inherits(NodejsStreamInputAdapter, GenericWorker); + +/** + * Prepare the stream and bind the callbacks on it. + * Do this ASAP on node 0.10 ! A lazy binding doesn't always work. + * @param {Stream} stream the nodejs stream to use. + */ +NodejsStreamInputAdapter.prototype._bindStream = function (stream) { + var self = this; + this._stream = stream; + stream.pause(); + stream + .on("data", function (chunk) { + self.push({ + data: chunk, + meta : { + percent : 0 + } + }); + }) + .on("error", function (e) { + if(self.isPaused) { + this.generatedError = e; + } else { + self.error(e); + } + }) + .on("end", function () { + if(self.isPaused) { + self._upstreamEnded = true; + } else { + self.end(); + } + }); +}; +NodejsStreamInputAdapter.prototype.pause = function () { + if(!GenericWorker.prototype.pause.call(this)) { + return false; + } + this._stream.pause(); + return true; +}; +NodejsStreamInputAdapter.prototype.resume = function () { + if(!GenericWorker.prototype.resume.call(this)) { + return false; + } + + if(this._upstreamEnded) { + this.end(); + } else { + this._stream.resume(); + } + + return true; +}; + +module.exports = NodejsStreamInputAdapter; + +},{"../stream/GenericWorker":28,"../utils":32}],13:[function(require,module,exports){ +"use strict"; + +var Readable = require("readable-stream").Readable; + +var utils = require("../utils"); +utils.inherits(NodejsStreamOutputAdapter, Readable); + +/** +* A nodejs stream using a worker as source. +* @see the SourceWrapper in http://nodejs.org/api/stream.html +* @constructor +* @param {StreamHelper} helper the helper wrapping the worker +* @param {Object} options the nodejs stream options +* @param {Function} updateCb the update callback. +*/ +function NodejsStreamOutputAdapter(helper, options, updateCb) { + Readable.call(this, options); + this._helper = helper; + + var self = this; + helper.on("data", function (data, meta) { + if (!self.push(data)) { + self._helper.pause(); + } + if(updateCb) { + updateCb(meta); + } + }) + .on("error", function(e) { + self.emit("error", e); + }) + .on("end", function () { + self.push(null); + }); +} + + +NodejsStreamOutputAdapter.prototype._read = function() { + this._helper.resume(); +}; + +module.exports = NodejsStreamOutputAdapter; + +},{"../utils":32,"readable-stream":16}],14:[function(require,module,exports){ +"use strict"; + +module.exports = { + /** + * True if this is running in Nodejs, will be undefined in a browser. + * In a browser, browserify won't include this file and the whole module + * will be resolved an empty object. + */ + isNode : typeof Buffer !== "undefined", + /** + * Create a new nodejs Buffer from an existing content. + * @param {Object} data the data to pass to the constructor. + * @param {String} encoding the encoding to use. + * @return {Buffer} a new Buffer. + */ + newBufferFrom: function(data, encoding) { + if (Buffer.from && Buffer.from !== Uint8Array.from) { + return Buffer.from(data, encoding); + } else { + if (typeof data === "number") { + // Safeguard for old Node.js versions. On newer versions, + // Buffer.from(number) / Buffer(number, encoding) already throw. + throw new Error("The \"data\" argument must not be a number"); + } + return new Buffer(data, encoding); + } + }, + /** + * Create a new nodejs Buffer with the specified size. + * @param {Integer} size the size of the buffer. + * @return {Buffer} a new Buffer. + */ + allocBuffer: function (size) { + if (Buffer.alloc) { + return Buffer.alloc(size); + } else { + var buf = new Buffer(size); + buf.fill(0); + return buf; + } + }, + /** + * Find out if an object is a Buffer. + * @param {Object} b the object to test. + * @return {Boolean} true if the object is a Buffer, false otherwise. + */ + isBuffer : function(b){ + return Buffer.isBuffer(b); + }, + + isStream : function (obj) { + return obj && + typeof obj.on === "function" && + typeof obj.pause === "function" && + typeof obj.resume === "function"; + } +}; + +},{}],15:[function(require,module,exports){ +"use strict"; +var utf8 = require("./utf8"); +var utils = require("./utils"); +var GenericWorker = require("./stream/GenericWorker"); +var StreamHelper = require("./stream/StreamHelper"); +var defaults = require("./defaults"); +var CompressedObject = require("./compressedObject"); +var ZipObject = require("./zipObject"); +var generate = require("./generate"); +var nodejsUtils = require("./nodejsUtils"); +var NodejsStreamInputAdapter = require("./nodejs/NodejsStreamInputAdapter"); + + +/** + * Add a file in the current folder. + * @private + * @param {string} name the name of the file + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the data of the file + * @param {Object} originalOptions the options of the file + * @return {Object} the new file. + */ +var fileAdd = function(name, data, originalOptions) { + // be sure sub folders exist + var dataType = utils.getTypeOf(data), + parent; + + + /* + * Correct options. + */ + + var o = utils.extend(originalOptions || {}, defaults); + o.date = o.date || new Date(); + if (o.compression !== null) { + o.compression = o.compression.toUpperCase(); + } + + if (typeof o.unixPermissions === "string") { + o.unixPermissions = parseInt(o.unixPermissions, 8); + } + + // UNX_IFDIR 0040000 see zipinfo.c + if (o.unixPermissions && (o.unixPermissions & 0x4000)) { + o.dir = true; + } + // Bit 4 Directory + if (o.dosPermissions && (o.dosPermissions & 0x0010)) { + o.dir = true; + } + + if (o.dir) { + name = forceTrailingSlash(name); + } + if (o.createFolders && (parent = parentFolder(name))) { + folderAdd.call(this, parent, true); + } + + var isUnicodeString = dataType === "string" && o.binary === false && o.base64 === false; + if (!originalOptions || typeof originalOptions.binary === "undefined") { + o.binary = !isUnicodeString; + } + + + var isCompressedEmpty = (data instanceof CompressedObject) && data.uncompressedSize === 0; + + if (isCompressedEmpty || o.dir || !data || data.length === 0) { + o.base64 = false; + o.binary = true; + data = ""; + o.compression = "STORE"; + dataType = "string"; + } + + /* + * Convert content to fit. + */ + + var zipObjectContent = null; + if (data instanceof CompressedObject || data instanceof GenericWorker) { + zipObjectContent = data; + } else if (nodejsUtils.isNode && nodejsUtils.isStream(data)) { + zipObjectContent = new NodejsStreamInputAdapter(name, data); + } else { + zipObjectContent = utils.prepareContent(name, data, o.binary, o.optimizedBinaryString, o.base64); + } + + var object = new ZipObject(name, zipObjectContent, o); + this.files[name] = object; + /* + TODO: we can't throw an exception because we have async promises + (we can have a promise of a Date() for example) but returning a + promise is useless because file(name, data) returns the JSZip + object for chaining. Should we break that to allow the user + to catch the error ? + + return external.Promise.resolve(zipObjectContent) + .then(function () { + return object; + }); + */ +}; + +/** + * Find the parent folder of the path. + * @private + * @param {string} path the path to use + * @return {string} the parent folder, or "" + */ +var parentFolder = function (path) { + if (path.slice(-1) === "/") { + path = path.substring(0, path.length - 1); + } + var lastSlash = path.lastIndexOf("/"); + return (lastSlash > 0) ? path.substring(0, lastSlash) : ""; +}; + +/** + * Returns the path with a slash at the end. + * @private + * @param {String} path the path to check. + * @return {String} the path with a trailing slash. + */ +var forceTrailingSlash = function(path) { + // Check the name ends with a / + if (path.slice(-1) !== "/") { + path += "/"; // IE doesn't like substr(-1) + } + return path; +}; + +/** + * Add a (sub) folder in the current folder. + * @private + * @param {string} name the folder's name + * @param {boolean=} [createFolders] If true, automatically create sub + * folders. Defaults to false. + * @return {Object} the new folder. + */ +var folderAdd = function(name, createFolders) { + createFolders = (typeof createFolders !== "undefined") ? createFolders : defaults.createFolders; + + name = forceTrailingSlash(name); + + // Does this folder already exist? + if (!this.files[name]) { + fileAdd.call(this, name, null, { + dir: true, + createFolders: createFolders + }); + } + return this.files[name]; +}; + +/** +* Cross-window, cross-Node-context regular expression detection +* @param {Object} object Anything +* @return {Boolean} true if the object is a regular expression, +* false otherwise +*/ +function isRegExp(object) { + return Object.prototype.toString.call(object) === "[object RegExp]"; +} + +// return the actual prototype of JSZip +var out = { + /** + * @see loadAsync + */ + load: function() { + throw new Error("This method has been removed in JSZip 3.0, please check the upgrade guide."); + }, + + + /** + * Call a callback function for each entry at this folder level. + * @param {Function} cb the callback function: + * function (relativePath, file) {...} + * It takes 2 arguments : the relative path and the file. + */ + forEach: function(cb) { + var filename, relativePath, file; + // ignore warning about unwanted properties because this.files is a null prototype object + /* eslint-disable-next-line guard-for-in */ + for (filename in this.files) { + file = this.files[filename]; + relativePath = filename.slice(this.root.length, filename.length); + if (relativePath && filename.slice(0, this.root.length) === this.root) { // the file is in the current root + cb(relativePath, file); // TODO reverse the parameters ? need to be clean AND consistent with the filter search fn... + } + } + }, + + /** + * Filter nested files/folders with the specified function. + * @param {Function} search the predicate to use : + * function (relativePath, file) {...} + * It takes 2 arguments : the relative path and the file. + * @return {Array} An array of matching elements. + */ + filter: function(search) { + var result = []; + this.forEach(function (relativePath, entry) { + if (search(relativePath, entry)) { // the file matches the function + result.push(entry); + } + + }); + return result; + }, + + /** + * Add a file to the zip file, or search a file. + * @param {string|RegExp} name The name of the file to add (if data is defined), + * the name of the file to find (if no data) or a regex to match files. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data The file data, either raw or base64 encoded + * @param {Object} o File options + * @return {JSZip|Object|Array} this JSZip object (when adding a file), + * a file (when searching by string) or an array of files (when searching by regex). + */ + file: function(name, data, o) { + if (arguments.length === 1) { + if (isRegExp(name)) { + var regexp = name; + return this.filter(function(relativePath, file) { + return !file.dir && regexp.test(relativePath); + }); + } + else { // text + var obj = this.files[this.root + name]; + if (obj && !obj.dir) { + return obj; + } else { + return null; + } + } + } + else { // more than one argument : we have data ! + name = this.root + name; + fileAdd.call(this, name, data, o); + } + return this; + }, + + /** + * Add a directory to the zip file, or search. + * @param {String|RegExp} arg The name of the directory to add, or a regex to search folders. + * @return {JSZip} an object with the new directory as the root, or an array containing matching folders. + */ + folder: function(arg) { + if (!arg) { + return this; + } + + if (isRegExp(arg)) { + return this.filter(function(relativePath, file) { + return file.dir && arg.test(relativePath); + }); + } + + // else, name is a new folder + var name = this.root + arg; + var newFolder = folderAdd.call(this, name); + + // Allow chaining by returning a new object with this folder as the root + var ret = this.clone(); + ret.root = newFolder.name; + return ret; + }, + + /** + * Delete a file, or a directory and all sub-files, from the zip + * @param {string} name the name of the file to delete + * @return {JSZip} this JSZip object + */ + remove: function(name) { + name = this.root + name; + var file = this.files[name]; + if (!file) { + // Look for any folders + if (name.slice(-1) !== "/") { + name += "/"; + } + file = this.files[name]; + } + + if (file && !file.dir) { + // file + delete this.files[name]; + } else { + // maybe a folder, delete recursively + var kids = this.filter(function(relativePath, file) { + return file.name.slice(0, name.length) === name; + }); + for (var i = 0; i < kids.length; i++) { + delete this.files[kids[i].name]; + } + } + + return this; + }, + + /** + * @deprecated This method has been removed in JSZip 3.0, please check the upgrade guide. + */ + generate: function() { + throw new Error("This method has been removed in JSZip 3.0, please check the upgrade guide."); + }, + + /** + * Generate the complete zip file as an internal stream. + * @param {Object} options the options to generate the zip file : + * - compression, "STORE" by default. + * - type, "base64" by default. Values are : string, base64, uint8array, arraybuffer, blob. + * @return {StreamHelper} the streamed zip file. + */ + generateInternalStream: function(options) { + var worker, opts = {}; + try { + opts = utils.extend(options || {}, { + streamFiles: false, + compression: "STORE", + compressionOptions : null, + type: "", + platform: "DOS", + comment: null, + mimeType: "application/zip", + encodeFileName: utf8.utf8encode + }); + + opts.type = opts.type.toLowerCase(); + opts.compression = opts.compression.toUpperCase(); + + // "binarystring" is preferred but the internals use "string". + if(opts.type === "binarystring") { + opts.type = "string"; + } + + if (!opts.type) { + throw new Error("No output type specified."); + } + + utils.checkSupport(opts.type); + + // accept nodejs `process.platform` + if( + opts.platform === "darwin" || + opts.platform === "freebsd" || + opts.platform === "linux" || + opts.platform === "sunos" + ) { + opts.platform = "UNIX"; + } + if (opts.platform === "win32") { + opts.platform = "DOS"; + } + + var comment = opts.comment || this.comment || ""; + worker = generate.generateWorker(this, opts, comment); + } catch (e) { + worker = new GenericWorker("error"); + worker.error(e); + } + return new StreamHelper(worker, opts.type || "string", opts.mimeType); + }, + /** + * Generate the complete zip file asynchronously. + * @see generateInternalStream + */ + generateAsync: function(options, onUpdate) { + return this.generateInternalStream(options).accumulate(onUpdate); + }, + /** + * Generate the complete zip file asynchronously. + * @see generateInternalStream + */ + generateNodeStream: function(options, onUpdate) { + options = options || {}; + if (!options.type) { + options.type = "nodebuffer"; + } + return this.generateInternalStream(options).toNodejsStream(onUpdate); + } +}; +module.exports = out; + +},{"./compressedObject":2,"./defaults":5,"./generate":9,"./nodejs/NodejsStreamInputAdapter":12,"./nodejsUtils":14,"./stream/GenericWorker":28,"./stream/StreamHelper":29,"./utf8":31,"./utils":32,"./zipObject":35}],16:[function(require,module,exports){ +"use strict"; +/* + * This file is used by module bundlers (browserify/webpack/etc) when + * including a stream implementation. We use "readable-stream" to get a + * consistent behavior between nodejs versions but bundlers often have a shim + * for "stream". Using this shim greatly improve the compatibility and greatly + * reduce the final size of the bundle (only one stream implementation, not + * two). + */ +module.exports = require("stream"); + +},{"stream":undefined}],17:[function(require,module,exports){ +"use strict"; +var DataReader = require("./DataReader"); +var utils = require("../utils"); + +function ArrayReader(data) { + DataReader.call(this, data); + for(var i = 0; i < this.data.length; i++) { + data[i] = data[i] & 0xFF; + } +} +utils.inherits(ArrayReader, DataReader); +/** + * @see DataReader.byteAt + */ +ArrayReader.prototype.byteAt = function(i) { + return this.data[this.zero + i]; +}; +/** + * @see DataReader.lastIndexOfSignature + */ +ArrayReader.prototype.lastIndexOfSignature = function(sig) { + var sig0 = sig.charCodeAt(0), + sig1 = sig.charCodeAt(1), + sig2 = sig.charCodeAt(2), + sig3 = sig.charCodeAt(3); + for (var i = this.length - 4; i >= 0; --i) { + if (this.data[i] === sig0 && this.data[i + 1] === sig1 && this.data[i + 2] === sig2 && this.data[i + 3] === sig3) { + return i - this.zero; + } + } + + return -1; +}; +/** + * @see DataReader.readAndCheckSignature + */ +ArrayReader.prototype.readAndCheckSignature = function (sig) { + var sig0 = sig.charCodeAt(0), + sig1 = sig.charCodeAt(1), + sig2 = sig.charCodeAt(2), + sig3 = sig.charCodeAt(3), + data = this.readData(4); + return sig0 === data[0] && sig1 === data[1] && sig2 === data[2] && sig3 === data[3]; +}; +/** + * @see DataReader.readData + */ +ArrayReader.prototype.readData = function(size) { + this.checkOffset(size); + if(size === 0) { + return []; + } + var result = this.data.slice(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = ArrayReader; + +},{"../utils":32,"./DataReader":18}],18:[function(require,module,exports){ +"use strict"; +var utils = require("../utils"); + +function DataReader(data) { + this.data = data; // type : see implementation + this.length = data.length; + this.index = 0; + this.zero = 0; +} +DataReader.prototype = { + /** + * Check that the offset will not go too far. + * @param {string} offset the additional offset to check. + * @throws {Error} an Error if the offset is out of bounds. + */ + checkOffset: function(offset) { + this.checkIndex(this.index + offset); + }, + /** + * Check that the specified index will not be too far. + * @param {string} newIndex the index to check. + * @throws {Error} an Error if the index is out of bounds. + */ + checkIndex: function(newIndex) { + if (this.length < this.zero + newIndex || newIndex < 0) { + throw new Error("End of data reached (data length = " + this.length + ", asked index = " + (newIndex) + "). Corrupted zip ?"); + } + }, + /** + * Change the index. + * @param {number} newIndex The new index. + * @throws {Error} if the new index is out of the data. + */ + setIndex: function(newIndex) { + this.checkIndex(newIndex); + this.index = newIndex; + }, + /** + * Skip the next n bytes. + * @param {number} n the number of bytes to skip. + * @throws {Error} if the new index is out of the data. + */ + skip: function(n) { + this.setIndex(this.index + n); + }, + /** + * Get the byte at the specified index. + * @param {number} i the index to use. + * @return {number} a byte. + */ + byteAt: function() { + // see implementations + }, + /** + * Get the next number with a given byte size. + * @param {number} size the number of bytes to read. + * @return {number} the corresponding number. + */ + readInt: function(size) { + var result = 0, + i; + this.checkOffset(size); + for (i = this.index + size - 1; i >= this.index; i--) { + result = (result << 8) + this.byteAt(i); + } + this.index += size; + return result; + }, + /** + * Get the next string with a given byte size. + * @param {number} size the number of bytes to read. + * @return {string} the corresponding string. + */ + readString: function(size) { + return utils.transformTo("string", this.readData(size)); + }, + /** + * Get raw data without conversion, bytes. + * @param {number} size the number of bytes to read. + * @return {Object} the raw data, implementation specific. + */ + readData: function() { + // see implementations + }, + /** + * Find the last occurrence of a zip signature (4 bytes). + * @param {string} sig the signature to find. + * @return {number} the index of the last occurrence, -1 if not found. + */ + lastIndexOfSignature: function() { + // see implementations + }, + /** + * Read the signature (4 bytes) at the current position and compare it with sig. + * @param {string} sig the expected signature + * @return {boolean} true if the signature matches, false otherwise. + */ + readAndCheckSignature: function() { + // see implementations + }, + /** + * Get the next date. + * @return {Date} the date. + */ + readDate: function() { + var dostime = this.readInt(4); + return new Date(Date.UTC( + ((dostime >> 25) & 0x7f) + 1980, // year + ((dostime >> 21) & 0x0f) - 1, // month + (dostime >> 16) & 0x1f, // day + (dostime >> 11) & 0x1f, // hour + (dostime >> 5) & 0x3f, // minute + (dostime & 0x1f) << 1)); // second + } +}; +module.exports = DataReader; + +},{"../utils":32}],19:[function(require,module,exports){ +"use strict"; +var Uint8ArrayReader = require("./Uint8ArrayReader"); +var utils = require("../utils"); + +function NodeBufferReader(data) { + Uint8ArrayReader.call(this, data); +} +utils.inherits(NodeBufferReader, Uint8ArrayReader); + +/** + * @see DataReader.readData + */ +NodeBufferReader.prototype.readData = function(size) { + this.checkOffset(size); + var result = this.data.slice(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = NodeBufferReader; + +},{"../utils":32,"./Uint8ArrayReader":21}],20:[function(require,module,exports){ +"use strict"; +var DataReader = require("./DataReader"); +var utils = require("../utils"); + +function StringReader(data) { + DataReader.call(this, data); +} +utils.inherits(StringReader, DataReader); +/** + * @see DataReader.byteAt + */ +StringReader.prototype.byteAt = function(i) { + return this.data.charCodeAt(this.zero + i); +}; +/** + * @see DataReader.lastIndexOfSignature + */ +StringReader.prototype.lastIndexOfSignature = function(sig) { + return this.data.lastIndexOf(sig) - this.zero; +}; +/** + * @see DataReader.readAndCheckSignature + */ +StringReader.prototype.readAndCheckSignature = function (sig) { + var data = this.readData(4); + return sig === data; +}; +/** + * @see DataReader.readData + */ +StringReader.prototype.readData = function(size) { + this.checkOffset(size); + // this will work because the constructor applied the "& 0xff" mask. + var result = this.data.slice(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = StringReader; + +},{"../utils":32,"./DataReader":18}],21:[function(require,module,exports){ +"use strict"; +var ArrayReader = require("./ArrayReader"); +var utils = require("../utils"); + +function Uint8ArrayReader(data) { + ArrayReader.call(this, data); +} +utils.inherits(Uint8ArrayReader, ArrayReader); +/** + * @see DataReader.readData + */ +Uint8ArrayReader.prototype.readData = function(size) { + this.checkOffset(size); + if(size === 0) { + // in IE10, when using subarray(idx, idx), we get the array [0x00] instead of []. + return new Uint8Array(0); + } + var result = this.data.subarray(this.zero + this.index, this.zero + this.index + size); + this.index += size; + return result; +}; +module.exports = Uint8ArrayReader; + +},{"../utils":32,"./ArrayReader":17}],22:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var support = require("../support"); +var ArrayReader = require("./ArrayReader"); +var StringReader = require("./StringReader"); +var NodeBufferReader = require("./NodeBufferReader"); +var Uint8ArrayReader = require("./Uint8ArrayReader"); + +/** + * Create a reader adapted to the data. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the data to read. + * @return {DataReader} the data reader. + */ +module.exports = function (data) { + var type = utils.getTypeOf(data); + utils.checkSupport(type); + if (type === "string" && !support.uint8array) { + return new StringReader(data); + } + if (type === "nodebuffer") { + return new NodeBufferReader(data); + } + if (support.uint8array) { + return new Uint8ArrayReader(utils.transformTo("uint8array", data)); + } + return new ArrayReader(utils.transformTo("array", data)); +}; + +},{"../support":30,"../utils":32,"./ArrayReader":17,"./NodeBufferReader":19,"./StringReader":20,"./Uint8ArrayReader":21}],23:[function(require,module,exports){ +"use strict"; +exports.LOCAL_FILE_HEADER = "PK\x03\x04"; +exports.CENTRAL_FILE_HEADER = "PK\x01\x02"; +exports.CENTRAL_DIRECTORY_END = "PK\x05\x06"; +exports.ZIP64_CENTRAL_DIRECTORY_LOCATOR = "PK\x06\x07"; +exports.ZIP64_CENTRAL_DIRECTORY_END = "PK\x06\x06"; +exports.DATA_DESCRIPTOR = "PK\x07\x08"; + +},{}],24:[function(require,module,exports){ +"use strict"; + +var GenericWorker = require("./GenericWorker"); +var utils = require("../utils"); + +/** + * A worker which convert chunks to a specified type. + * @constructor + * @param {String} destType the destination type. + */ +function ConvertWorker(destType) { + GenericWorker.call(this, "ConvertWorker to " + destType); + this.destType = destType; +} +utils.inherits(ConvertWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +ConvertWorker.prototype.processChunk = function (chunk) { + this.push({ + data : utils.transformTo(this.destType, chunk.data), + meta : chunk.meta + }); +}; +module.exports = ConvertWorker; + +},{"../utils":32,"./GenericWorker":28}],25:[function(require,module,exports){ +"use strict"; + +var GenericWorker = require("./GenericWorker"); +var crc32 = require("../crc32"); +var utils = require("../utils"); + +/** + * A worker which calculate the crc32 of the data flowing through. + * @constructor + */ +function Crc32Probe() { + GenericWorker.call(this, "Crc32Probe"); + this.withStreamInfo("crc32", 0); +} +utils.inherits(Crc32Probe, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +Crc32Probe.prototype.processChunk = function (chunk) { + this.streamInfo.crc32 = crc32(chunk.data, this.streamInfo.crc32 || 0); + this.push(chunk); +}; +module.exports = Crc32Probe; + +},{"../crc32":4,"../utils":32,"./GenericWorker":28}],26:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("./GenericWorker"); + +/** + * A worker which calculate the total length of the data flowing through. + * @constructor + * @param {String} propName the name used to expose the length + */ +function DataLengthProbe(propName) { + GenericWorker.call(this, "DataLengthProbe for " + propName); + this.propName = propName; + this.withStreamInfo(propName, 0); +} +utils.inherits(DataLengthProbe, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +DataLengthProbe.prototype.processChunk = function (chunk) { + if(chunk) { + var length = this.streamInfo[this.propName] || 0; + this.streamInfo[this.propName] = length + chunk.data.length; + } + GenericWorker.prototype.processChunk.call(this, chunk); +}; +module.exports = DataLengthProbe; + + +},{"../utils":32,"./GenericWorker":28}],27:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var GenericWorker = require("./GenericWorker"); + +// the size of the generated chunks +// TODO expose this as a public variable +var DEFAULT_BLOCK_SIZE = 16 * 1024; + +/** + * A worker that reads a content and emits chunks. + * @constructor + * @param {Promise} dataP the promise of the data to split + */ +function DataWorker(dataP) { + GenericWorker.call(this, "DataWorker"); + var self = this; + this.dataIsReady = false; + this.index = 0; + this.max = 0; + this.data = null; + this.type = ""; + + this._tickScheduled = false; + + dataP.then(function (data) { + self.dataIsReady = true; + self.data = data; + self.max = data && data.length || 0; + self.type = utils.getTypeOf(data); + if(!self.isPaused) { + self._tickAndRepeat(); + } + }, function (e) { + self.error(e); + }); +} + +utils.inherits(DataWorker, GenericWorker); + +/** + * @see GenericWorker.cleanUp + */ +DataWorker.prototype.cleanUp = function () { + GenericWorker.prototype.cleanUp.call(this); + this.data = null; +}; + +/** + * @see GenericWorker.resume + */ +DataWorker.prototype.resume = function () { + if(!GenericWorker.prototype.resume.call(this)) { + return false; + } + + if (!this._tickScheduled && this.dataIsReady) { + this._tickScheduled = true; + utils.delay(this._tickAndRepeat, [], this); + } + return true; +}; + +/** + * Trigger a tick a schedule an other call to this function. + */ +DataWorker.prototype._tickAndRepeat = function() { + this._tickScheduled = false; + if(this.isPaused || this.isFinished) { + return; + } + this._tick(); + if(!this.isFinished) { + utils.delay(this._tickAndRepeat, [], this); + this._tickScheduled = true; + } +}; + +/** + * Read and push a chunk. + */ +DataWorker.prototype._tick = function() { + + if(this.isPaused || this.isFinished) { + return false; + } + + var size = DEFAULT_BLOCK_SIZE; + var data = null, nextIndex = Math.min(this.max, this.index + size); + if (this.index >= this.max) { + // EOF + return this.end(); + } else { + switch(this.type) { + case "string": + data = this.data.substring(this.index, nextIndex); + break; + case "uint8array": + data = this.data.subarray(this.index, nextIndex); + break; + case "array": + case "nodebuffer": + data = this.data.slice(this.index, nextIndex); + break; + } + this.index = nextIndex; + return this.push({ + data : data, + meta : { + percent : this.max ? this.index / this.max * 100 : 0 + } + }); + } +}; + +module.exports = DataWorker; + +},{"../utils":32,"./GenericWorker":28}],28:[function(require,module,exports){ +"use strict"; + +/** + * A worker that does nothing but passing chunks to the next one. This is like + * a nodejs stream but with some differences. On the good side : + * - it works on IE 6-9 without any issue / polyfill + * - it weights less than the full dependencies bundled with browserify + * - it forwards errors (no need to declare an error handler EVERYWHERE) + * + * A chunk is an object with 2 attributes : `meta` and `data`. The former is an + * object containing anything (`percent` for example), see each worker for more + * details. The latter is the real data (String, Uint8Array, etc). + * + * @constructor + * @param {String} name the name of the stream (mainly used for debugging purposes) + */ +function GenericWorker(name) { + // the name of the worker + this.name = name || "default"; + // an object containing metadata about the workers chain + this.streamInfo = {}; + // an error which happened when the worker was paused + this.generatedError = null; + // an object containing metadata to be merged by this worker into the general metadata + this.extraStreamInfo = {}; + // true if the stream is paused (and should not do anything), false otherwise + this.isPaused = true; + // true if the stream is finished (and should not do anything), false otherwise + this.isFinished = false; + // true if the stream is locked to prevent further structure updates (pipe), false otherwise + this.isLocked = false; + // the event listeners + this._listeners = { + "data":[], + "end":[], + "error":[] + }; + // the previous worker, if any + this.previous = null; +} + +GenericWorker.prototype = { + /** + * Push a chunk to the next workers. + * @param {Object} chunk the chunk to push + */ + push : function (chunk) { + this.emit("data", chunk); + }, + /** + * End the stream. + * @return {Boolean} true if this call ended the worker, false otherwise. + */ + end : function () { + if (this.isFinished) { + return false; + } + + this.flush(); + try { + this.emit("end"); + this.cleanUp(); + this.isFinished = true; + } catch (e) { + this.emit("error", e); + } + return true; + }, + /** + * End the stream with an error. + * @param {Error} e the error which caused the premature end. + * @return {Boolean} true if this call ended the worker with an error, false otherwise. + */ + error : function (e) { + if (this.isFinished) { + return false; + } + + if(this.isPaused) { + this.generatedError = e; + } else { + this.isFinished = true; + + this.emit("error", e); + + // in the workers chain exploded in the middle of the chain, + // the error event will go downward but we also need to notify + // workers upward that there has been an error. + if(this.previous) { + this.previous.error(e); + } + + this.cleanUp(); + } + return true; + }, + /** + * Add a callback on an event. + * @param {String} name the name of the event (data, end, error) + * @param {Function} listener the function to call when the event is triggered + * @return {GenericWorker} the current object for chainability + */ + on : function (name, listener) { + this._listeners[name].push(listener); + return this; + }, + /** + * Clean any references when a worker is ending. + */ + cleanUp : function () { + this.streamInfo = this.generatedError = this.extraStreamInfo = null; + this._listeners = []; + }, + /** + * Trigger an event. This will call registered callback with the provided arg. + * @param {String} name the name of the event (data, end, error) + * @param {Object} arg the argument to call the callback with. + */ + emit : function (name, arg) { + if (this._listeners[name]) { + for(var i = 0; i < this._listeners[name].length; i++) { + this._listeners[name][i].call(this, arg); + } + } + }, + /** + * Chain a worker with an other. + * @param {Worker} next the worker receiving events from the current one. + * @return {worker} the next worker for chainability + */ + pipe : function (next) { + return next.registerPrevious(this); + }, + /** + * Same as `pipe` in the other direction. + * Using an API with `pipe(next)` is very easy. + * Implementing the API with the point of view of the next one registering + * a source is easier, see the ZipFileWorker. + * @param {Worker} previous the previous worker, sending events to this one + * @return {Worker} the current worker for chainability + */ + registerPrevious : function (previous) { + if (this.isLocked) { + throw new Error("The stream '" + this + "' has already been used."); + } + + // sharing the streamInfo... + this.streamInfo = previous.streamInfo; + // ... and adding our own bits + this.mergeStreamInfo(); + this.previous = previous; + var self = this; + previous.on("data", function (chunk) { + self.processChunk(chunk); + }); + previous.on("end", function () { + self.end(); + }); + previous.on("error", function (e) { + self.error(e); + }); + return this; + }, + /** + * Pause the stream so it doesn't send events anymore. + * @return {Boolean} true if this call paused the worker, false otherwise. + */ + pause : function () { + if(this.isPaused || this.isFinished) { + return false; + } + this.isPaused = true; + + if(this.previous) { + this.previous.pause(); + } + return true; + }, + /** + * Resume a paused stream. + * @return {Boolean} true if this call resumed the worker, false otherwise. + */ + resume : function () { + if(!this.isPaused || this.isFinished) { + return false; + } + this.isPaused = false; + + // if true, the worker tried to resume but failed + var withError = false; + if(this.generatedError) { + this.error(this.generatedError); + withError = true; + } + if(this.previous) { + this.previous.resume(); + } + + return !withError; + }, + /** + * Flush any remaining bytes as the stream is ending. + */ + flush : function () {}, + /** + * Process a chunk. This is usually the method overridden. + * @param {Object} chunk the chunk to process. + */ + processChunk : function(chunk) { + this.push(chunk); + }, + /** + * Add a key/value to be added in the workers chain streamInfo once activated. + * @param {String} key the key to use + * @param {Object} value the associated value + * @return {Worker} the current worker for chainability + */ + withStreamInfo : function (key, value) { + this.extraStreamInfo[key] = value; + this.mergeStreamInfo(); + return this; + }, + /** + * Merge this worker's streamInfo into the chain's streamInfo. + */ + mergeStreamInfo : function () { + for(var key in this.extraStreamInfo) { + if (!Object.prototype.hasOwnProperty.call(this.extraStreamInfo, key)) { + continue; + } + this.streamInfo[key] = this.extraStreamInfo[key]; + } + }, + + /** + * Lock the stream to prevent further updates on the workers chain. + * After calling this method, all calls to pipe will fail. + */ + lock: function () { + if (this.isLocked) { + throw new Error("The stream '" + this + "' has already been used."); + } + this.isLocked = true; + if (this.previous) { + this.previous.lock(); + } + }, + + /** + * + * Pretty print the workers chain. + */ + toString : function () { + var me = "Worker " + this.name; + if (this.previous) { + return this.previous + " -> " + me; + } else { + return me; + } + } +}; + +module.exports = GenericWorker; + +},{}],29:[function(require,module,exports){ +"use strict"; + +var utils = require("../utils"); +var ConvertWorker = require("./ConvertWorker"); +var GenericWorker = require("./GenericWorker"); +var base64 = require("../base64"); +var support = require("../support"); +var external = require("../external"); + +var NodejsStreamOutputAdapter = null; +if (support.nodestream) { + try { + NodejsStreamOutputAdapter = require("../nodejs/NodejsStreamOutputAdapter"); + } catch(e) { + // ignore + } +} + +/** + * Apply the final transformation of the data. If the user wants a Blob for + * example, it's easier to work with an U8intArray and finally do the + * ArrayBuffer/Blob conversion. + * @param {String} type the name of the final type + * @param {String|Uint8Array|Buffer} content the content to transform + * @param {String} mimeType the mime type of the content, if applicable. + * @return {String|Uint8Array|ArrayBuffer|Buffer|Blob} the content in the right format. + */ +function transformZipOutput(type, content, mimeType) { + switch(type) { + case "blob" : + return utils.newBlob(utils.transformTo("arraybuffer", content), mimeType); + case "base64" : + return base64.encode(content); + default : + return utils.transformTo(type, content); + } +} + +/** + * Concatenate an array of data of the given type. + * @param {String} type the type of the data in the given array. + * @param {Array} dataArray the array containing the data chunks to concatenate + * @return {String|Uint8Array|Buffer} the concatenated data + * @throws Error if the asked type is unsupported + */ +function concat (type, dataArray) { + var i, index = 0, res = null, totalLength = 0; + for(i = 0; i < dataArray.length; i++) { + totalLength += dataArray[i].length; + } + switch(type) { + case "string": + return dataArray.join(""); + case "array": + return Array.prototype.concat.apply([], dataArray); + case "uint8array": + res = new Uint8Array(totalLength); + for(i = 0; i < dataArray.length; i++) { + res.set(dataArray[i], index); + index += dataArray[i].length; + } + return res; + case "nodebuffer": + return Buffer.concat(dataArray); + default: + throw new Error("concat : unsupported type '" + type + "'"); + } +} + +/** + * Listen a StreamHelper, accumulate its content and concatenate it into a + * complete block. + * @param {StreamHelper} helper the helper to use. + * @param {Function} updateCallback a callback called on each update. Called + * with one arg : + * - the metadata linked to the update received. + * @return Promise the promise for the accumulation. + */ +function accumulate(helper, updateCallback) { + return new external.Promise(function (resolve, reject){ + var dataArray = []; + var chunkType = helper._internalType, + resultType = helper._outputType, + mimeType = helper._mimeType; + helper + .on("data", function (data, meta) { + dataArray.push(data); + if(updateCallback) { + updateCallback(meta); + } + }) + .on("error", function(err) { + dataArray = []; + reject(err); + }) + .on("end", function (){ + try { + var result = transformZipOutput(resultType, concat(chunkType, dataArray), mimeType); + resolve(result); + } catch (e) { + reject(e); + } + dataArray = []; + }) + .resume(); + }); +} + +/** + * An helper to easily use workers outside of JSZip. + * @constructor + * @param {Worker} worker the worker to wrap + * @param {String} outputType the type of data expected by the use + * @param {String} mimeType the mime type of the content, if applicable. + */ +function StreamHelper(worker, outputType, mimeType) { + var internalType = outputType; + switch(outputType) { + case "blob": + case "arraybuffer": + internalType = "uint8array"; + break; + case "base64": + internalType = "string"; + break; + } + + try { + // the type used internally + this._internalType = internalType; + // the type used to output results + this._outputType = outputType; + // the mime type + this._mimeType = mimeType; + utils.checkSupport(internalType); + this._worker = worker.pipe(new ConvertWorker(internalType)); + // the last workers can be rewired without issues but we need to + // prevent any updates on previous workers. + worker.lock(); + } catch(e) { + this._worker = new GenericWorker("error"); + this._worker.error(e); + } +} + +StreamHelper.prototype = { + /** + * Listen a StreamHelper, accumulate its content and concatenate it into a + * complete block. + * @param {Function} updateCb the update callback. + * @return Promise the promise for the accumulation. + */ + accumulate : function (updateCb) { + return accumulate(this, updateCb); + }, + /** + * Add a listener on an event triggered on a stream. + * @param {String} evt the name of the event + * @param {Function} fn the listener + * @return {StreamHelper} the current helper. + */ + on : function (evt, fn) { + var self = this; + + if(evt === "data") { + this._worker.on(evt, function (chunk) { + fn.call(self, chunk.data, chunk.meta); + }); + } else { + this._worker.on(evt, function () { + utils.delay(fn, arguments, self); + }); + } + return this; + }, + /** + * Resume the flow of chunks. + * @return {StreamHelper} the current helper. + */ + resume : function () { + utils.delay(this._worker.resume, [], this._worker); + return this; + }, + /** + * Pause the flow of chunks. + * @return {StreamHelper} the current helper. + */ + pause : function () { + this._worker.pause(); + return this; + }, + /** + * Return a nodejs stream for this helper. + * @param {Function} updateCb the update callback. + * @return {NodejsStreamOutputAdapter} the nodejs stream. + */ + toNodejsStream : function (updateCb) { + utils.checkSupport("nodestream"); + if (this._outputType !== "nodebuffer") { + // an object stream containing blob/arraybuffer/uint8array/string + // is strange and I don't know if it would be useful. + // I you find this comment and have a good usecase, please open a + // bug report ! + throw new Error(this._outputType + " is not supported by this method"); + } + + return new NodejsStreamOutputAdapter(this, { + objectMode : this._outputType !== "nodebuffer" + }, updateCb); + } +}; + + +module.exports = StreamHelper; + +},{"../base64":1,"../external":6,"../nodejs/NodejsStreamOutputAdapter":13,"../support":30,"../utils":32,"./ConvertWorker":24,"./GenericWorker":28}],30:[function(require,module,exports){ +"use strict"; + +exports.base64 = true; +exports.array = true; +exports.string = true; +exports.arraybuffer = typeof ArrayBuffer !== "undefined" && typeof Uint8Array !== "undefined"; +exports.nodebuffer = typeof Buffer !== "undefined"; +// contains true if JSZip can read/generate Uint8Array, false otherwise. +exports.uint8array = typeof Uint8Array !== "undefined"; + +if (typeof ArrayBuffer === "undefined") { + exports.blob = false; +} +else { + var buffer = new ArrayBuffer(0); + try { + exports.blob = new Blob([buffer], { + type: "application/zip" + }).size === 0; + } + catch (e) { + try { + var Builder = self.BlobBuilder || self.WebKitBlobBuilder || self.MozBlobBuilder || self.MSBlobBuilder; + var builder = new Builder(); + builder.append(buffer); + exports.blob = builder.getBlob("application/zip").size === 0; + } + catch (e) { + exports.blob = false; + } + } +} + +try { + exports.nodestream = !!require("readable-stream").Readable; +} catch(e) { + exports.nodestream = false; +} + +},{"readable-stream":16}],31:[function(require,module,exports){ +"use strict"; + +var utils = require("./utils"); +var support = require("./support"); +var nodejsUtils = require("./nodejsUtils"); +var GenericWorker = require("./stream/GenericWorker"); + +/** + * The following functions come from pako, from pako/lib/utils/strings + * released under the MIT license, see pako https://github.com/nodeca/pako/ + */ + +// Table with utf8 lengths (calculated by first byte of sequence) +// Note, that 5 & 6-byte values and some 4-byte values can not be represented in JS, +// because max possible codepoint is 0x10ffff +var _utf8len = new Array(256); +for (var i=0; i<256; i++) { + _utf8len[i] = (i >= 252 ? 6 : i >= 248 ? 5 : i >= 240 ? 4 : i >= 224 ? 3 : i >= 192 ? 2 : 1); +} +_utf8len[254]=_utf8len[254]=1; // Invalid sequence start + +// convert string to array (typed, when possible) +var string2buf = function (str) { + var buf, c, c2, m_pos, i, str_len = str.length, buf_len = 0; + + // count binary size + for (m_pos = 0; m_pos < str_len; m_pos++) { + c = str.charCodeAt(m_pos); + if ((c & 0xfc00) === 0xd800 && (m_pos+1 < str_len)) { + c2 = str.charCodeAt(m_pos+1); + if ((c2 & 0xfc00) === 0xdc00) { + c = 0x10000 + ((c - 0xd800) << 10) + (c2 - 0xdc00); + m_pos++; + } + } + buf_len += c < 0x80 ? 1 : c < 0x800 ? 2 : c < 0x10000 ? 3 : 4; + } + + // allocate buffer + if (support.uint8array) { + buf = new Uint8Array(buf_len); + } else { + buf = new Array(buf_len); + } + + // convert + for (i=0, m_pos = 0; i < buf_len; m_pos++) { + c = str.charCodeAt(m_pos); + if ((c & 0xfc00) === 0xd800 && (m_pos+1 < str_len)) { + c2 = str.charCodeAt(m_pos+1); + if ((c2 & 0xfc00) === 0xdc00) { + c = 0x10000 + ((c - 0xd800) << 10) + (c2 - 0xdc00); + m_pos++; + } + } + if (c < 0x80) { + /* one byte */ + buf[i++] = c; + } else if (c < 0x800) { + /* two bytes */ + buf[i++] = 0xC0 | (c >>> 6); + buf[i++] = 0x80 | (c & 0x3f); + } else if (c < 0x10000) { + /* three bytes */ + buf[i++] = 0xE0 | (c >>> 12); + buf[i++] = 0x80 | (c >>> 6 & 0x3f); + buf[i++] = 0x80 | (c & 0x3f); + } else { + /* four bytes */ + buf[i++] = 0xf0 | (c >>> 18); + buf[i++] = 0x80 | (c >>> 12 & 0x3f); + buf[i++] = 0x80 | (c >>> 6 & 0x3f); + buf[i++] = 0x80 | (c & 0x3f); + } + } + + return buf; +}; + +// Calculate max possible position in utf8 buffer, +// that will not break sequence. If that's not possible +// - (very small limits) return max size as is. +// +// buf[] - utf8 bytes array +// max - length limit (mandatory); +var utf8border = function(buf, max) { + var pos; + + max = max || buf.length; + if (max > buf.length) { max = buf.length; } + + // go back from last position, until start of sequence found + pos = max-1; + while (pos >= 0 && (buf[pos] & 0xC0) === 0x80) { pos--; } + + // Fuckup - very small and broken sequence, + // return max, because we should return something anyway. + if (pos < 0) { return max; } + + // If we came to start of buffer - that means vuffer is too small, + // return max too. + if (pos === 0) { return max; } + + return (pos + _utf8len[buf[pos]] > max) ? pos : max; +}; + +// convert array to string +var buf2string = function (buf) { + var i, out, c, c_len; + var len = buf.length; + + // Reserve max possible length (2 words per char) + // NB: by unknown reasons, Array is significantly faster for + // String.fromCharCode.apply than Uint16Array. + var utf16buf = new Array(len*2); + + for (out=0, i=0; i 4) { utf16buf[out++] = 0xfffd; i += c_len-1; continue; } + + // apply mask on first byte + c &= c_len === 2 ? 0x1f : c_len === 3 ? 0x0f : 0x07; + // join the rest + while (c_len > 1 && i < len) { + c = (c << 6) | (buf[i++] & 0x3f); + c_len--; + } + + // terminated by end of string? + if (c_len > 1) { utf16buf[out++] = 0xfffd; continue; } + + if (c < 0x10000) { + utf16buf[out++] = c; + } else { + c -= 0x10000; + utf16buf[out++] = 0xd800 | ((c >> 10) & 0x3ff); + utf16buf[out++] = 0xdc00 | (c & 0x3ff); + } + } + + // shrinkBuf(utf16buf, out) + if (utf16buf.length !== out) { + if(utf16buf.subarray) { + utf16buf = utf16buf.subarray(0, out); + } else { + utf16buf.length = out; + } + } + + // return String.fromCharCode.apply(null, utf16buf); + return utils.applyFromCharCode(utf16buf); +}; + + +// That's all for the pako functions. + + +/** + * Transform a javascript string into an array (typed if possible) of bytes, + * UTF-8 encoded. + * @param {String} str the string to encode + * @return {Array|Uint8Array|Buffer} the UTF-8 encoded string. + */ +exports.utf8encode = function utf8encode(str) { + if (support.nodebuffer) { + return nodejsUtils.newBufferFrom(str, "utf-8"); + } + + return string2buf(str); +}; + + +/** + * Transform a bytes array (or a representation) representing an UTF-8 encoded + * string into a javascript string. + * @param {Array|Uint8Array|Buffer} buf the data de decode + * @return {String} the decoded string. + */ +exports.utf8decode = function utf8decode(buf) { + if (support.nodebuffer) { + return utils.transformTo("nodebuffer", buf).toString("utf-8"); + } + + buf = utils.transformTo(support.uint8array ? "uint8array" : "array", buf); + + return buf2string(buf); +}; + +/** + * A worker to decode utf8 encoded binary chunks into string chunks. + * @constructor + */ +function Utf8DecodeWorker() { + GenericWorker.call(this, "utf-8 decode"); + // the last bytes if a chunk didn't end with a complete codepoint. + this.leftOver = null; +} +utils.inherits(Utf8DecodeWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +Utf8DecodeWorker.prototype.processChunk = function (chunk) { + + var data = utils.transformTo(support.uint8array ? "uint8array" : "array", chunk.data); + + // 1st step, re-use what's left of the previous chunk + if (this.leftOver && this.leftOver.length) { + if(support.uint8array) { + var previousData = data; + data = new Uint8Array(previousData.length + this.leftOver.length); + data.set(this.leftOver, 0); + data.set(previousData, this.leftOver.length); + } else { + data = this.leftOver.concat(data); + } + this.leftOver = null; + } + + var nextBoundary = utf8border(data); + var usableData = data; + if (nextBoundary !== data.length) { + if (support.uint8array) { + usableData = data.subarray(0, nextBoundary); + this.leftOver = data.subarray(nextBoundary, data.length); + } else { + usableData = data.slice(0, nextBoundary); + this.leftOver = data.slice(nextBoundary, data.length); + } + } + + this.push({ + data : exports.utf8decode(usableData), + meta : chunk.meta + }); +}; + +/** + * @see GenericWorker.flush + */ +Utf8DecodeWorker.prototype.flush = function () { + if(this.leftOver && this.leftOver.length) { + this.push({ + data : exports.utf8decode(this.leftOver), + meta : {} + }); + this.leftOver = null; + } +}; +exports.Utf8DecodeWorker = Utf8DecodeWorker; + +/** + * A worker to endcode string chunks into utf8 encoded binary chunks. + * @constructor + */ +function Utf8EncodeWorker() { + GenericWorker.call(this, "utf-8 encode"); +} +utils.inherits(Utf8EncodeWorker, GenericWorker); + +/** + * @see GenericWorker.processChunk + */ +Utf8EncodeWorker.prototype.processChunk = function (chunk) { + this.push({ + data : exports.utf8encode(chunk.data), + meta : chunk.meta + }); +}; +exports.Utf8EncodeWorker = Utf8EncodeWorker; + +},{"./nodejsUtils":14,"./stream/GenericWorker":28,"./support":30,"./utils":32}],32:[function(require,module,exports){ +"use strict"; + +var support = require("./support"); +var base64 = require("./base64"); +var nodejsUtils = require("./nodejsUtils"); +var external = require("./external"); +require("setimmediate"); + + +/** + * Convert a string that pass as a "binary string": it should represent a byte + * array but may have > 255 char codes. Be sure to take only the first byte + * and returns the byte array. + * @param {String} str the string to transform. + * @return {Array|Uint8Array} the string in a binary format. + */ +function string2binary(str) { + var result = null; + if (support.uint8array) { + result = new Uint8Array(str.length); + } else { + result = new Array(str.length); + } + return stringToArrayLike(str, result); +} + +/** + * Create a new blob with the given content and the given type. + * @param {String|ArrayBuffer} part the content to put in the blob. DO NOT use + * an Uint8Array because the stock browser of android 4 won't accept it (it + * will be silently converted to a string, "[object Uint8Array]"). + * + * Use only ONE part to build the blob to avoid a memory leak in IE11 / Edge: + * when a large amount of Array is used to create the Blob, the amount of + * memory consumed is nearly 100 times the original data amount. + * + * @param {String} type the mime type of the blob. + * @return {Blob} the created blob. + */ +exports.newBlob = function(part, type) { + exports.checkSupport("blob"); + + try { + // Blob constructor + return new Blob([part], { + type: type + }); + } + catch (e) { + + try { + // deprecated, browser only, old way + var Builder = self.BlobBuilder || self.WebKitBlobBuilder || self.MozBlobBuilder || self.MSBlobBuilder; + var builder = new Builder(); + builder.append(part); + return builder.getBlob(type); + } + catch (e) { + + // well, fuck ?! + throw new Error("Bug : can't construct the Blob."); + } + } + + +}; +/** + * The identity function. + * @param {Object} input the input. + * @return {Object} the same input. + */ +function identity(input) { + return input; +} + +/** + * Fill in an array with a string. + * @param {String} str the string to use. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to fill in (will be mutated). + * @return {Array|ArrayBuffer|Uint8Array|Buffer} the updated array. + */ +function stringToArrayLike(str, array) { + for (var i = 0; i < str.length; ++i) { + array[i] = str.charCodeAt(i) & 0xFF; + } + return array; +} + +/** + * An helper for the function arrayLikeToString. + * This contains static information and functions that + * can be optimized by the browser JIT compiler. + */ +var arrayToStringHelper = { + /** + * Transform an array of int into a string, chunk by chunk. + * See the performances notes on arrayLikeToString. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to transform. + * @param {String} type the type of the array. + * @param {Integer} chunk the chunk size. + * @return {String} the resulting string. + * @throws Error if the chunk is too big for the stack. + */ + stringifyByChunk: function(array, type, chunk) { + var result = [], k = 0, len = array.length; + // shortcut + if (len <= chunk) { + return String.fromCharCode.apply(null, array); + } + while (k < len) { + if (type === "array" || type === "nodebuffer") { + result.push(String.fromCharCode.apply(null, array.slice(k, Math.min(k + chunk, len)))); + } + else { + result.push(String.fromCharCode.apply(null, array.subarray(k, Math.min(k + chunk, len)))); + } + k += chunk; + } + return result.join(""); + }, + /** + * Call String.fromCharCode on every item in the array. + * This is the naive implementation, which generate A LOT of intermediate string. + * This should be used when everything else fail. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to transform. + * @return {String} the result. + */ + stringifyByChar: function(array){ + var resultStr = ""; + for(var i = 0; i < array.length; i++) { + resultStr += String.fromCharCode(array[i]); + } + return resultStr; + }, + applyCanBeUsed : { + /** + * true if the browser accepts to use String.fromCharCode on Uint8Array + */ + uint8array : (function () { + try { + return support.uint8array && String.fromCharCode.apply(null, new Uint8Array(1)).length === 1; + } catch (e) { + return false; + } + })(), + /** + * true if the browser accepts to use String.fromCharCode on nodejs Buffer. + */ + nodebuffer : (function () { + try { + return support.nodebuffer && String.fromCharCode.apply(null, nodejsUtils.allocBuffer(1)).length === 1; + } catch (e) { + return false; + } + })() + } +}; + +/** + * Transform an array-like object to a string. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} array the array to transform. + * @return {String} the result. + */ +function arrayLikeToString(array) { + // Performances notes : + // -------------------- + // String.fromCharCode.apply(null, array) is the fastest, see + // see http://jsperf.com/converting-a-uint8array-to-a-string/2 + // but the stack is limited (and we can get huge arrays !). + // + // result += String.fromCharCode(array[i]); generate too many strings ! + // + // This code is inspired by http://jsperf.com/arraybuffer-to-string-apply-performance/2 + // TODO : we now have workers that split the work. Do we still need that ? + var chunk = 65536, + type = exports.getTypeOf(array), + canUseApply = true; + if (type === "uint8array") { + canUseApply = arrayToStringHelper.applyCanBeUsed.uint8array; + } else if (type === "nodebuffer") { + canUseApply = arrayToStringHelper.applyCanBeUsed.nodebuffer; + } + + if (canUseApply) { + while (chunk > 1) { + try { + return arrayToStringHelper.stringifyByChunk(array, type, chunk); + } catch (e) { + chunk = Math.floor(chunk / 2); + } + } + } + + // no apply or chunk error : slow and painful algorithm + // default browser on android 4.* + return arrayToStringHelper.stringifyByChar(array); +} + +exports.applyFromCharCode = arrayLikeToString; + + +/** + * Copy the data from an array-like to an other array-like. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} arrayFrom the origin array. + * @param {Array|ArrayBuffer|Uint8Array|Buffer} arrayTo the destination array which will be mutated. + * @return {Array|ArrayBuffer|Uint8Array|Buffer} the updated destination array. + */ +function arrayLikeToArrayLike(arrayFrom, arrayTo) { + for (var i = 0; i < arrayFrom.length; i++) { + arrayTo[i] = arrayFrom[i]; + } + return arrayTo; +} + +// a matrix containing functions to transform everything into everything. +var transform = {}; + +// string to ? +transform["string"] = { + "string": identity, + "array": function(input) { + return stringToArrayLike(input, new Array(input.length)); + }, + "arraybuffer": function(input) { + return transform["string"]["uint8array"](input).buffer; + }, + "uint8array": function(input) { + return stringToArrayLike(input, new Uint8Array(input.length)); + }, + "nodebuffer": function(input) { + return stringToArrayLike(input, nodejsUtils.allocBuffer(input.length)); + } +}; + +// array to ? +transform["array"] = { + "string": arrayLikeToString, + "array": identity, + "arraybuffer": function(input) { + return (new Uint8Array(input)).buffer; + }, + "uint8array": function(input) { + return new Uint8Array(input); + }, + "nodebuffer": function(input) { + return nodejsUtils.newBufferFrom(input); + } +}; + +// arraybuffer to ? +transform["arraybuffer"] = { + "string": function(input) { + return arrayLikeToString(new Uint8Array(input)); + }, + "array": function(input) { + return arrayLikeToArrayLike(new Uint8Array(input), new Array(input.byteLength)); + }, + "arraybuffer": identity, + "uint8array": function(input) { + return new Uint8Array(input); + }, + "nodebuffer": function(input) { + return nodejsUtils.newBufferFrom(new Uint8Array(input)); + } +}; + +// uint8array to ? +transform["uint8array"] = { + "string": arrayLikeToString, + "array": function(input) { + return arrayLikeToArrayLike(input, new Array(input.length)); + }, + "arraybuffer": function(input) { + return input.buffer; + }, + "uint8array": identity, + "nodebuffer": function(input) { + return nodejsUtils.newBufferFrom(input); + } +}; + +// nodebuffer to ? +transform["nodebuffer"] = { + "string": arrayLikeToString, + "array": function(input) { + return arrayLikeToArrayLike(input, new Array(input.length)); + }, + "arraybuffer": function(input) { + return transform["nodebuffer"]["uint8array"](input).buffer; + }, + "uint8array": function(input) { + return arrayLikeToArrayLike(input, new Uint8Array(input.length)); + }, + "nodebuffer": identity +}; + +/** + * Transform an input into any type. + * The supported output type are : string, array, uint8array, arraybuffer, nodebuffer. + * If no output type is specified, the unmodified input will be returned. + * @param {String} outputType the output type. + * @param {String|Array|ArrayBuffer|Uint8Array|Buffer} input the input to convert. + * @throws {Error} an Error if the browser doesn't support the requested output type. + */ +exports.transformTo = function(outputType, input) { + if (!input) { + // undefined, null, etc + // an empty string won't harm. + input = ""; + } + if (!outputType) { + return input; + } + exports.checkSupport(outputType); + var inputType = exports.getTypeOf(input); + var result = transform[inputType][outputType](input); + return result; +}; + +/** + * Resolve all relative path components, "." and "..", in a path. If these relative components + * traverse above the root then the resulting path will only contain the final path component. + * + * All empty components, e.g. "//", are removed. + * @param {string} path A path with / or \ separators + * @returns {string} The path with all relative path components resolved. + */ +exports.resolve = function(path) { + var parts = path.split("/"); + var result = []; + for (var index = 0; index < parts.length; index++) { + var part = parts[index]; + // Allow the first and last component to be empty for trailing slashes. + if (part === "." || (part === "" && index !== 0 && index !== parts.length - 1)) { + continue; + } else if (part === "..") { + result.pop(); + } else { + result.push(part); + } + } + return result.join("/"); +}; + +/** + * Return the type of the input. + * The type will be in a format valid for JSZip.utils.transformTo : string, array, uint8array, arraybuffer. + * @param {Object} input the input to identify. + * @return {String} the (lowercase) type of the input. + */ +exports.getTypeOf = function(input) { + if (typeof input === "string") { + return "string"; + } + if (Object.prototype.toString.call(input) === "[object Array]") { + return "array"; + } + if (support.nodebuffer && nodejsUtils.isBuffer(input)) { + return "nodebuffer"; + } + if (support.uint8array && input instanceof Uint8Array) { + return "uint8array"; + } + if (support.arraybuffer && input instanceof ArrayBuffer) { + return "arraybuffer"; + } +}; + +/** + * Throw an exception if the type is not supported. + * @param {String} type the type to check. + * @throws {Error} an Error if the browser doesn't support the requested type. + */ +exports.checkSupport = function(type) { + var supported = support[type.toLowerCase()]; + if (!supported) { + throw new Error(type + " is not supported by this platform"); + } +}; + +exports.MAX_VALUE_16BITS = 65535; +exports.MAX_VALUE_32BITS = -1; // well, "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" is parsed as -1 + +/** + * Prettify a string read as binary. + * @param {string} str the string to prettify. + * @return {string} a pretty string. + */ +exports.pretty = function(str) { + var res = "", + code, i; + for (i = 0; i < (str || "").length; i++) { + code = str.charCodeAt(i); + res += "\\x" + (code < 16 ? "0" : "") + code.toString(16).toUpperCase(); + } + return res; +}; + +/** + * Defer the call of a function. + * @param {Function} callback the function to call asynchronously. + * @param {Array} args the arguments to give to the callback. + */ +exports.delay = function(callback, args, self) { + setImmediate(function () { + callback.apply(self || null, args || []); + }); +}; + +/** + * Extends a prototype with an other, without calling a constructor with + * side effects. Inspired by nodejs' `utils.inherits` + * @param {Function} ctor the constructor to augment + * @param {Function} superCtor the parent constructor to use + */ +exports.inherits = function (ctor, superCtor) { + var Obj = function() {}; + Obj.prototype = superCtor.prototype; + ctor.prototype = new Obj(); +}; + +/** + * Merge the objects passed as parameters into a new one. + * @private + * @param {...Object} var_args All objects to merge. + * @return {Object} a new object with the data of the others. + */ +exports.extend = function() { + var result = {}, i, attr; + for (i = 0; i < arguments.length; i++) { // arguments is not enumerable in some browsers + for (attr in arguments[i]) { + if (Object.prototype.hasOwnProperty.call(arguments[i], attr) && typeof result[attr] === "undefined") { + result[attr] = arguments[i][attr]; + } + } + } + return result; +}; + +/** + * Transform arbitrary content into a Promise. + * @param {String} name a name for the content being processed. + * @param {Object} inputData the content to process. + * @param {Boolean} isBinary true if the content is not an unicode string + * @param {Boolean} isOptimizedBinaryString true if the string content only has one byte per character. + * @param {Boolean} isBase64 true if the string content is encoded with base64. + * @return {Promise} a promise in a format usable by JSZip. + */ +exports.prepareContent = function(name, inputData, isBinary, isOptimizedBinaryString, isBase64) { + + // if inputData is already a promise, this flatten it. + var promise = external.Promise.resolve(inputData).then(function(data) { + + + var isBlob = support.blob && (data instanceof Blob || ["[object File]", "[object Blob]"].indexOf(Object.prototype.toString.call(data)) !== -1); + + if (isBlob && typeof FileReader !== "undefined") { + return new external.Promise(function (resolve, reject) { + var reader = new FileReader(); + + reader.onload = function(e) { + resolve(e.target.result); + }; + reader.onerror = function(e) { + reject(e.target.error); + }; + reader.readAsArrayBuffer(data); + }); + } else { + return data; + } + }); + + return promise.then(function(data) { + var dataType = exports.getTypeOf(data); + + if (!dataType) { + return external.Promise.reject( + new Error("Can't read the data of '" + name + "'. Is it " + + "in a supported JavaScript type (String, Blob, ArrayBuffer, etc) ?") + ); + } + // special case : it's way easier to work with Uint8Array than with ArrayBuffer + if (dataType === "arraybuffer") { + data = exports.transformTo("uint8array", data); + } else if (dataType === "string") { + if (isBase64) { + data = base64.decode(data); + } + else if (isBinary) { + // optimizedBinaryString === true means that the file has already been filtered with a 0xFF mask + if (isOptimizedBinaryString !== true) { + // this is a string, not in a base64 format. + // Be sure that this is a correct "binary string" + data = string2binary(data); + } + } + } + return data; + }); +}; + +},{"./base64":1,"./external":6,"./nodejsUtils":14,"./support":30,"setimmediate":54}],33:[function(require,module,exports){ +"use strict"; +var readerFor = require("./reader/readerFor"); +var utils = require("./utils"); +var sig = require("./signature"); +var ZipEntry = require("./zipEntry"); +var support = require("./support"); +// class ZipEntries {{{ +/** + * All the entries in the zip file. + * @constructor + * @param {Object} loadOptions Options for loading the stream. + */ +function ZipEntries(loadOptions) { + this.files = []; + this.loadOptions = loadOptions; +} +ZipEntries.prototype = { + /** + * Check that the reader is on the specified signature. + * @param {string} expectedSignature the expected signature. + * @throws {Error} if it is an other signature. + */ + checkSignature: function(expectedSignature) { + if (!this.reader.readAndCheckSignature(expectedSignature)) { + this.reader.index -= 4; + var signature = this.reader.readString(4); + throw new Error("Corrupted zip or bug: unexpected signature " + "(" + utils.pretty(signature) + ", expected " + utils.pretty(expectedSignature) + ")"); + } + }, + /** + * Check if the given signature is at the given index. + * @param {number} askedIndex the index to check. + * @param {string} expectedSignature the signature to expect. + * @return {boolean} true if the signature is here, false otherwise. + */ + isSignature: function(askedIndex, expectedSignature) { + var currentIndex = this.reader.index; + this.reader.setIndex(askedIndex); + var signature = this.reader.readString(4); + var result = signature === expectedSignature; + this.reader.setIndex(currentIndex); + return result; + }, + /** + * Read the end of the central directory. + */ + readBlockEndOfCentral: function() { + this.diskNumber = this.reader.readInt(2); + this.diskWithCentralDirStart = this.reader.readInt(2); + this.centralDirRecordsOnThisDisk = this.reader.readInt(2); + this.centralDirRecords = this.reader.readInt(2); + this.centralDirSize = this.reader.readInt(4); + this.centralDirOffset = this.reader.readInt(4); + + this.zipCommentLength = this.reader.readInt(2); + // warning : the encoding depends of the system locale + // On a linux machine with LANG=en_US.utf8, this field is utf8 encoded. + // On a windows machine, this field is encoded with the localized windows code page. + var zipComment = this.reader.readData(this.zipCommentLength); + var decodeParamType = support.uint8array ? "uint8array" : "array"; + // To get consistent behavior with the generation part, we will assume that + // this is utf8 encoded unless specified otherwise. + var decodeContent = utils.transformTo(decodeParamType, zipComment); + this.zipComment = this.loadOptions.decodeFileName(decodeContent); + }, + /** + * Read the end of the Zip 64 central directory. + * Not merged with the method readEndOfCentral : + * The end of central can coexist with its Zip64 brother, + * I don't want to read the wrong number of bytes ! + */ + readBlockZip64EndOfCentral: function() { + this.zip64EndOfCentralSize = this.reader.readInt(8); + this.reader.skip(4); + // this.versionMadeBy = this.reader.readString(2); + // this.versionNeeded = this.reader.readInt(2); + this.diskNumber = this.reader.readInt(4); + this.diskWithCentralDirStart = this.reader.readInt(4); + this.centralDirRecordsOnThisDisk = this.reader.readInt(8); + this.centralDirRecords = this.reader.readInt(8); + this.centralDirSize = this.reader.readInt(8); + this.centralDirOffset = this.reader.readInt(8); + + this.zip64ExtensibleData = {}; + var extraDataSize = this.zip64EndOfCentralSize - 44, + index = 0, + extraFieldId, + extraFieldLength, + extraFieldValue; + while (index < extraDataSize) { + extraFieldId = this.reader.readInt(2); + extraFieldLength = this.reader.readInt(4); + extraFieldValue = this.reader.readData(extraFieldLength); + this.zip64ExtensibleData[extraFieldId] = { + id: extraFieldId, + length: extraFieldLength, + value: extraFieldValue + }; + } + }, + /** + * Read the end of the Zip 64 central directory locator. + */ + readBlockZip64EndOfCentralLocator: function() { + this.diskWithZip64CentralDirStart = this.reader.readInt(4); + this.relativeOffsetEndOfZip64CentralDir = this.reader.readInt(8); + this.disksCount = this.reader.readInt(4); + if (this.disksCount > 1) { + throw new Error("Multi-volumes zip are not supported"); + } + }, + /** + * Read the local files, based on the offset read in the central part. + */ + readLocalFiles: function() { + var i, file; + for (i = 0; i < this.files.length; i++) { + file = this.files[i]; + this.reader.setIndex(file.localHeaderOffset); + this.checkSignature(sig.LOCAL_FILE_HEADER); + file.readLocalPart(this.reader); + file.handleUTF8(); + file.processAttributes(); + } + }, + /** + * Read the central directory. + */ + readCentralDir: function() { + var file; + + this.reader.setIndex(this.centralDirOffset); + while (this.reader.readAndCheckSignature(sig.CENTRAL_FILE_HEADER)) { + file = new ZipEntry({ + zip64: this.zip64 + }, this.loadOptions); + file.readCentralPart(this.reader); + this.files.push(file); + } + + if (this.centralDirRecords !== this.files.length) { + if (this.centralDirRecords !== 0 && this.files.length === 0) { + // We expected some records but couldn't find ANY. + // This is really suspicious, as if something went wrong. + throw new Error("Corrupted zip or bug: expected " + this.centralDirRecords + " records in central dir, got " + this.files.length); + } else { + // We found some records but not all. + // Something is wrong but we got something for the user: no error here. + // console.warn("expected", this.centralDirRecords, "records in central dir, got", this.files.length); + } + } + }, + /** + * Read the end of central directory. + */ + readEndOfCentral: function() { + var offset = this.reader.lastIndexOfSignature(sig.CENTRAL_DIRECTORY_END); + if (offset < 0) { + // Check if the content is a truncated zip or complete garbage. + // A "LOCAL_FILE_HEADER" is not required at the beginning (auto + // extractible zip for example) but it can give a good hint. + // If an ajax request was used without responseType, we will also + // get unreadable data. + var isGarbage = !this.isSignature(0, sig.LOCAL_FILE_HEADER); + + if (isGarbage) { + throw new Error("Can't find end of central directory : is this a zip file ? " + + "If it is, see https://stuk.github.io/jszip/documentation/howto/read_zip.html"); + } else { + throw new Error("Corrupted zip: can't find end of central directory"); + } + + } + this.reader.setIndex(offset); + var endOfCentralDirOffset = offset; + this.checkSignature(sig.CENTRAL_DIRECTORY_END); + this.readBlockEndOfCentral(); + + + /* extract from the zip spec : + 4) If one of the fields in the end of central directory + record is too small to hold required data, the field + should be set to -1 (0xFFFF or 0xFFFFFFFF) and the + ZIP64 format record should be created. + 5) The end of central directory record and the + Zip64 end of central directory locator record must + reside on the same disk when splitting or spanning + an archive. + */ + if (this.diskNumber === utils.MAX_VALUE_16BITS || this.diskWithCentralDirStart === utils.MAX_VALUE_16BITS || this.centralDirRecordsOnThisDisk === utils.MAX_VALUE_16BITS || this.centralDirRecords === utils.MAX_VALUE_16BITS || this.centralDirSize === utils.MAX_VALUE_32BITS || this.centralDirOffset === utils.MAX_VALUE_32BITS) { + this.zip64 = true; + + /* + Warning : the zip64 extension is supported, but ONLY if the 64bits integer read from + the zip file can fit into a 32bits integer. This cannot be solved : JavaScript represents + all numbers as 64-bit double precision IEEE 754 floating point numbers. + So, we have 53bits for integers and bitwise operations treat everything as 32bits. + see https://developer.mozilla.org/en-US/docs/JavaScript/Reference/Operators/Bitwise_Operators + and http://www.ecma-international.org/publications/files/ECMA-ST/ECMA-262.pdf section 8.5 + */ + + // should look for a zip64 EOCD locator + offset = this.reader.lastIndexOfSignature(sig.ZIP64_CENTRAL_DIRECTORY_LOCATOR); + if (offset < 0) { + throw new Error("Corrupted zip: can't find the ZIP64 end of central directory locator"); + } + this.reader.setIndex(offset); + this.checkSignature(sig.ZIP64_CENTRAL_DIRECTORY_LOCATOR); + this.readBlockZip64EndOfCentralLocator(); + + // now the zip64 EOCD record + if (!this.isSignature(this.relativeOffsetEndOfZip64CentralDir, sig.ZIP64_CENTRAL_DIRECTORY_END)) { + // console.warn("ZIP64 end of central directory not where expected."); + this.relativeOffsetEndOfZip64CentralDir = this.reader.lastIndexOfSignature(sig.ZIP64_CENTRAL_DIRECTORY_END); + if (this.relativeOffsetEndOfZip64CentralDir < 0) { + throw new Error("Corrupted zip: can't find the ZIP64 end of central directory"); + } + } + this.reader.setIndex(this.relativeOffsetEndOfZip64CentralDir); + this.checkSignature(sig.ZIP64_CENTRAL_DIRECTORY_END); + this.readBlockZip64EndOfCentral(); + } + + var expectedEndOfCentralDirOffset = this.centralDirOffset + this.centralDirSize; + if (this.zip64) { + expectedEndOfCentralDirOffset += 20; // end of central dir 64 locator + expectedEndOfCentralDirOffset += 12 /* should not include the leading 12 bytes */ + this.zip64EndOfCentralSize; + } + + var extraBytes = endOfCentralDirOffset - expectedEndOfCentralDirOffset; + + if (extraBytes > 0) { + // console.warn(extraBytes, "extra bytes at beginning or within zipfile"); + if (this.isSignature(endOfCentralDirOffset, sig.CENTRAL_FILE_HEADER)) { + // The offsets seem wrong, but we have something at the specified offset. + // So… we keep it. + } else { + // the offset is wrong, update the "zero" of the reader + // this happens if data has been prepended (crx files for example) + this.reader.zero = extraBytes; + } + } else if (extraBytes < 0) { + throw new Error("Corrupted zip: missing " + Math.abs(extraBytes) + " bytes."); + } + }, + prepareReader: function(data) { + this.reader = readerFor(data); + }, + /** + * Read a zip file and create ZipEntries. + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the binary string representing a zip file. + */ + load: function(data) { + this.prepareReader(data); + this.readEndOfCentral(); + this.readCentralDir(); + this.readLocalFiles(); + } +}; +// }}} end of ZipEntries +module.exports = ZipEntries; + +},{"./reader/readerFor":22,"./signature":23,"./support":30,"./utils":32,"./zipEntry":34}],34:[function(require,module,exports){ +"use strict"; +var readerFor = require("./reader/readerFor"); +var utils = require("./utils"); +var CompressedObject = require("./compressedObject"); +var crc32fn = require("./crc32"); +var utf8 = require("./utf8"); +var compressions = require("./compressions"); +var support = require("./support"); + +var MADE_BY_DOS = 0x00; +var MADE_BY_UNIX = 0x03; + +/** + * Find a compression registered in JSZip. + * @param {string} compressionMethod the method magic to find. + * @return {Object|null} the JSZip compression object, null if none found. + */ +var findCompression = function(compressionMethod) { + for (var method in compressions) { + if (!Object.prototype.hasOwnProperty.call(compressions, method)) { + continue; + } + if (compressions[method].magic === compressionMethod) { + return compressions[method]; + } + } + return null; +}; + +// class ZipEntry {{{ +/** + * An entry in the zip file. + * @constructor + * @param {Object} options Options of the current file. + * @param {Object} loadOptions Options for loading the stream. + */ +function ZipEntry(options, loadOptions) { + this.options = options; + this.loadOptions = loadOptions; +} +ZipEntry.prototype = { + /** + * say if the file is encrypted. + * @return {boolean} true if the file is encrypted, false otherwise. + */ + isEncrypted: function() { + // bit 1 is set + return (this.bitFlag & 0x0001) === 0x0001; + }, + /** + * say if the file has utf-8 filename/comment. + * @return {boolean} true if the filename/comment is in utf-8, false otherwise. + */ + useUTF8: function() { + // bit 11 is set + return (this.bitFlag & 0x0800) === 0x0800; + }, + /** + * Read the local part of a zip file and add the info in this object. + * @param {DataReader} reader the reader to use. + */ + readLocalPart: function(reader) { + var compression, localExtraFieldsLength; + + // we already know everything from the central dir ! + // If the central dir data are false, we are doomed. + // On the bright side, the local part is scary : zip64, data descriptors, both, etc. + // The less data we get here, the more reliable this should be. + // Let's skip the whole header and dash to the data ! + reader.skip(22); + // in some zip created on windows, the filename stored in the central dir contains \ instead of /. + // Strangely, the filename here is OK. + // I would love to treat these zip files as corrupted (see http://www.info-zip.org/FAQ.html#backslashes + // or APPNOTE#4.4.17.1, "All slashes MUST be forward slashes '/'") but there are a lot of bad zip generators... + // Search "unzip mismatching "local" filename continuing with "central" filename version" on + // the internet. + // + // I think I see the logic here : the central directory is used to display + // content and the local directory is used to extract the files. Mixing / and \ + // may be used to display \ to windows users and use / when extracting the files. + // Unfortunately, this lead also to some issues : http://seclists.org/fulldisclosure/2009/Sep/394 + this.fileNameLength = reader.readInt(2); + localExtraFieldsLength = reader.readInt(2); // can't be sure this will be the same as the central dir + // the fileName is stored as binary data, the handleUTF8 method will take care of the encoding. + this.fileName = reader.readData(this.fileNameLength); + reader.skip(localExtraFieldsLength); + + if (this.compressedSize === -1 || this.uncompressedSize === -1) { + throw new Error("Bug or corrupted zip : didn't get enough information from the central directory " + "(compressedSize === -1 || uncompressedSize === -1)"); + } + + compression = findCompression(this.compressionMethod); + if (compression === null) { // no compression found + throw new Error("Corrupted zip : compression " + utils.pretty(this.compressionMethod) + " unknown (inner file : " + utils.transformTo("string", this.fileName) + ")"); + } + this.decompressed = new CompressedObject(this.compressedSize, this.uncompressedSize, this.crc32, compression, reader.readData(this.compressedSize)); + }, + + /** + * Read the central part of a zip file and add the info in this object. + * @param {DataReader} reader the reader to use. + */ + readCentralPart: function(reader) { + this.versionMadeBy = reader.readInt(2); + reader.skip(2); + // this.versionNeeded = reader.readInt(2); + this.bitFlag = reader.readInt(2); + this.compressionMethod = reader.readString(2); + this.date = reader.readDate(); + this.crc32 = reader.readInt(4); + this.compressedSize = reader.readInt(4); + this.uncompressedSize = reader.readInt(4); + var fileNameLength = reader.readInt(2); + this.extraFieldsLength = reader.readInt(2); + this.fileCommentLength = reader.readInt(2); + this.diskNumberStart = reader.readInt(2); + this.internalFileAttributes = reader.readInt(2); + this.externalFileAttributes = reader.readInt(4); + this.localHeaderOffset = reader.readInt(4); + + if (this.isEncrypted()) { + throw new Error("Encrypted zip are not supported"); + } + + // will be read in the local part, see the comments there + reader.skip(fileNameLength); + this.readExtraFields(reader); + this.parseZIP64ExtraField(reader); + this.fileComment = reader.readData(this.fileCommentLength); + }, + + /** + * Parse the external file attributes and get the unix/dos permissions. + */ + processAttributes: function () { + this.unixPermissions = null; + this.dosPermissions = null; + var madeBy = this.versionMadeBy >> 8; + + // Check if we have the DOS directory flag set. + // We look for it in the DOS and UNIX permissions + // but some unknown platform could set it as a compatibility flag. + this.dir = this.externalFileAttributes & 0x0010 ? true : false; + + if(madeBy === MADE_BY_DOS) { + // first 6 bits (0 to 5) + this.dosPermissions = this.externalFileAttributes & 0x3F; + } + + if(madeBy === MADE_BY_UNIX) { + this.unixPermissions = (this.externalFileAttributes >> 16) & 0xFFFF; + // the octal permissions are in (this.unixPermissions & 0x01FF).toString(8); + } + + // fail safe : if the name ends with a / it probably means a folder + if (!this.dir && this.fileNameStr.slice(-1) === "/") { + this.dir = true; + } + }, + + /** + * Parse the ZIP64 extra field and merge the info in the current ZipEntry. + * @param {DataReader} reader the reader to use. + */ + parseZIP64ExtraField: function() { + if (!this.extraFields[0x0001]) { + return; + } + + // should be something, preparing the extra reader + var extraReader = readerFor(this.extraFields[0x0001].value); + + // I really hope that these 64bits integer can fit in 32 bits integer, because js + // won't let us have more. + if (this.uncompressedSize === utils.MAX_VALUE_32BITS) { + this.uncompressedSize = extraReader.readInt(8); + } + if (this.compressedSize === utils.MAX_VALUE_32BITS) { + this.compressedSize = extraReader.readInt(8); + } + if (this.localHeaderOffset === utils.MAX_VALUE_32BITS) { + this.localHeaderOffset = extraReader.readInt(8); + } + if (this.diskNumberStart === utils.MAX_VALUE_32BITS) { + this.diskNumberStart = extraReader.readInt(4); + } + }, + /** + * Read the central part of a zip file and add the info in this object. + * @param {DataReader} reader the reader to use. + */ + readExtraFields: function(reader) { + var end = reader.index + this.extraFieldsLength, + extraFieldId, + extraFieldLength, + extraFieldValue; + + if (!this.extraFields) { + this.extraFields = {}; + } + + while (reader.index + 4 < end) { + extraFieldId = reader.readInt(2); + extraFieldLength = reader.readInt(2); + extraFieldValue = reader.readData(extraFieldLength); + + this.extraFields[extraFieldId] = { + id: extraFieldId, + length: extraFieldLength, + value: extraFieldValue + }; + } + + reader.setIndex(end); + }, + /** + * Apply an UTF8 transformation if needed. + */ + handleUTF8: function() { + var decodeParamType = support.uint8array ? "uint8array" : "array"; + if (this.useUTF8()) { + this.fileNameStr = utf8.utf8decode(this.fileName); + this.fileCommentStr = utf8.utf8decode(this.fileComment); + } else { + var upath = this.findExtraFieldUnicodePath(); + if (upath !== null) { + this.fileNameStr = upath; + } else { + // ASCII text or unsupported code page + var fileNameByteArray = utils.transformTo(decodeParamType, this.fileName); + this.fileNameStr = this.loadOptions.decodeFileName(fileNameByteArray); + } + + var ucomment = this.findExtraFieldUnicodeComment(); + if (ucomment !== null) { + this.fileCommentStr = ucomment; + } else { + // ASCII text or unsupported code page + var commentByteArray = utils.transformTo(decodeParamType, this.fileComment); + this.fileCommentStr = this.loadOptions.decodeFileName(commentByteArray); + } + } + }, + + /** + * Find the unicode path declared in the extra field, if any. + * @return {String} the unicode path, null otherwise. + */ + findExtraFieldUnicodePath: function() { + var upathField = this.extraFields[0x7075]; + if (upathField) { + var extraReader = readerFor(upathField.value); + + // wrong version + if (extraReader.readInt(1) !== 1) { + return null; + } + + // the crc of the filename changed, this field is out of date. + if (crc32fn(this.fileName) !== extraReader.readInt(4)) { + return null; + } + + return utf8.utf8decode(extraReader.readData(upathField.length - 5)); + } + return null; + }, + + /** + * Find the unicode comment declared in the extra field, if any. + * @return {String} the unicode comment, null otherwise. + */ + findExtraFieldUnicodeComment: function() { + var ucommentField = this.extraFields[0x6375]; + if (ucommentField) { + var extraReader = readerFor(ucommentField.value); + + // wrong version + if (extraReader.readInt(1) !== 1) { + return null; + } + + // the crc of the comment changed, this field is out of date. + if (crc32fn(this.fileComment) !== extraReader.readInt(4)) { + return null; + } + + return utf8.utf8decode(extraReader.readData(ucommentField.length - 5)); + } + return null; + } +}; +module.exports = ZipEntry; + +},{"./compressedObject":2,"./compressions":3,"./crc32":4,"./reader/readerFor":22,"./support":30,"./utf8":31,"./utils":32}],35:[function(require,module,exports){ +"use strict"; + +var StreamHelper = require("./stream/StreamHelper"); +var DataWorker = require("./stream/DataWorker"); +var utf8 = require("./utf8"); +var CompressedObject = require("./compressedObject"); +var GenericWorker = require("./stream/GenericWorker"); + +/** + * A simple object representing a file in the zip file. + * @constructor + * @param {string} name the name of the file + * @param {String|ArrayBuffer|Uint8Array|Buffer} data the data + * @param {Object} options the options of the file + */ +var ZipObject = function(name, data, options) { + this.name = name; + this.dir = options.dir; + this.date = options.date; + this.comment = options.comment; + this.unixPermissions = options.unixPermissions; + this.dosPermissions = options.dosPermissions; + + this._data = data; + this._dataBinary = options.binary; + // keep only the compression + this.options = { + compression : options.compression, + compressionOptions : options.compressionOptions + }; +}; + +ZipObject.prototype = { + /** + * Create an internal stream for the content of this object. + * @param {String} type the type of each chunk. + * @return StreamHelper the stream. + */ + internalStream: function (type) { + var result = null, outputType = "string"; + try { + if (!type) { + throw new Error("No output type specified."); + } + outputType = type.toLowerCase(); + var askUnicodeString = outputType === "string" || outputType === "text"; + if (outputType === "binarystring" || outputType === "text") { + outputType = "string"; + } + result = this._decompressWorker(); + + var isUnicodeString = !this._dataBinary; + + if (isUnicodeString && !askUnicodeString) { + result = result.pipe(new utf8.Utf8EncodeWorker()); + } + if (!isUnicodeString && askUnicodeString) { + result = result.pipe(new utf8.Utf8DecodeWorker()); + } + } catch (e) { + result = new GenericWorker("error"); + result.error(e); + } + + return new StreamHelper(result, outputType, ""); + }, + + /** + * Prepare the content in the asked type. + * @param {String} type the type of the result. + * @param {Function} onUpdate a function to call on each internal update. + * @return Promise the promise of the result. + */ + async: function (type, onUpdate) { + return this.internalStream(type).accumulate(onUpdate); + }, + + /** + * Prepare the content as a nodejs stream. + * @param {String} type the type of each chunk. + * @param {Function} onUpdate a function to call on each internal update. + * @return Stream the stream. + */ + nodeStream: function (type, onUpdate) { + return this.internalStream(type || "nodebuffer").toNodejsStream(onUpdate); + }, + + /** + * Return a worker for the compressed content. + * @private + * @param {Object} compression the compression object to use. + * @param {Object} compressionOptions the options to use when compressing. + * @return Worker the worker. + */ + _compressWorker: function (compression, compressionOptions) { + if ( + this._data instanceof CompressedObject && + this._data.compression.magic === compression.magic + ) { + return this._data.getCompressedWorker(); + } else { + var result = this._decompressWorker(); + if(!this._dataBinary) { + result = result.pipe(new utf8.Utf8EncodeWorker()); + } + return CompressedObject.createWorkerFrom(result, compression, compressionOptions); + } + }, + /** + * Return a worker for the decompressed content. + * @private + * @return Worker the worker. + */ + _decompressWorker : function () { + if (this._data instanceof CompressedObject) { + return this._data.getContentWorker(); + } else if (this._data instanceof GenericWorker) { + return this._data; + } else { + return new DataWorker(this._data); + } + } +}; + +var removedMethods = ["asText", "asBinary", "asNodeBuffer", "asUint8Array", "asArrayBuffer"]; +var removedFn = function () { + throw new Error("This method has been removed in JSZip 3.0, please check the upgrade guide."); +}; + +for(var i = 0; i < removedMethods.length; i++) { + ZipObject.prototype[removedMethods[i]] = removedFn; +} +module.exports = ZipObject; + +},{"./compressedObject":2,"./stream/DataWorker":27,"./stream/GenericWorker":28,"./stream/StreamHelper":29,"./utf8":31}],36:[function(require,module,exports){ +(function (global){ +'use strict'; +var Mutation = global.MutationObserver || global.WebKitMutationObserver; + +var scheduleDrain; + +{ + if (Mutation) { + var called = 0; + var observer = new Mutation(nextTick); + var element = global.document.createTextNode(''); + observer.observe(element, { + characterData: true + }); + scheduleDrain = function () { + element.data = (called = ++called % 2); + }; + } else if (!global.setImmediate && typeof global.MessageChannel !== 'undefined') { + var channel = new global.MessageChannel(); + channel.port1.onmessage = nextTick; + scheduleDrain = function () { + channel.port2.postMessage(0); + }; + } else if ('document' in global && 'onreadystatechange' in global.document.createElement('script')) { + scheduleDrain = function () { + + // Create a + + diff --git a/spd/clustering/clustering_run_config.py b/spd/clustering/clustering_run_config.py new file mode 100644 index 000000000..95d72f9bd --- /dev/null +++ b/spd/clustering/clustering_run_config.py @@ -0,0 +1,136 @@ +"""ClusteringRunConfig""" + +import base64 +import hashlib +import json +from pathlib import Path +from typing import Any + +from pydantic import Field, PositiveInt, field_validator, model_validator + +from spd.base_config import BaseConfig +from spd.clustering.merge_config import MergeConfig +from spd.registry import EXPERIMENT_REGISTRY +from spd.settings import SPD_CACHE_DIR + + +class LoggingIntervals(BaseConfig): + """Intervals in which to log each type of output.""" + + stat: PositiveInt = Field( + default=1, description="Logging statistics (e.g., k_groups, merge_pair_cost, mdl_loss)" + ) + tensor: PositiveInt = Field( + default=100, description="Logging tensors (e.g., wandb_log_tensor, fraction calculations)" + ) + plot: PositiveInt = Field( + default=100, description="Generating plots (e.g., plot_merge_iteration)" + ) + artifact: PositiveInt = Field( + default=100, description="Creating artifacts (e.g., merge_history)" + ) + + +class ClusteringRunConfig(BaseConfig): + """Configuration for a single clustering run. + + This config specifies the clustering algorithm parameters and data processing settings. + Deployment concerns (where to save, WandB settings, ensemble configuration) are handled + by ClusteringSubmitConfig. + """ + + # TODO: Handle both wandb strings and local file paths + model_path: str = Field( + description="WandB path to the decomposed model (format: wandb:entity/project/run_id)" + ) + + batch_size: PositiveInt = Field(..., description="Batch size for processing") + dataset_seed: int = Field(0, description="Seed for dataset generation/loading") + base_output_dir: Path = Field( + default=SPD_CACHE_DIR / "clustering", + description="Base directory to save clustering runs", + ) + ensemble_id: str | None = Field( + default=None, + description="Ensemble identifier for WandB grouping", + ) + merge_config: MergeConfig = Field(description="Merge algorithm configuration") + logging_intervals: LoggingIntervals = Field( + default_factory=LoggingIntervals, + description="Logging intervals", + ) + + wandb_project: str | None = Field( + default=None, + description="WandB project name (None to disable WandB logging)", + ) + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") + dataset_streaming: bool = Field( + default=False, + description="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) + + @model_validator(mode="before") + def process_experiment_key(cls, values: dict[str, Any]) -> dict[str, Any]: + experiment_key: str | None = values.get("experiment_key") + if experiment_key: + model_path_given: str | None = values.get("model_path") + model_path_from_experiment: str | None = EXPERIMENT_REGISTRY[ + experiment_key + ].canonical_run + assert model_path_from_experiment is not None, ( + f"Experiment '{experiment_key}' has no canonical_run defined in the EXPERIMENT_REGISTRY" + ) + if model_path_given and model_path_given != model_path_from_experiment: + raise ValueError( + f"Both experiment_key '{experiment_key}' and model_path '{model_path_given}' given in config data, but they disagree: {model_path_from_experiment=}" + ) + + values["model_path"] = model_path_from_experiment + del values["experiment_key"] + + return values + + @field_validator("model_path") + def validate_model_path(cls, v: str) -> str: + """Validate that model_path is a proper WandB path.""" + if not v.startswith("wandb:"): + raise ValueError(f"model_path must start with 'wandb:', got: {v}") + return v + + @property + def wandb_decomp_model(self) -> str: + """Extract the WandB run ID of the source decomposition.""" + parts = self.model_path.replace("wandb:", "").split("/") + if len(parts) >= 3: + return parts[-1] if parts[-1] != "runs" else parts[-2] + raise ValueError(f"Invalid wandb path format: {self.model_path}") + + def model_dump_with_properties(self) -> dict[str, Any]: + """Serialize config including computed properties for WandB logging.""" + base_dump: dict[str, Any] = self.model_dump(mode="json") + + # Add computed properties + base_dump.update( + { + "wandb_decomp_model": self.wandb_decomp_model, + } + ) + + return base_dump + + def stable_hash_b64(self) -> str: + """Generate a stable, deterministic base64-encoded hash of this config. + + Uses SHA256 hash of the JSON representation with sorted keys for determinism. + Returns URL-safe base64 encoding without padding. + + Returns: + URL-safe base64-encoded hash (without padding) + """ + config_dict: dict[str, Any] = self.model_dump(mode="json") + config_json: str = json.dumps(config_dict, indent=2, sort_keys=True) + hash_digest: bytes = hashlib.sha256(config_json.encode()).digest() + # Use base64 URL-safe encoding and strip padding for filesystem safety + hash_b64: str = base64.urlsafe_b64encode(hash_digest).decode().rstrip("=") + return hash_b64 diff --git a/spd/clustering/compute_costs.py b/spd/clustering/compute_costs.py new file mode 100644 index 000000000..f1b3425d1 --- /dev/null +++ b/spd/clustering/compute_costs.py @@ -0,0 +1,189 @@ +import math + +import torch +from jaxtyping import Bool, Float +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, MergePair +from spd.clustering.math.merge_matrix import GroupMerge + + +def compute_mdl_cost( + acts: Float[Tensor, " k_groups"], + merges: GroupMerge, + alpha: float = 1.0, +) -> float: + r"""Compute MDL costs for merge matrices + + $$ + MDL = \sum_{i \in \N_k} s_i ( \log(k) + \alpha r(P_i) ) + $$ + + where: + - $s_i$ activation of component $i$, $s_j$ activation of component $j$ + - $r(P_i)$ rank of component $i$, $r(P_j)$ rank of component $j$ + - $k$ is the total number of components + """ + + k_groups: int = acts.shape[0] + assert k_groups == merges.k_groups, "Merges must match activation vector shape" + + return ( + (acts * (math.log2(k_groups) + alpha * merges.components_per_group.to(device=acts.device))) + .sum() + .item() + ) + + +def compute_merge_costs( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + alpha: float = 1.0, +) -> ClusterCoactivationShaped: + r"""Compute MDL costs for merge matrices + + $$ + F(P_i, P_j) + = \alpha |s_i| r(P_i) + \alpha |s_j| r(P_j) + - s_i s_j ( \alpha r(P_i) + \alpha r(P_j) + c ) + = \alpha ( + |s_i| r(P_i) + + |s_j| r(P_j) + - s_i s_j ( r(P_i) + r(P_j) + c/\alpha ) + ) + $$ + + new version from nathu 2025-08-11 16:48 + + $$ + (s_\Sigma - s_i - s_j) log((c-1)/c) + + s_{i,j} log(c-1) - s_i log(c) - s_j log(c) + + alpha ( s_{i,j} r(P_{i,j}) - s_i r(P_i) - s_j r(P_j) ) + $$ + where: + - $s_\Sigma$ average activation of all components + - $s_i$ activation of component $i$, $s_j$ activation of component $j$ + - $s_{i,j}$ activation of the merged component $i,j$ + - $r(P_i)$ rank of component $i$, $r(P_j)$ rank of component $j$ + - $r(P_{i,j})$ rank of the merged component $i,j$ + + """ + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + assert merges.k_groups == k_groups, "Merges must match coactivation matrix shape" + + device: torch.device = coact.device + ranks: Float[Tensor, " k_groups"] = merges.components_per_group.to(device=device).float() + s_diag: Float[Tensor, " k_groups"] = torch.diag(coact).to(device=device) + # term_si_rpj: Float[Tensor, "k_groups k_groups"] = s_diag.view(-1, 1) * ranks.view(1, -1) + # term_si_rpj: Float[Tensor, "k_groups k_groups"] = s_diag.view(-1, 1) * (ranks.view(1, -1) + 1/alpha) + term_si_rpi: Float[Tensor, " k_groups"] = s_diag * ranks + # dbg_auto(term_si_rpi) + rank_sum: ClusterCoactivationShaped = ranks.view(-1, 1) + ranks.view(1, -1) + # TODO: use dynamic rank computation + # return alpha * ( + # term_si_rpj # |s_i| r(P_j) + # + term_si_rpj.T # |s_j| r(P_i) + # - coact * ( # s_i s_j + # rank_sum # r(P_i) + r(P_j) + # + (rank_cost(merges.k_groups) / alpha) # c / alpha + # ) + # ) + + coact_OR: ClusterCoactivationShaped = s_diag.view(-1, 1) + s_diag.view(1, -1) - coact + + # reduce penalty for sending dictionary by 1 + # (s_\Sigma - s_i - s_j) log((c-1)/c) + # delta of cost for sending index, in expectation + # + s_{i,j} log(c-1) - s_i log(c) - s_j log(c) + # delta of cost for sending ranks, in expectation + # + alpha ( s_{i,j} r(P_{i,j}) - s_i r(P_i) - s_j r(P_j) + + s_other: ClusterCoactivationShaped = ( + s_diag.sum() - s_diag.view(-1, 1) - s_diag.view(1, -1) + ) * math.log2((k_groups - 1) / k_groups) + + bits_local: ClusterCoactivationShaped = ( + coact_OR * math.log2(k_groups - 1) + - s_diag.view(-1, 1) * math.log2(k_groups) + - s_diag.view(1, -1) * math.log2(k_groups) + ) + + penalty: ClusterCoactivationShaped = ( + coact_OR * rank_sum # s_{i,j} r(P_{i,j}) + - term_si_rpi.view(-1, 1) # s_i r(P_i) + - term_si_rpi.view(1, -1) # s_j r(P_j) + ) + + output: ClusterCoactivationShaped = s_other + bits_local + alpha * penalty + return output + + +def recompute_coacts_merge_pair( + coact: ClusterCoactivationShaped, + merges: GroupMerge, + merge_pair: MergePair, + activation_mask: Bool[Tensor, "samples k_groups"], +) -> tuple[ + GroupMerge, + Float[Tensor, "k_groups-1 k_groups-1"], + Bool[Tensor, "samples k_groups"], +]: + # check shape + k_groups: int = coact.shape[0] + assert coact.shape[1] == k_groups, "Coactivation matrix must be square" + + # activations of the new merged group + activation_mask_grp: Bool[Tensor, " samples"] = ( + activation_mask[:, merge_pair[0]] + activation_mask[:, merge_pair[1]] + ) + + # coactivations with the new merged group + coact_with_merge: Float[Tensor, " k_groups"] = ( + activation_mask_grp.float() @ activation_mask.float() + ) + new_group_idx: int = min(merge_pair) + remove_idx: int = max(merge_pair) + new_group_self_coact: float = activation_mask_grp.float().sum().item() + + # assemble the merge pair + merge_new: GroupMerge = merges.merge_groups( + merge_pair[0], + merge_pair[1], + ) + # TODO: we don't use this index for anything, and could reconstruct it from the merge pair if needed. get rid of it + # `merge_groups` will set `old_to_new_idx` to be an actual dict for `merge_new` + old_to_new_idx: dict[int | None, int | None] = merge_new.old_to_new_idx # pyright: ignore[reportAssignmentType] + assert old_to_new_idx[None] == new_group_idx, ( + "New group index should be the minimum of the merge pair" + ) + assert old_to_new_idx[new_group_idx] is None + assert old_to_new_idx[remove_idx] is None + # TODO: check that the rest are in order? probably not necessary + + # reindex coactivations + coact_temp: ClusterCoactivationShaped = coact.clone() + # add in the similarities with the new group + coact_temp[new_group_idx, :] = coact_with_merge + coact_temp[:, new_group_idx] = coact_with_merge + # delete the old group + mask: Bool[Tensor, " k_groups"] = torch.ones( + coact_temp.shape[0], dtype=torch.bool, device=coact_temp.device + ) + mask[remove_idx] = False + coact_new: Float[Tensor, "k_groups-1 k_groups-1"] = coact_temp[mask, :][:, mask] + # add in the self-coactivation of the new group + coact_new[new_group_idx, new_group_idx] = new_group_self_coact + + # reindex mask + activation_mask_new: Float[Tensor, "samples ..."] = activation_mask.clone() + # add in the new group + activation_mask_new[:, new_group_idx] = activation_mask_grp + # remove the old group + activation_mask_new = activation_mask_new[:, mask] + + return ( + merge_new, + coact_new, + activation_mask_new, + ) diff --git a/spd/clustering/configs/README.md b/spd/clustering/configs/README.md new file mode 100644 index 000000000..e1ac41f47 --- /dev/null +++ b/spd/clustering/configs/README.md @@ -0,0 +1 @@ +this folder `configs/` contains files with `ClusteringPipelineConfig`s. the folder `configs/crc/` contains files with `ClusteringRunConfig`s, which may be referenced by the `clustering_run_config_path` field in the pipeline configs. \ No newline at end of file diff --git a/spd/clustering/configs/crc/example.yaml b/spd/clustering/configs/crc/example.yaml new file mode 100644 index 000000000..9345307d2 --- /dev/null +++ b/spd/clustering/configs/crc/example.yaml @@ -0,0 +1,23 @@ +model_path: wandb:goodfire/spd/runs/zxbu57pt # WandB path to the decomposed model +batch_size: 8 # Batch size for processing -- number of samples for each run in the ensemble +dataset_seed: 0 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) +# output_dir: .data/clustering/clustering_runs # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) +# ensemble_id: 1234567890 # Note, overridden if run in the pipeline (spd/clustering/scripts/run_pipeline.py) + +merge_config: + activation_threshold: 0.01 # set to null to use scalar activations for cost calculation + alpha: 1.0 # rank penalty term + iters: 10 # iterations to run. setting this to exactly the number of components can be buggy when doing ensembles, so set it to a bit less? + merge_pair_sampling_method: "range" # Method for sampling merge pairs: 'range' or 'mcmc' + merge_pair_sampling_kwargs: + threshold: 0.05 # For range sampler: fraction of the range of costs to sample from + filter_dead_threshold: 0.001 # Threshold for filtering dead components + module_name_filter: null # Can be a string prefix like "model.layers.0." if you want to do only some modules + +wandb_project: spd-cluster +wandb_entity: goodfire +logging_intervals: + stat: 1 # for k_groups, merge_pair_cost, mdl_loss + tensor: 100 # for wandb_log_tensor and fraction_* calculations + plot: 100 # for calling the plotting callback + artifact: 100 # for calling the artifact callback \ No newline at end of file diff --git a/spd/clustering/configs/crc/resid_mlp1.json b/spd/clustering/configs/crc/resid_mlp1.json new file mode 100644 index 000000000..1e13ce23e --- /dev/null +++ b/spd/clustering/configs/crc/resid_mlp1.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.01, + "alpha": 1, + "iters": 5, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0, + "module_name_filter": null + }, + "experiment_key": "resid_mlp1", + "batch_size": 128, + "wandb_project": "spd-cluster", + "logging_intervals": { + "stat": 1, + "tensor": 5, + "plot": 5, + "artifact": 5 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/resid_mlp2.json b/spd/clustering/configs/crc/resid_mlp2.json new file mode 100644 index 000000000..edc4849e2 --- /dev/null +++ b/spd/clustering/configs/crc/resid_mlp2.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.01, + "alpha": 1, + "iters": 100, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0.01, + "module_name_filter": null + }, + "experiment_key": "resid_mlp2", + "batch_size": 1024, + "wandb_project": "spd-cluster", + "logging_intervals": { + "stat": 1, + "tensor": 5, + "plot": 5, + "artifact": 50 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/simplestories_dev.json b/spd/clustering/configs/crc/simplestories_dev.json new file mode 100644 index 000000000..e1647b6e4 --- /dev/null +++ b/spd/clustering/configs/crc/simplestories_dev.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.1, + "alpha": 1.0, + "iters": 100, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.001}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "model_path": "wandb:goodfire/spd/runs/lxs77xye", + "batch_size": 32, + "wandb_project": null, + "logging_intervals": { + "stat": 1, + "tensor": 200, + "plot": 2000, + "artifact": 2000 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/test-resid_mlp1.json b/spd/clustering/configs/crc/test-resid_mlp1.json new file mode 100644 index 000000000..4b3a26ff8 --- /dev/null +++ b/spd/clustering/configs/crc/test-resid_mlp1.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.5, + "alpha": 1, + "iters": 16, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0.1, + "module_name_filter": null + }, + "experiment_key": "resid_mlp1", + "batch_size": 128, + "wandb_project": null, + "logging_intervals": { + "stat": 1, + "tensor": 5, + "plot": 10, + "artifact": 10 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/crc/test-simplestories.json b/spd/clustering/configs/crc/test-simplestories.json new file mode 100644 index 000000000..911f71529 --- /dev/null +++ b/spd/clustering/configs/crc/test-simplestories.json @@ -0,0 +1,20 @@ +{ + "merge_config": { + "activation_threshold": 0.9, + "alpha": 1.0, + "iters": 5, + "merge_pair_sampling_method": "range", + "merge_pair_sampling_kwargs": {"threshold": 0.05}, + "filter_dead_threshold": 0.9, + "module_name_filter": "model.layers.0" + }, + "model_path": "wandb:goodfire/spd/runs/lxs77xye", + "batch_size": 1, + "wandb_project": null, + "logging_intervals": { + "stat": 1, + "tensor": 2, + "plot": 3, + "artifact": 4 + } +} \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-dev-simplestories.yaml b/spd/clustering/configs/pipeline-dev-simplestories.yaml new file mode 100644 index 000000000..1868b5887 --- /dev/null +++ b/spd/clustering/configs/pipeline-dev-simplestories.yaml @@ -0,0 +1,9 @@ +n_runs: 2 +distances_methods: ["matching_dist"] +# base_output_dir: "tests/.temp/clustering" +slurm_job_name_prefix: null +slurm_partition: null +wandb_project: "spd-cluster" # wandb fails in CI +wandb_entity: "goodfire" +create_git_snapshot: false +clustering_run_config_path: "spd/clustering/configs/crc/simplestories_dev.json" \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-resid_mlp1.yaml b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml new file mode 100644 index 000000000..37833c82c --- /dev/null +++ b/spd/clustering/configs/pipeline-test-resid_mlp1.yaml @@ -0,0 +1,9 @@ +clustering_run_config_path: "spd/clustering/configs/crc/test-resid_mlp1.json" +n_runs: 3 +distances_methods: ["matching_dist"] +base_output_dir: "tests/.temp/clustering" +slurm_job_name_prefix: null +slurm_partition: null +wandb_project: null # wandb fails in CI +wandb_entity: "goodfire" +create_git_snapshot: false \ No newline at end of file diff --git a/spd/clustering/configs/pipeline-test-simplestories.yaml b/spd/clustering/configs/pipeline-test-simplestories.yaml new file mode 100644 index 000000000..9872062d2 --- /dev/null +++ b/spd/clustering/configs/pipeline-test-simplestories.yaml @@ -0,0 +1,9 @@ +clustering_run_config_path: "spd/clustering/configs/crc/test-simplestories.json" +n_runs: 2 +distances_methods: ["matching_dist"] +base_output_dir: "tests/.temp/clustering" +slurm_job_name_prefix: null +slurm_partition: null +wandb_project: null # wandb fails in CI +wandb_entity: "goodfire" +create_git_snapshot: false \ No newline at end of file diff --git a/spd/clustering/configs/pipeline_config.yaml b/spd/clustering/configs/pipeline_config.yaml new file mode 100644 index 000000000..3a533885d --- /dev/null +++ b/spd/clustering/configs/pipeline_config.yaml @@ -0,0 +1,9 @@ +clustering_run_config_path: "spd/clustering/configs/crc/example.yaml" +n_runs: 2 +distances_methods: ["perm_invariant_hamming"] +base_output_dir: "/mnt/polished-lake/spd/clustering" +slurm_job_name_prefix: "spd" +slurm_partition: "h100-reserved" +wandb_project: "spd-cluster" +wandb_entity: "goodfire" +create_git_snapshot: true \ No newline at end of file diff --git a/spd/clustering/consts.py b/spd/clustering/consts.py new file mode 100644 index 000000000..8a9647dc8 --- /dev/null +++ b/spd/clustering/consts.py @@ -0,0 +1,48 @@ +"""Constants and shared abstractions for clustering pipeline.""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Literal, NewType + +import numpy as np +from jaxtyping import Bool, Float, Int +from torch import Tensor + +# Merge arrays and distances (numpy-based for storage/analysis) +MergesAtIterArray = Int[np.ndarray, "n_ens n_components"] +MergesArray = Int[np.ndarray, "n_ens n_iters n_components"] +DistancesMethod = Literal["perm_invariant_hamming", "matching_dist", "matching_dist_vec"] +DistancesArray = Float[np.ndarray, "n_iters n_ens n_ens"] + +# Component and label types (NewType for stronger type safety) +ComponentLabel = NewType("ComponentLabel", str) # Format: "module_name:component_index" +ComponentLabels = NewType("ComponentLabels", list[str]) +BatchId = NewType("BatchId", str) + +# Path types +WandBPath = NewType("WandBPath", str) # Format: "wandb:entity/project/run_id" + +# Merge types +MergePair = NewType("MergePair", tuple[int, int]) + +# Tensor type aliases (torch-based for computation - TypeAlias for jaxtyping compatibility) +ActivationsTensor = Float[Tensor, "samples n_components"] +BoolActivationsTensor = Bool[Tensor, "samples n_components"] +ClusterCoactivationShaped = Float[Tensor, "k_groups k_groups"] +GroupIdxsTensor = Int[Tensor, " n_components"] +BatchTensor = Int[Tensor, "batch_size seq_len"] + + +class SaveableObject(ABC): + """Abstract base class for objects that can be saved to and loaded from disk.""" + + @abstractmethod + def save(self, path: Path) -> None: + """Save the object to disk at the given path.""" + ... + + @classmethod + @abstractmethod + def read(cls, path: Path) -> "SaveableObject": + """Load the object from disk at the given path.""" + ... diff --git a/spd/clustering/dataset.py b/spd/clustering/dataset.py new file mode 100644 index 000000000..ea9b9f904 --- /dev/null +++ b/spd/clustering/dataset.py @@ -0,0 +1,137 @@ +"""Dataset loading utilities for clustering runs. + +Each clustering run loads its own dataset batch, seeded by the run index. +""" + +from typing import Any + +from spd.clustering.consts import BatchTensor +from spd.data import DatasetConfig, create_data_loader +from spd.experiments.lm.configs import LMTaskConfig +from spd.experiments.resid_mlp.configs import ResidMLPTaskConfig +from spd.experiments.resid_mlp.models import ResidMLP +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName + + +def load_dataset( + model_path: str, + task_name: TaskName, + batch_size: int, + seed: int, + **kwargs: Any, +) -> BatchTensor: + """Load a single batch for clustering. + + Each run gets its own dataset batch, seeded by index in ensemble. + + Args: + model_path: Path to decomposed model + task_name: Task type + batch_size: Batch size + seed: Random seed for dataset + + Returns: + Single batch of data + """ + match task_name: + case "lm": + return _load_lm_batch( + model_path=model_path, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + case "resid_mlp": + return _load_resid_mlp_batch( + model_path=model_path, + batch_size=batch_size, + seed=seed, + **kwargs, + ) + case _: + raise ValueError(f"Unsupported task: {task_name}") + + +def _load_lm_batch( + model_path: str, batch_size: int, seed: int, config_kwargs: dict[str, Any] | None = None +) -> BatchTensor: + """Load a batch for language model task.""" + spd_run = SPDRunInfo.from_path(model_path) + cfg = spd_run.config + + assert isinstance(cfg.task_config, LMTaskConfig), ( + f"Expected task_config to be of type LMTaskConfig, but got {type(cfg.task_config) = }" + ) + + try: + pretrained_model_name: str = cfg.pretrained_model_name # pyright: ignore[reportAssignmentType] + assert pretrained_model_name is not None + except Exception as e: + raise AttributeError("Could not find 'pretrained_model_name' in the SPD Run config") from e + + config_kwargs_: dict[str, Any] = { + **dict( + is_tokenized=False, + streaming=False, + ), + **(config_kwargs or {}), + } + + dataset_config = DatasetConfig( + name=cfg.task_config.dataset_name, + hf_tokenizer_path=pretrained_model_name, + split=cfg.task_config.train_data_split, + n_ctx=cfg.task_config.max_seq_len, + seed=seed, # Use run-specific seed + column_name=cfg.task_config.column_name, + **config_kwargs_, + ) + + dataloader, _ = create_data_loader( + dataset_config=dataset_config, + batch_size=batch_size, + buffer_size=cfg.task_config.buffer_size, + global_seed=seed, # Use run-specific seed + ddp_rank=0, + ddp_world_size=1, + ) + + # Get first batch + batch = next(iter(dataloader)) + return batch["input_ids"] + + +def _load_resid_mlp_batch(model_path: str, batch_size: int, seed: int) -> BatchTensor: + """Load a batch for ResidMLP task.""" + from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset + from spd.utils.data_utils import DatasetGeneratedDataLoader + + spd_run = SPDRunInfo.from_path(model_path) + cfg = spd_run.config + component_model = ComponentModel.from_pretrained(spd_run.checkpoint_path) + + assert isinstance(cfg.task_config, ResidMLPTaskConfig), ( + f"Expected task_config to be of type ResidMLPTaskConfig, but got {type(cfg.task_config) = }" + ) + assert isinstance(component_model.target_model, ResidMLP), ( + f"Expected target_model to be of type ResidMLP, but got {type(component_model.target_model) = }" + ) + + # Create dataset with run-specific seed + dataset = ResidMLPDataset( + n_features=component_model.target_model.config.n_features, + feature_probability=cfg.task_config.feature_probability, + device="cpu", + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=seed, # Use run-specific seed + label_coeffs=None, + data_generation_type=cfg.task_config.data_generation_type, + ) + + # Generate batch + dataloader = DatasetGeneratedDataLoader(dataset, batch_size=batch_size, shuffle=False) + batch, _ = next(iter(dataloader)) + return batch diff --git a/spd/clustering/ensemble_registry.py b/spd/clustering/ensemble_registry.py new file mode 100644 index 000000000..c54fe408b --- /dev/null +++ b/spd/clustering/ensemble_registry.py @@ -0,0 +1,87 @@ +"""Ensemble registry for tracking which clustering runs belong to which pipeline ensemble. + +Uses SQLite to maintain a mapping of (pipeline_run_id, idx, clustering_run_id). +""" + +import sqlite3 +from contextlib import contextmanager + +from spd.settings import SPD_CACHE_DIR + +# SQLite database path +_ENSEMBLE_REGISTRY_DB = SPD_CACHE_DIR / "clustering_ensemble_registry.db" + + +@contextmanager +def _get_connection(): + """Context manager for SQLite connection, ensures table exists.""" + _ENSEMBLE_REGISTRY_DB.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(_ENSEMBLE_REGISTRY_DB) + + try: + # Create table if not exists + conn.execute(""" + CREATE TABLE IF NOT EXISTS ensemble_runs ( + pipeline_run_id TEXT NOT NULL, + idx INTEGER NOT NULL, + clustering_run_id TEXT NOT NULL, + PRIMARY KEY (pipeline_run_id, idx) + ) + """) + conn.execute(""" + CREATE INDEX IF NOT EXISTS idx_pipeline_run_id + ON ensemble_runs (pipeline_run_id) + """) + conn.commit() + + yield conn + finally: + conn.close() + + +def register_clustering_run(pipeline_run_id: str, clustering_run_id: str) -> int: + """Register a clustering run as part of a pipeline ensemble. + + Args: + pipeline_run_id: The ensemble/pipeline run ID + idx: Index of this run in the ensemble. If -1, auto-assigns the next available index. + clustering_run_id: The individual clustering run ID + + Returns: + The index assigned to this run (either the provided idx or the auto-assigned one) + """ + with _get_connection() as conn: + # Use BEGIN IMMEDIATE for thread-safe auto-increment + conn.execute("BEGIN IMMEDIATE") + + # Auto-assign next available index, we rely on atomicity of the transaction here + cursor = conn.execute( + "SELECT COALESCE(MAX(idx), -1) + 1 FROM ensemble_runs WHERE pipeline_run_id = ?", + (pipeline_run_id,), + ) + assigned_idx: int = cursor.fetchone()[0] + + conn.execute( + "INSERT INTO ensemble_runs (pipeline_run_id, idx, clustering_run_id) VALUES (?, ?, ?)", + (pipeline_run_id, assigned_idx, clustering_run_id), + ) + conn.commit() + + return assigned_idx + + +def get_clustering_runs(pipeline_run_id: str) -> list[tuple[int, str]]: + """Get all clustering runs for a pipeline ensemble. + + Args: + pipeline_run_id: The ensemble/pipeline run ID + + Returns: + List of (idx, clustering_run_id) tuples, sorted by idx + """ + with _get_connection() as conn: + cursor = conn.execute( + "SELECT idx, clustering_run_id FROM ensemble_runs WHERE pipeline_run_id = ? ORDER BY idx", + (pipeline_run_id,), + ) + return cursor.fetchall() diff --git a/spd/clustering/math/__init__.py b/spd/clustering/math/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/clustering/math/matching_dist.py b/spd/clustering/math/matching_dist.py new file mode 100644 index 000000000..1991e9ba0 --- /dev/null +++ b/spd/clustering/math/matching_dist.py @@ -0,0 +1,47 @@ +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +_DEVICE: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def matching_dist( + X: Int[Tensor, "s n"], +) -> Float[Tensor, "s s"]: + s_ensemble, _n_components = X.shape + matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] + + dists: Float[Tensor, "s s"] = torch.full((s_ensemble, s_ensemble), torch.nan) + + for i in range(s_ensemble): + for j in range(i + 1, s_ensemble): + dist_mat = matches[i].float() - matches[j].float() + dists[i, j] = torch.tril(dist_mat, diagonal=-1).abs().sum() + + return dists + + +def matching_dist_vec( + X: Int[Tensor, "s n"], +) -> Float[Tensor, "s s"]: + matches: Bool[Tensor, "s n n"] = X[:, :, None] == X[:, None, :] + diffs: Bool[Tensor, "s s n n"] = matches[:, None, :, :] ^ matches[None, :, :, :] + + dists_int: torch.Tensor = diffs.sum(dim=(-1, -2)) + dists: Float[Tensor, "s s"] = dists_int.to(torch.float32) + return dists + + +def matching_dist_np( + X: Int[np.ndarray, "s n"], + device: torch.device = _DEVICE, +) -> Float[np.ndarray, "s s"]: + return matching_dist(torch.tensor(X, device=device)).cpu().numpy() + + +def matching_dist_vec_np( + X: Int[np.ndarray, "s n"], + device: torch.device = _DEVICE, +) -> Float[np.ndarray, "s s"]: + return matching_dist_vec(torch.tensor(X, device=device)).cpu().numpy() diff --git a/spd/clustering/math/merge_distances.py b/spd/clustering/math/merge_distances.py new file mode 100644 index 000000000..d3644cd68 --- /dev/null +++ b/spd/clustering/math/merge_distances.py @@ -0,0 +1,59 @@ +from collections.abc import Callable + +import numpy as np +from jaxtyping import Float, Int +from muutils.parallel import run_maybe_parallel + +from spd.clustering.consts import ( + DistancesArray, + DistancesMethod, + MergesArray, + MergesAtIterArray, +) +from spd.clustering.math.matching_dist import matching_dist_np, matching_dist_vec_np +from spd.clustering.math.perm_invariant_hamming import perm_invariant_hamming_matrix + +DISTANCES_METHODS: dict[DistancesMethod, Callable[[MergesAtIterArray], DistancesArray]] = { + "perm_invariant_hamming": perm_invariant_hamming_matrix, + "matching_dist": matching_dist_np, +} + +# pyright: reportUnnecessaryComparison=false, reportUnreachable=false + + +def compute_distances( + normalized_merge_array: MergesArray, + method: DistancesMethod = "perm_invariant_hamming", +) -> DistancesArray: + n_iters: int = normalized_merge_array.shape[1] + merges_array_list: list[Int[np.ndarray, "n_ens n_components"]] + distances_list: list[Float[np.ndarray, "n_ens n_ens"]] + match method: + case "perm_invariant_hamming": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + + distances_list = run_maybe_parallel( + func=perm_invariant_hamming_matrix, + iterable=merges_array_list, + parallel=True, + ) + + return np.stack(distances_list, axis=0) + case "matching_dist": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + distances_list = run_maybe_parallel( + func=matching_dist_np, + iterable=merges_array_list, + parallel=True, + ) + return np.stack(distances_list, axis=0) + case "matching_dist_vec": + merges_array_list = [normalized_merge_array[:, i, :] for i in range(n_iters)] + distances_list = run_maybe_parallel( + func=matching_dist_vec_np, + iterable=merges_array_list, + parallel=True, + ) + return np.stack(distances_list, axis=0) + case _: + raise ValueError(f"Unknown distance method: {method}") diff --git a/spd/clustering/math/merge_matrix.py b/spd/clustering/math/merge_matrix.py new file mode 100644 index 000000000..118f575e2 --- /dev/null +++ b/spd/clustering/math/merge_matrix.py @@ -0,0 +1,283 @@ +from dataclasses import dataclass + +import torch +from jaxtyping import Bool, Int +from muutils.tensor_info import array_summary +from torch import Tensor + +from spd.clustering.consts import GroupIdxsTensor + +# pyright: reportUnnecessaryTypeIgnoreComment=false + + +@dataclass(kw_only=True, slots=True) +class GroupMerge: + """Canonical component-to-group assignment. + + `group_idxs` is a length-`n_components` integer tensor; entry `c` + gives the group index (0 to `k_groups-1`) that contains component `c`. + """ + + group_idxs: GroupIdxsTensor + k_groups: int + old_to_new_idx: dict[int | None, int | None] | None = None + + def summary(self) -> dict[str, int | str | None]: + return dict( + group_idxs=array_summary(self.group_idxs, as_list=False), # pyright: ignore[reportCallIssue] + k_groups=self.k_groups, + old_to_new_idx=f"len={len(self.old_to_new_idx)}" + if self.old_to_new_idx is not None + else None, + ) + + @property + def _n_components(self) -> int: + return int(self.group_idxs.shape[0]) + + @property + def components_per_group(self) -> Int[Tensor, " k_groups"]: + return torch.bincount(self.group_idxs, minlength=self.k_groups) + + def components_in_group_mask(self, group_idx: int) -> Bool[Tensor, " n_components"]: + """Returns a boolean mask for components in the specified group.""" + if group_idx < 0 or group_idx >= self.k_groups: + raise ValueError("group index out of range") + return self.group_idxs == group_idx + + def components_in_group(self, group_idx: int) -> list[int]: + """Returns a list of component indices in the specified group.""" + indices: Int[Tensor, " n_matches"] = ( + (self.group_idxs == group_idx).nonzero(as_tuple=False).squeeze(-1) + ) + return indices.tolist() + + def validate(self, *, require_nonempty: bool = True) -> None: + v_min: int = int(self.group_idxs.min().item()) + v_max: int = int(self.group_idxs.max().item()) + if v_min < 0 or v_max >= self.k_groups: + raise ValueError("group indices out of range") + + if require_nonempty: + has_empty_groups: bool = bool(self.components_per_group.eq(0).any().item()) + if has_empty_groups: + raise ValueError("one or more groups are empty") + + def to_matrix( + self, device: torch.device | None = None + ) -> Bool[Tensor, "k_groups n_components"]: + if device is None: + device = self.group_idxs.device + mat: Bool[Tensor, "k_groups n_components"] = torch.zeros( + (self.k_groups, self._n_components), dtype=torch.bool, device=device + ) + idxs: Int[Tensor, " n_components"] = torch.arange( + self._n_components, device=device, dtype=torch.int + ) + mat[self.group_idxs.to(dtype=torch.int), idxs] = True + return mat + + @classmethod + def from_matrix(cls, mat: Bool[Tensor, "k_groups n_components"]) -> "GroupMerge": + if mat.dtype is not torch.bool: + raise TypeError("mat must have dtype bool") + if not mat.sum(dim=0).eq(1).all(): + raise ValueError("each column must contain exactly one True") + group_idxs: GroupIdxsTensor = mat.argmax(dim=0).to(torch.int64) + inst: GroupMerge = cls(group_idxs=group_idxs, k_groups=int(mat.shape[0])) + inst.validate(require_nonempty=False) + return inst + + @classmethod + def random( + cls, + n_components: int, + k_groups: int, + *, + ensure_groups_nonempty: bool = False, + device: torch.device | str = "cpu", + ) -> "GroupMerge": + if ensure_groups_nonempty and n_components < k_groups: + raise ValueError("n_components must be >= k_groups when ensure_groups_nonempty is True") + + group_idxs: GroupIdxsTensor + + if ensure_groups_nonempty: + base: Int[Tensor, " k_groups"] = torch.arange(k_groups, device=device) + if n_components > k_groups: + extra: Int[Tensor, " n_extra"] = torch.randint( + 0, k_groups, (n_components - k_groups,), device=device + ) + group_idxs = torch.cat((base, extra)) + group_idxs = group_idxs[torch.randperm(n_components, device=device)] + else: + group_idxs = base + else: + group_idxs = torch.randint(0, k_groups, (n_components,), device=device) + inst: GroupMerge = cls(group_idxs=group_idxs, k_groups=k_groups) + inst.validate(require_nonempty=ensure_groups_nonempty) + return inst + + @classmethod + def identity(cls, n_components: int) -> "GroupMerge": + """Creates a GroupMerge where each component is its own group.""" + return cls( + group_idxs=torch.arange(n_components, dtype=torch.int64), + k_groups=n_components, + ) + + def merge_groups(self, group_a: int, group_b: int) -> "GroupMerge": + """Merges two groups into one, returning a new GroupMerge.""" + if group_a < 0 or group_b < 0 or group_a >= self.k_groups or group_b >= self.k_groups: + raise ValueError("group indices out of range") + if group_a == group_b: + raise ValueError("Cannot merge a group with itself") + + # make sure group_a is the smaller index + if group_a > group_b: + group_a, group_b = group_b, group_a + + # make a copy + new_idxs: GroupIdxsTensor = self.group_idxs.clone() + # wherever its currently b, change it to a + new_idxs[new_idxs == group_b] = group_a + # wherever i currently above b, change it to i-1 + new_idxs[new_idxs > group_b] -= 1 + # create a new GroupMerge instance + merged: GroupMerge = GroupMerge(group_idxs=new_idxs, k_groups=self.k_groups - 1) + + # create a mapping from old to new group indices + # `None` as a key is for the new group that contains both a and b + # values of a and b are mapped to `None` since they are merged + old_to_new_idx: dict[int | None, int | None] = dict() + for i in range(self.k_groups): + if i in {group_a, group_b}: + old_to_new_idx[i] = None + elif i <= group_b: + old_to_new_idx[i] = i + else: + old_to_new_idx[i] = i - 1 + old_to_new_idx[None] = group_a # the new group index for the merged group + + # HACK: store the mapping in the instance for later use + merged.old_to_new_idx = old_to_new_idx # type: ignore[assignment] + + # validate the new instance + # merged.validate(require_nonempty=True) + return merged + + def all_downstream_merged(self) -> "BatchedGroupMerge": + downstream: list[GroupMerge] = [] + idxs: list[tuple[int, int]] = [] + for i in range(self.k_groups): + for j in range(i + 1, self.k_groups): + downstream.append(self.merge_groups(i, j)) + idxs.append((i, j)) + + return BatchedGroupMerge.from_list(merge_matrices=downstream) + + +@dataclass(slots=True) +class BatchedGroupMerge: + """Batch of merge matrices. + + `group_idxs` has shape `(batch, n_components)`; each row holds the + group index for every component in that matrix. + """ + + group_idxs: Int[Tensor, "batch n_components"] + k_groups: Int[Tensor, " batch"] + + def summary(self) -> dict[str, int | str | None]: + return dict( + group_idxs=array_summary(self.group_idxs, as_list=False), # pyright: ignore[reportCallIssue] + k_groups=array_summary(self.k_groups, as_list=False), # pyright: ignore[reportCallIssue] + # TODO: re-add metadata (which pairs merged at each step) + # meta=f"len={len(self.meta)}" if self.meta is not None else None, + ) + + @classmethod + def init_empty(cls, batch_size: int, n_components: int) -> "BatchedGroupMerge": + """Initialize an empty BatchedGroupMerge with the given batch size and number of components.""" + return cls( + group_idxs=torch.full((batch_size, n_components), -1, dtype=torch.int16), + k_groups=torch.zeros(batch_size, dtype=torch.int16), + ) + + @property + def _batch_size(self) -> int: + return int(self.group_idxs.shape[0]) + + @property + def _n_components(self) -> int: + return int(self.group_idxs.shape[1]) + + @property + def k_groups_unique(self) -> int: + """Returns the number of groups across all matrices, throws exception if they differ.""" + k_groups_set: set[int] = set(self.k_groups.tolist()) + if len(k_groups_set) != 1: + raise ValueError("All matrices must have the same number of groups") + return k_groups_set.pop() + + def to_matrix( + self, device: torch.device | None = None + ) -> Bool[Tensor, "batch k_groups n_components"]: + if device is None: + device = self.group_idxs.device + k_groups_u: int = self.k_groups_unique + mat = torch.nn.functional.one_hot(self.group_idxs, num_classes=k_groups_u) + return mat.permute(0, 2, 1).to(device=device, dtype=torch.bool) + + @classmethod + def from_matrix(cls, mat: Bool[Tensor, "batch k_groups n_components"]) -> "BatchedGroupMerge": + if mat.dtype is not torch.bool: + raise TypeError("mat must have dtype bool") + if not mat.sum(dim=1).eq(1).all(): + raise ValueError("each column must have exactly one True per matrix") + group_idxs = mat.argmax(dim=1).to(torch.int64) + batch_size: int = int(mat.shape[0]) + inst = cls( + group_idxs=group_idxs, + k_groups=torch.full((batch_size,), int(mat.shape[1]), dtype=torch.int64), + ) + # inst.validate(require_nonempty=False) + return inst + + @classmethod + def from_list( + cls, + merge_matrices: list[GroupMerge], + ) -> "BatchedGroupMerge": + group_idxs: Int[Tensor, "batch n_components"] = torch.stack( + [mm.group_idxs for mm in merge_matrices], dim=0 + ) + k_groups: Int[Tensor, " batch"] = torch.tensor( + [mm.k_groups for mm in merge_matrices], dtype=torch.int64 + ) + inst: BatchedGroupMerge = cls(group_idxs=group_idxs, k_groups=k_groups) + # inst.validate(require_nonempty=False) + return inst + + def __getitem__(self, idx: int) -> GroupMerge: + if not (0 <= idx < self._batch_size): + raise IndexError("index out of range") + group_idxs: GroupIdxsTensor = self.group_idxs[idx] + k_groups: int = int(self.k_groups[idx].item()) + return GroupMerge(group_idxs=group_idxs, k_groups=k_groups) + + def __setitem__(self, idx: int, value: GroupMerge) -> None: + if not (0 <= idx < self._batch_size): + raise IndexError("index out of range") + if value._n_components != self._n_components: + raise ValueError("value must have the same number of components as the batch") + self.group_idxs[idx] = value.group_idxs + self.k_groups[idx] = value.k_groups + + def __iter__(self): + """Iterate over the GroupMerge instances in the batch.""" + for i in range(self._batch_size): + yield self[i] + + def __len__(self) -> int: + return self._batch_size diff --git a/spd/clustering/math/merge_pair_samplers.py b/spd/clustering/math/merge_pair_samplers.py new file mode 100644 index 000000000..24c050d36 --- /dev/null +++ b/spd/clustering/math/merge_pair_samplers.py @@ -0,0 +1,121 @@ +import random +from typing import Any, Literal, Protocol + +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, MergePair + +MergePairSamplerKey = Literal["range", "mcmc"] + + +class MergePairSamplerConfigurable(Protocol): + def __call__( + self, + costs: ClusterCoactivationShaped, + **kwargs: Any, + ) -> MergePair: ... + + +class MergePairSampler(Protocol): + def __call__( + self, + costs: ClusterCoactivationShaped, + ) -> MergePair: ... + + +def range_sampler( + costs: ClusterCoactivationShaped, + threshold: float = 0.05, + **kwargs: Any, +) -> MergePair: + """Sample a merge pair using threshold-based range selection. + + Considers all pairs with costs below a threshold defined as a fraction + of the range of non-diagonal costs, then randomly selects one. + + Args: + costs: Cost matrix for all possible merges + k_groups: Number of current groups + threshold: Fraction of cost range to consider (0=min only, 1=all pairs) + + Returns: + Tuple of (group_i, group_j) indices to merge + """ + assert not kwargs + k_groups: int = costs.shape[0] + assert costs.shape[1] == k_groups, "Cost matrix must be square" + + # Find the range of non-diagonal costs + non_diag_costs: Float[Tensor, " k_groups_squared_minus_k"] = costs[ + ~torch.eye(k_groups, dtype=torch.bool, device=costs.device) + ] + min_cost: float = float(non_diag_costs.min().item()) + max_cost: float = float(non_diag_costs.max().item()) + + # Calculate threshold cost + max_considered_cost: float = (max_cost - min_cost) * threshold + min_cost + + # Find all pairs below threshold + considered_idxs: Int[Tensor, "n_considered 2"] = torch.stack( + torch.where(costs <= max_considered_cost), dim=1 + ) + # Remove diagonal entries (i == j) + considered_idxs = considered_idxs[considered_idxs[:, 0] != considered_idxs[:, 1]] + + # Randomly select one of the considered pairs + selected_idx: int = random.randint(0, considered_idxs.shape[0] - 1) + pair_tuple: tuple[int, int] = tuple(considered_idxs[selected_idx].tolist()) # type: ignore[assignment] + return MergePair(pair_tuple) + + +def mcmc_sampler( + costs: ClusterCoactivationShaped, + temperature: float = 1.0, + **kwargs: Any, +) -> MergePair: + """Sample a merge pair using MCMC with probability proportional to exp(-cost/temperature). + + Args: + costs: Cost matrix for all possible merges + k_groups: Number of current groups + temperature: Temperature parameter for softmax (higher = more uniform sampling) + + Returns: + Tuple of (group_i, group_j) indices to merge + """ + assert not kwargs + k_groups: int = costs.shape[0] + assert costs.shape[1] == k_groups, "Cost matrix must be square" + + # Create mask for valid pairs (non-diagonal) + valid_mask: Bool[Tensor, "k_groups k_groups"] = ~torch.eye( + k_groups, dtype=torch.bool, device=costs.device + ) + + # Compute probabilities: exp(-cost/temperature) + # Use stable softmax computation to avoid overflow + costs_masked: ClusterCoactivationShaped = costs.clone() + costs_masked[~valid_mask] = float("inf") # Set diagonal to inf so exp gives 0 + + # Subtract min for numerical stability + min_cost: float = float(costs_masked[valid_mask].min()) + probs: ClusterCoactivationShaped = ( + torch.exp((min_cost - costs_masked) / temperature) * valid_mask + ) # Zero out diagonal + probs_flatten: Float[Tensor, " k_groups_squared"] = probs.flatten() + probs_flatten = probs_flatten / probs_flatten.sum() + + # Sample from multinomial distribution + idx: int = int(torch.multinomial(probs_flatten, 1).item()) + row: int = idx // k_groups + col: int = idx % k_groups + + return MergePair((row, col)) + + +MERGE_PAIR_SAMPLERS: dict[MergePairSamplerKey, MergePairSamplerConfigurable] = { + "range": range_sampler, + "mcmc": mcmc_sampler, +} diff --git a/spd/clustering/math/perm_invariant_hamming.py b/spd/clustering/math/perm_invariant_hamming.py new file mode 100644 index 000000000..e70d3c7c0 --- /dev/null +++ b/spd/clustering/math/perm_invariant_hamming.py @@ -0,0 +1,70 @@ +import warnings + +import numpy as np +from jaxtyping import Float, Int +from scipy.optimize import linear_sum_assignment + + +def perm_invariant_hamming_matrix( + X: Int[np.ndarray, "n_ens n_components"], +) -> Float[np.ndarray, "n_ens n_ens"]: + """Compute all pairwise permutation-invariant Hamming distances. + + The strictly lower-triangular entries are filled with distances; + the diagonal and upper triangle are left as `np.nan`. + + # Parameters: + - `X : Int[np.ndarray, "n_ens n_components"]` + Matrix where each of the `n_ens` rows is a label vector of length `n_components`. + + # Returns: + - `Float[np.ndarray, "n_ens n_ens"]` + Distance matrix `D` with `D[i, j]` defined only for `i > j`; + all other positions are `np.nan`. + + # Usage: + ```python + >>> X = np.array([[0, 0, 1], + ... [1, 1, 0], + ... [0, 1, 0]]) + >>> D = perm_invariant_hamming_matrix(X) + >>> D + array([[nan, nan, nan], + [ 0., nan, nan], + [ 2., 2., nan]]) + ``` + """ + n_ens: int + n_components: int + n_ens, n_components = X.shape + D: Float[np.ndarray, "n_ens n_ens"] = np.full((n_ens, n_ens), np.nan, dtype=float) + + # Pre-compute max label in each row once. + row_max: Int[np.ndarray, " n_ens"] = X.max(axis=1) + + for i in range(1, n_ens): + a: Int[np.ndarray, " n_components"] = X[i] + for j in range(i): + b: Int[np.ndarray, " n_components"] = X[j] + + k_lbls: int = int(max(row_max[i], row_max[j]) + 1) + + # Handle case where all labels are -1 (no valid clustering) + if k_lbls <= 0: + warnings.warn( + f"All labels are -1 at rows {i} and {j}. Setting distance to 0.", + UserWarning, + stacklevel=2, + ) + D[i, j] = 0.0 + continue + + C: Int[np.ndarray, "k_lbls k_lbls"] = np.zeros((k_lbls, k_lbls), dtype=int) + np.add.at(C, (a, b), 1) + + row_ind, col_ind = linear_sum_assignment(-C) + matches: int = int(C[row_ind, col_ind].sum()) + + D[i, j] = n_components - matches # int is fine; array is float because of NaN + + return D diff --git a/spd/clustering/math/semilog.py b/spd/clustering/math/semilog.py new file mode 100644 index 000000000..a17ba63b5 --- /dev/null +++ b/spd/clustering/math/semilog.py @@ -0,0 +1,13 @@ +import math + + +def semilog( + value: float, + epsilon: float = 1e-3, +) -> float: + if abs(value) < epsilon: + return value + else: + sign: int = 1 if value >= 0 else -1 + # log10 here is safe, since we know the value is not close to zero + return sign * epsilon * math.log1p(abs(value) / epsilon) diff --git a/spd/clustering/merge.py b/spd/clustering/merge.py new file mode 100644 index 000000000..dba55c878 --- /dev/null +++ b/spd/clustering/merge.py @@ -0,0 +1,188 @@ +""" +Merge iteration with logging support. + +This wraps the pure merge_iteration_pure() function and adds WandB/plotting callbacks. +""" + +import warnings +from typing import Protocol + +import torch +from jaxtyping import Bool, Float +from torch import Tensor +from tqdm import tqdm + +from spd.clustering.compute_costs import ( + compute_mdl_cost, + compute_merge_costs, + recompute_coacts_merge_pair, +) +from spd.clustering.consts import ( + ActivationsTensor, + BoolActivationsTensor, + ClusterCoactivationShaped, + ComponentLabels, + MergePair, +) +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory + + +class LogCallback(Protocol): + def __call__( + self, + current_coact: ClusterCoactivationShaped, + component_labels: ComponentLabels, + current_merge: GroupMerge, + costs: ClusterCoactivationShaped, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: Float[Tensor, " k_groups"], + ) -> None: ... + + +def merge_iteration( + merge_config: MergeConfig, + activations: ActivationsTensor, + component_labels: ComponentLabels, + log_callback: LogCallback | None = None, +) -> MergeHistory: + """ + Merge iteration with optional logging/plotting callbacks. + + This wraps the pure computation with logging capabilities while maintaining + the same core algorithm logic. + """ + + # compute coactivations + # -------------------------------------------------- + activation_mask_orig: BoolActivationsTensor | ActivationsTensor | None = ( + activations > merge_config.activation_threshold + if merge_config.activation_threshold is not None + else activations + ) + coact: Float[Tensor, "c c"] = activation_mask_orig.float().T @ activation_mask_orig.float() + + # check shapes + c_components: int = coact.shape[0] + assert coact.shape[1] == c_components, "Coactivation matrix must be square" + + # determine number of iterations based on config and number of components + num_iters: int = merge_config.get_num_iters(c_components) + + # initialize vars + # -------------------------------------------------- + # start with an identity merge + current_merge: GroupMerge = GroupMerge.identity(n_components=c_components) + + # initialize variables for the merge process + k_groups: int = c_components + current_coact: ClusterCoactivationShaped = coact.clone() + current_act_mask: Bool[Tensor, "samples k_groups"] = activation_mask_orig.clone() + + # variables we keep track of + merge_history: MergeHistory = MergeHistory.from_config( + merge_config=merge_config, + labels=component_labels, + ) + + # merge iteration + # ================================================== + pbar: tqdm[int] = tqdm( + range(num_iters), + unit="iter", + total=num_iters, + ) + for iter_idx in pbar: + # compute costs, figure out what to merge + # -------------------------------------------------- + # HACK: this is messy + costs: ClusterCoactivationShaped = compute_merge_costs( + coact=current_coact / current_act_mask.shape[0], + merges=current_merge, + alpha=merge_config.alpha, + ) + + merge_pair: MergePair = merge_config.merge_pair_sample(costs) + + # merge the pair + # -------------------------------------------------- + # we do this *before* logging, so we can see how the sampled pair cost compares + # to the costs of all the other possible pairs + current_merge, current_coact, current_act_mask = recompute_coacts_merge_pair( + coact=current_coact, + merges=current_merge, + merge_pair=merge_pair, + activation_mask=current_act_mask, + ) + + # metrics and logging + # -------------------------------------------------- + # Store in history + merge_history.add_iteration( + idx=iter_idx, + selected_pair=merge_pair, + current_merge=current_merge, + ) + + # Compute metrics for logging + # the MDL loss computed here is the *cost of the current merge*, a single scalar value + # rather than the *delta in cost from merging a specific pair* (which is what `costs` matrix contains) + diag_acts: Float[Tensor, " k_groups"] = torch.diag(current_coact) + mdl_loss: float = compute_mdl_cost( + acts=diag_acts, + merges=current_merge, + alpha=merge_config.alpha, + ) + mdl_loss_norm: float = mdl_loss / current_act_mask.shape[0] + # this is the cost for the selected pair + merge_pair_cost: float = float(costs[merge_pair].item()) + + # Update progress bar + pbar.set_description(f"k={k_groups}, mdl={mdl_loss_norm:.4f}, pair={merge_pair_cost:.4f}") + + if log_callback is not None: + log_callback( + iter_idx=iter_idx, + current_coact=current_coact, + component_labels=component_labels, + current_merge=current_merge, + costs=costs, + merge_history=merge_history, + k_groups=k_groups, + merge_pair_cost=merge_pair_cost, + mdl_loss=mdl_loss, + mdl_loss_norm=mdl_loss_norm, + diag_acts=diag_acts, + ) + + # iterate and sanity checks + # -------------------------------------------------- + k_groups -= 1 + assert current_coact.shape[0] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_coact.shape[1] == k_groups, ( + "Coactivation matrix shape should match number of groups" + ) + assert current_act_mask.shape[1] == k_groups, ( + "Activation mask shape should match number of groups" + ) + + # early stopping failsafe + # -------------------------------------------------- + if k_groups <= 3: + warnings.warn( + f"Stopping early at iteration {iter_idx} as only {k_groups} groups left", + stacklevel=2, + ) + break + + # finish up + # ================================================== + return merge_history diff --git a/spd/clustering/merge_config.py b/spd/clustering/merge_config.py new file mode 100644 index 000000000..f471879b2 --- /dev/null +++ b/spd/clustering/merge_config.py @@ -0,0 +1,114 @@ +import functools +import hashlib +from typing import Any, Literal + +from pydantic import ( + Field, + PositiveInt, +) + +from spd.base_config import BaseConfig +from spd.clustering.consts import ClusterCoactivationShaped, MergePair +from spd.clustering.math.merge_pair_samplers import ( + MERGE_PAIR_SAMPLERS, + MergePairSampler, + MergePairSamplerKey, +) +from spd.clustering.util import ModuleFilterFunc, ModuleFilterSource +from spd.spd_types import Probability + +MergeConfigKey = Literal[ + "activation_threshold", + "alpha", + "iters", + "merge_pair_sampling_method", + "merge_pair_sampling_kwargs", + "filter_dead_threshold", +] + + +def _to_module_filter( + filter_modules: ModuleFilterSource, +) -> ModuleFilterFunc: + """Convert the filter_modules argument to a callable.""" + if filter_modules is None: + return lambda _: True + elif isinstance(filter_modules, str): + return lambda module_name: module_name.startswith(filter_modules) + elif isinstance(filter_modules, set): + return lambda module_name: module_name in filter_modules + elif callable(filter_modules): + return filter_modules + else: + raise TypeError(f"filter_modules must be str, set, or callable, got {type(filter_modules)}") # pyright: ignore[reportUnreachable] + + +class MergeConfig(BaseConfig): + activation_threshold: Probability | None = Field( + default=0.01, + description="Threshold for considering a component active in a group. If None, use raw scalar causal importances", + ) + alpha: float = Field( + default=1.0, + description="rank weight factor. Higher values mean a higher penalty on 'sending' the component weights", + ) + iters: PositiveInt | None = Field( + default=100, + description="max number of iterations to run the merge algorithm for. If `None`, set to number of components (after filtering) minus one.", + ) + merge_pair_sampling_method: MergePairSamplerKey = Field( + default="range", + description="Method for sampling merge pairs. Options: 'range', 'mcmc'.", + ) + merge_pair_sampling_kwargs: dict[str, Any] = Field( + default_factory=lambda: {"threshold": 0.05}, + description="Keyword arguments for the merge pair sampling method.", + ) + filter_dead_threshold: float = Field( + default=0.001, + description="Threshold for filtering out dead components. If a component's activation is below this threshold, it is considered dead and not included in the merge.", + ) + module_name_filter: ModuleFilterSource = Field( + default=None, + description="Filter for module names. Can be a string prefix, a set of names, or a callable that returns True for modules to include.", + ) + + @property + def merge_pair_sample_func(self) -> MergePairSampler: + return functools.partial( + MERGE_PAIR_SAMPLERS[self.merge_pair_sampling_method], + **self.merge_pair_sampling_kwargs, + ) + + def merge_pair_sample( + self, + costs: ClusterCoactivationShaped, + ) -> MergePair: + """do merge sampling based on the configured method and kwargs + + has signature `MergePairSampler = Callable[[ClusterCoactivationShaped], MergePair]` + """ + return self.merge_pair_sample_func(costs=costs) + + @property + def filter_modules(self) -> ModuleFilterFunc: + """Get the module filter function based on the provided source.""" + return _to_module_filter(self.module_name_filter) + + def get_num_iters(self, n_components: int) -> PositiveInt: + """Get the number of iterations to run the merge algorithm for. + + Args: + n_components: Number of components (after filtering) + + Returns: + Number of iterations to run + """ + if self.iters is None: + return n_components - 1 + else: + return self.iters + + @property + def stable_hash(self) -> str: + return hashlib.md5(self.model_dump_json().encode()).hexdigest()[:6] diff --git a/spd/clustering/merge_history.py b/spd/clustering/merge_history.py new file mode 100644 index 000000000..bbff78893 --- /dev/null +++ b/spd/clustering/merge_history.py @@ -0,0 +1,464 @@ +import io +import json +import zipfile +from dataclasses import dataclass +from pathlib import Path +from typing import Any, override + +import numpy as np +import torch +from jaxtyping import Float, Int +from muutils.dbg import dbg_tensor + +from spd.clustering.consts import ( + ComponentLabels, + DistancesArray, + DistancesMethod, + MergePair, + MergesArray, + SaveableObject, +) +from spd.clustering.math.merge_distances import compute_distances +from spd.clustering.math.merge_matrix import BatchedGroupMerge, GroupMerge +from spd.clustering.merge_config import MergeConfig + + +@dataclass(frozen=True) +class IterationInfo: + """Information about a single merge iteration.""" + + idx: int + selected_pair: list[int] + merges: GroupMerge + + +def _zip_save_arr(zf: zipfile.ZipFile, name: str, arr: np.ndarray) -> None: + """Save a numpy array to a zip file.""" + buf: io.BytesIO = io.BytesIO() + np.save(buf, arr) + zf.writestr(name, buf.getvalue()) + + +def _zip_save_arr_dict(zf: zipfile.ZipFile, data: dict[str, np.ndarray]) -> None: + """Save a dictionary of numpy arrays to a zip file, {key}.npy used as path""" + key: str + arr: np.ndarray + for key, arr in data.items(): + _zip_save_arr(zf, f"{key}.npy", arr) + + +@dataclass(kw_only=True) +class MergeHistory(SaveableObject): + """Track merge iteration history""" + + merges: BatchedGroupMerge + selected_pairs: Int[np.ndarray, " n_iters 2"] + labels: ComponentLabels + merge_config: MergeConfig + n_iters_current: int + + meta: dict[str, Any] | None = None + + @property + def c_components(self) -> int: + return len(self.labels) + + @classmethod + def from_config( + cls, + merge_config: MergeConfig, + labels: ComponentLabels, + ) -> "MergeHistory": + n_components: int = len(labels) + n_iters_target: int = merge_config.get_num_iters(n_components) + return MergeHistory( + labels=labels, + n_iters_current=0, + selected_pairs=np.full((n_iters_target, 2), -1, dtype=np.int16), + merges=BatchedGroupMerge.init_empty( + batch_size=n_iters_target, n_components=n_components + ), + merge_config=merge_config, + ) + + def summary(self) -> dict[str, str | int | None | dict[str, int | str | None]]: + return dict( + c_components=self.c_components, + n_iters_current=self.n_iters_current, + total_iters=len(self.merges.k_groups), + len_labels=len(self.labels), + # wandb_url=self.wandb_url, + merge_config=self.merge_config.model_dump(mode="json"), + merges_summary=self.merges.summary(), + ) + + @override + def __str__(self) -> str: + out: list[str] = [f" {key} = {value}" for key, value in self.summary().items()] + return "MergeHistory(\n" + "\n".join(out) + "\n)" + + @override + def __repr__(self) -> str: + return self.__str__() + + def add_iteration( + self, + idx: int, + selected_pair: MergePair, + current_merge: GroupMerge, + ) -> None: + """Add data for one iteration.""" + self.selected_pairs[idx] = np.array(selected_pair, dtype=np.int16) + self.merges[idx] = current_merge + + assert self.n_iters_current == idx + self.n_iters_current += 1 + + def __getitem__(self, idx: int) -> IterationInfo: + """Get data for a specific iteration.""" + if idx < 0 or idx >= self.n_iters_current: + raise IndexError( + f"Index {idx} out of range for history with {self.n_iters_current} iterations" + ) + + return IterationInfo( + idx=idx, + selected_pair=self.selected_pairs[idx].tolist(), + merges=self.merges[idx], + ) + + def __len__(self) -> int: + """Get the number of iterations in the history.""" + return self.n_iters_current + + def latest(self) -> IterationInfo: + """Get the latest values.""" + if self.n_iters_current == 0: + raise ValueError("No history available") + latest_idx: int = self.n_iters_current - 1 + return self[latest_idx] + + def get_unique_clusters(self, iteration: int) -> list[int]: + """Get unique cluster IDs at a given iteration. + + Args: + iteration: Iteration index (negative indexes from end) + + Returns: + List of unique cluster IDs + """ + if iteration < 0: + iteration = self.n_iters_current + iteration + assert 0 <= iteration < self.n_iters_current, ( + f"Invalid iteration: {iteration = }, {self.n_iters_current = }" + ) + merge: GroupMerge = self.merges[iteration] + return torch.unique(merge.group_idxs).tolist() + + def get_cluster_component_labels(self, iteration: int, cluster_id: int) -> ComponentLabels: + """Get component labels for a specific cluster at a given iteration. + + Args: + iteration: Iteration index (negative indexes from end) + cluster_id: Cluster ID to query + + Returns: + List of component labels in the cluster + """ + if iteration < 0: + iteration = self.n_iters_current + iteration + assert 0 <= iteration < self.n_iters_current, ( + f"Invalid iteration: {iteration = }, {self.n_iters_current = }" + ) + merge: GroupMerge = self.merges[iteration] + component_indices: list[int] = merge.components_in_group(cluster_id) + return ComponentLabels([self.labels[idx] for idx in component_indices]) + + def get_cluster_components_info(self, iteration: int, cluster_id: int) -> list[dict[str, Any]]: + """Get detailed component information for a cluster. + + Args: + iteration: Iteration index (negative indexes from end) + cluster_id: Cluster ID to query + + Returns: + List of dicts with keys: module, index, label + """ + component_labels: list[str] = self.get_cluster_component_labels(iteration, cluster_id) + result: list[dict[str, Any]] = [] + for label in component_labels: + module: str + idx_str: str + module, idx_str = label.rsplit(":", 1) + result.append({"module": module, "index": int(idx_str), "label": label}) + return result + + # Convenience properties for sweep analysis + @property + def total_iterations(self) -> int: + """Total number of iterations performed.""" + return self.n_iters_current + + @property + def final_k_groups(self) -> int: + """Final number of groups after merging.""" + if self.n_iters_current == 0: + return self.c_components + return int(self.merges.k_groups[self.n_iters_current - 1].item()) + + @property + def initial_k_groups(self) -> int: + """Initial number of groups before merging.""" + if self.n_iters_current == 0: + return self.c_components + return int(self.merges.k_groups[0].item()) + + @override + def save(self, path: Path) -> None: + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "w") as zf: + # save arrays + _zip_save_arr_dict( + zf=zf, + data={ + "merge.group_idxs": self.merges.group_idxs.cpu().numpy(), + "merge.k_groups": self.merges.k_groups.cpu().numpy(), + "selected_pairs": self.selected_pairs, + }, + ) + # Save labels + zf.writestr("labels.txt", "\n".join(self.labels)) + # Save metadata + zf.writestr( + "metadata.json", + json.dumps( + dict( + merge_config=self.merge_config.model_dump(mode="json"), + c_components=self.c_components, + n_iters_current=self.n_iters_current, + labels=self.labels, + ) + ), + ) + + @override + @classmethod + def read(cls, path: Path) -> "MergeHistory": + zf: zipfile.ZipFile + with zipfile.ZipFile(path, "r") as zf: + group_idxs: np.ndarray = np.load(io.BytesIO(zf.read("merge.group_idxs.npy"))) + k_groups: np.ndarray = np.load(io.BytesIO(zf.read("merge.k_groups.npy"))) + selected_pairs: np.ndarray = np.load(io.BytesIO(zf.read("selected_pairs.npy"))) + merges: BatchedGroupMerge = BatchedGroupMerge( + group_idxs=torch.from_numpy(group_idxs), + k_groups=torch.from_numpy(k_groups), + ) + labels_raw: list[str] = zf.read("labels.txt").decode("utf-8").splitlines() + labels: ComponentLabels = ComponentLabels(labels_raw) + metadata: dict[str, Any] = json.loads(zf.read("metadata.json").decode("utf-8")) + merge_config: MergeConfig = MergeConfig.model_validate(metadata["merge_config"]) + + metadata["origin_path"] = path + + return cls( + merges=merges, + selected_pairs=selected_pairs, + labels=labels, + merge_config=merge_config, + n_iters_current=metadata["n_iters_current"], + meta=metadata, + ) + + +@dataclass +class MergeHistoryEnsemble: + data: list[MergeHistory] + + def __iter__(self): + return iter(self.data) + + def __getitem__(self, idx: int) -> MergeHistory: + return self.data[idx] + + def _validate_configs_match(self) -> None: + """Ensure all histories have the same merge config.""" + if not self.data: + return + first_config: MergeConfig = self.data[0].merge_config + for history in self.data[1:]: + if history.merge_config != first_config: + raise ValueError("All histories must have the same merge config") + + @property + def config(self) -> MergeConfig: + """Get the merge config used in the ensemble.""" + self._validate_configs_match() + return self.data[0].merge_config + + @property + def n_iters_min(self) -> int: + """Minimum number of iterations across all histories in the ensemble.""" + return min(len(history.merges.k_groups) for history in self.data) + + @property + def n_iters_max(self) -> int: + """Maximum number of iterations across all histories in the ensemble.""" + return max(len(history.merges.k_groups) for history in self.data) + + @property + def n_iters_range(self) -> tuple[int, int]: + """Range of iterations (min, max) across all histories in the ensemble.""" + iter_counts = [len(history.merges.k_groups) for history in self.data] + return (min(iter_counts), max(iter_counts)) + + @property + def n_ensemble(self) -> int: + """Number of ensemble members.""" + return len(self.data) + + @property + def c_components(self) -> int: + """Number of components in each history.""" + c_components: int = self.data[0].c_components + assert all(history.c_components == c_components for history in self.data), ( + "All histories must have the same number of components" + ) + return c_components + + @property + def shape(self) -> tuple[int, int, int]: + """Shape of the ensemble data.""" + return (self.n_ensemble, self.n_iters_min, self.c_components) + + @property + def merges_array(self) -> MergesArray: + n_ens: int = self.n_ensemble + n_iters: int = self.n_iters_min + c_components: int = self.c_components + + output: MergesArray = np.full( + (n_ens, n_iters, c_components), + fill_value=-1, + dtype=np.int16, + # if you have more than 32k components, change this to np.int32 + # if you have more than 2.1b components, rethink your life choices + ) + for i_ens, history in enumerate(self.data): + for i_iter, merge in enumerate(history.merges): + output[i_ens, i_iter] = merge.group_idxs + + return output + + def normalized(self) -> tuple[MergesArray, dict[str, Any]]: + """Normalize the component labels across all histories. + + if different histories see different batches, then they might have different dead + components, and are hence not directly comparable. So, we find the union of all + component labels across all histories, and then any component missing from a history + is put into it's own group in that history + """ + + unique_labels_set: set[str] = set() + for history in self.data: + unique_labels_set.update(history.labels) + + unique_labels_list: list[str] = sorted(unique_labels_set) + unique_labels: ComponentLabels = ComponentLabels(unique_labels_list) + c_components: int = len(unique_labels) + component_label_idxs: dict[str, int] = { + label: idx for idx, label in enumerate(unique_labels) + } + + try: + merges_array: MergesArray = np.full( + (self.n_ensemble, self.n_iters_min, c_components), + fill_value=-1, + dtype=np.int16, + ) + except Exception as e: + err_msg = ( + f"failed to create merge array, probably due to issues with getting shape.\n" + f"{self = }\n" + f"{self.data = }\n" + ) + raise RuntimeError(err_msg) from e + + overlap_stats: Float[np.ndarray, " n_ens"] = np.full( + self.n_ensemble, + fill_value=float("nan"), + dtype=np.float32, + ) + i_ens: int + history: MergeHistory + for i_ens, history in enumerate(self.data): + hist_c_labels: list[str] = history.labels + hist_n_components: int = len(hist_c_labels) + overlap_stats[i_ens] = hist_n_components / c_components + # map from old component indices to new component indices + i_comp_old: int + comp_label: str + for i_comp_old, comp_label in enumerate(hist_c_labels): + i_comp_new: int = component_label_idxs[comp_label] + merges_array[i_ens, :, i_comp_new] = history.merges.group_idxs[ + : self.n_iters_min, i_comp_old + ] + + # assert np.max(merges_array[i_ens]) == hist_n_components - 1, ( + # f"Max component index in history {i_ens} should be {hist_n_components - 1}, " + # f"but got {np.max(merges_array[i_ens])}" + # ) + + # put each missing label into its own group + hist_missing_labels: set[str] = unique_labels_set - set(hist_c_labels) + assert len(hist_missing_labels) == c_components - hist_n_components + idx_missing: int + missing_label: str + for idx_missing, missing_label in enumerate(hist_missing_labels): + i_comp_new_relabel: int = component_label_idxs[missing_label] + merges_array[i_ens, :, i_comp_new_relabel] = np.full( + self.n_iters_min, + fill_value=idx_missing + hist_n_components, + dtype=np.int16, + ) + + # TODO: Consider logging overlap_stats to WandB if run is available + # For now, keep using dbg_tensor for overlap_stats analysis + dbg_tensor(overlap_stats) + + # TODO: double check this + # Convert any Path objects to strings for JSON serialization + history_metadatas: list[dict[str, Any] | None] = [] + for history in self.data: + if history.meta is not None: + meta_copy = history.meta.copy() + # Convert Path objects to strings + for key, value in meta_copy.items(): + if isinstance(value, Path): + meta_copy[key] = str(value) + history_metadatas.append(meta_copy) + else: + history_metadatas.append(None) + + dbg_tensor(merges_array) + + return ( + # TODO: dataclass this + merges_array, + dict( + component_labels=unique_labels, + n_ensemble=self.n_ensemble, + n_iters_min=self.n_iters_min, + n_iters_max=self.n_iters_max, + n_iters_range=self.n_iters_range, + c_components=c_components, + config=self.config.model_dump(mode="json"), + history_metadatas=history_metadatas, + ), + ) + + def get_distances(self, method: DistancesMethod = "perm_invariant_hamming") -> DistancesArray: + merges_array: MergesArray = self.merges_array + return compute_distances( + normalized_merge_array=merges_array, + method=method, + ) diff --git a/spd/clustering/plotting/__init__.py b/spd/clustering/plotting/__init__.py new file mode 100644 index 000000000..b048d1d24 --- /dev/null +++ b/spd/clustering/plotting/__init__.py @@ -0,0 +1 @@ +"""Plotting utilities for clustering module.""" diff --git a/spd/clustering/plotting/activations.py b/spd/clustering/plotting/activations.py new file mode 100644 index 000000000..e7f02ad3b --- /dev/null +++ b/spd/clustering/plotting/activations.py @@ -0,0 +1,389 @@ +"""Plotting functions for activation visualizations.""" + +from collections.abc import Sequence +from pathlib import Path + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np +import torch +import wandb +import wandb.sdk.wandb_run +from jaxtyping import Float, Int +from muutils.dbg import dbg_tensor +from torch import Tensor + +from spd.clustering.activations import ProcessedActivations, compute_coactivatons +from spd.clustering.consts import ActivationsTensor, ClusterCoactivationShaped, ComponentLabels + + +def plot_activations( + processed_activations: ProcessedActivations, + save_dir: Path | None, + n_samples_max: int, + figure_prefix: str = "activations", + figsize_raw: tuple[int, int] = (12, 4), + figsize_concat: tuple[int, int] = (12, 2), + figsize_coact: tuple[int, int] = (8, 6), + hist_scales: tuple[str, str] = ("lin", "log"), + hist_bins: int = 100, + do_sorted_samples: bool = False, + wandb_run: wandb.sdk.wandb_run.Run | None = None, +) -> None: + """Plot activation visualizations including raw, concatenated, sorted, and coactivations. + + Args: + activations: Dictionary of raw activations by module + act_concat: Concatenated activations tensor + coact: Coactivation matrix + labels: Component labels + save_dir: The directory to save the plots to (None to skip saving to disk) + figure_prefix: Prefix for PDF filenames + figsize_raw: Figure size for raw activations + figsize_concat: Figure size for concatenated activations + figsize_coact: Figure size for coactivations + hist_scales: Tuple of (x_scale, y_scale) where each is "lin" or "log" + hist_bins: Number of bins for histograms + """ + if save_dir is not None: + save_dir.mkdir(parents=True, exist_ok=True) + + act_dict: dict[str, ActivationsTensor] = processed_activations.activations_raw + act_concat: ActivationsTensor = processed_activations.activations + dbg_tensor(act_concat) + coact: ClusterCoactivationShaped = compute_coactivatons(act_concat) + dbg_tensor(coact) + labels: ComponentLabels = ComponentLabels(processed_activations.labels) + n_samples: int = act_concat.shape[0] + + # trim the activations if n_samples_max is specified + # clone here so we don't modify the original tensor + act_concat = act_concat[:n_samples_max].clone() + # we don't use the stuff in this dict again, so we can modify it in-place + for key in act_dict: + act_dict[key] = act_dict[key][:n_samples_max] + + # Update n_samples to reflect the truncated size + n_samples = act_concat.shape[0] + + # Raw activations + axs_act: Sequence[plt.Axes] + _fig1: plt.Figure + _fig1, axs_act = plt.subplots(len(act_dict), 1, figsize=figsize_raw) + if len(act_dict) == 1: + assert isinstance(axs_act, plt.Axes) + axs_act = [axs_act] + for i, (key, act) in enumerate(act_dict.items()): + act_raw_data: np.ndarray = act.T.cpu().numpy() + axs_act[i].matshow( + act_raw_data, aspect="auto", vmin=act_raw_data.min(), vmax=act_raw_data.max() + ) + axs_act[i].set_ylabel(f"components\n{key}") + axs_act[i].set_title(f"Raw Activations: {key} (shape: {act_raw_data.shape})") + + if save_dir is not None: + fig1_fname = save_dir / f"{figure_prefix}_raw.pdf" + _fig1.savefig(fig1_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/raw": wandb.Image(_fig1)}, step=0) + + # Close figure to free memory + plt.close(_fig1) + + # Concatenated activations + fig2: plt.Figure + ax2: plt.Axes + fig2, ax2 = plt.subplots(figsize=figsize_concat) + act_data: np.ndarray = act_concat.T.cpu().numpy() + im2 = ax2.matshow(act_data, aspect="auto", vmin=act_data.min(), vmax=act_data.max()) + ax2.set_title("Concatenated Activations") + + # Add component labeling on y-axis + add_component_labeling(ax2, labels, axis="y") + + plt.colorbar(im2) + + if save_dir is not None: + fig2_fname: Path = save_dir / f"{figure_prefix}_concatenated.pdf" + fig2.savefig(fig2_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/concatenated": wandb.Image(fig2)}, step=0) + + # Close figure to free memory + plt.close(fig2) + + # Concatenated activations, sorted samples + if do_sorted_samples: + # TODO: move sample sorting logic to its own function, see + # https://github.com/goodfire-ai/spd/pull/172/files#r2387275601 + fig3: plt.Figure + ax3: plt.Axes + fig3, ax3 = plt.subplots(figsize=figsize_concat) + + # Compute gram matrix (sample similarity) and sort samples using greedy ordering + gram_matrix: Float[Tensor, "samples samples"] = act_concat @ act_concat.T + + # Normalize gram matrix to get cosine similarity + norms: Float[Tensor, "samples 1"] = torch.norm(act_concat, dim=1, keepdim=True) + norms = torch.where(norms > 1e-8, norms, torch.ones_like(norms)) + similarity_matrix: Float[Tensor, "samples samples"] = gram_matrix / (norms @ norms.T) + + # Greedy ordering: start with sample most similar to all others + avg_similarity: Float[Tensor, " samples"] = similarity_matrix.mean(dim=1) + start_idx: int = int(torch.argmax(avg_similarity).item()) + + # Build ordering greedily + ordered_indices: list[int] = [start_idx] + remaining: set[int] = set(range(n_samples)) + remaining.remove(start_idx) + + # Greedily add the nearest unvisited sample + current_idx: int = start_idx + while remaining: + # Find the unvisited sample most similar to current + best_similarity: float = -1 + best_idx: int = -1 + for idx in remaining: + sim: float = similarity_matrix[current_idx, idx].item() + if sim > best_similarity: + best_similarity = sim + best_idx = idx + + ordered_indices.append(best_idx) + remaining.remove(best_idx) + current_idx = best_idx + + sorted_indices: Int[Tensor, " samples"] = torch.tensor( + ordered_indices, dtype=torch.long, device=act_concat.device + ) + act_concat_sorted: ActivationsTensor = act_concat[sorted_indices] + + # Handle log10 properly - add small epsilon to avoid log(0) + act_sorted_data: np.ndarray = act_concat_sorted.T.cpu().numpy() + act_sorted_log: np.ndarray = np.log10(act_sorted_data + 1e-10) + im3 = ax3.matshow( + act_sorted_log, aspect="auto", vmin=act_sorted_log.min(), vmax=act_sorted_log.max() + ) + ax3.set_title("Concatenated Activations $\\log_{10}$, Sorted Samples") + + # Add component labeling on y-axis + add_component_labeling(ax3, labels, axis="y") + + plt.colorbar(im3) + + if save_dir is not None: + fig3_fname: Path = save_dir / f"{figure_prefix}_concatenated_sorted.pdf" + fig3.savefig(fig3_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/concatenated_sorted": wandb.Image(fig3)}, step=0) + + # Close figure to free memory + plt.close(fig3) + + # Coactivations + fig4: plt.Figure + ax4: plt.Axes + fig4, ax4 = plt.subplots(figsize=figsize_coact) + coact_data: np.ndarray = coact.cpu().numpy() + im4 = ax4.matshow(coact_data, aspect="auto", vmin=coact_data.min(), vmax=coact_data.max()) + ax4.set_title("Coactivations") + + # Add component labeling on both axes + add_component_labeling(ax4, labels, axis="x") + add_component_labeling(ax4, labels, axis="y") + + plt.colorbar(im4) + + if save_dir is not None: + fig4_fname: Path = save_dir / f"{figure_prefix}_coactivations.pdf" + fig4.savefig(fig4_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/coactivations": wandb.Image(fig4)}, step=0) + + # Close figure to free memory + plt.close(fig4) + + # log coactivations + fig4_log: plt.Figure + ax4_log: plt.Axes + fig4_log, ax4_log = plt.subplots(figsize=figsize_coact) + # assert np.all(coact_data >= 0) # TODO: why are coacts negative? :/ + coact_log_data: np.ndarray = np.log10(coact_data + 1e-6 + coact_data.min()) + im4_log = ax4_log.matshow( + coact_log_data, aspect="auto", vmin=coact_log_data.min(), vmax=coact_log_data.max() + ) + ax4_log.set_title("Coactivations $\\log_{10}$") + # Add component labeling on both axes + add_component_labeling(ax4_log, labels, axis="x") + add_component_labeling(ax4_log, labels, axis="y") + plt.colorbar(im4_log) + if save_dir is not None: + fig4_log_fname: Path = save_dir / f"{figure_prefix}_coactivations_log.pdf" + fig4_log.savefig(fig4_log_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/coactivations_log": wandb.Image(fig4_log)}, step=0) + + # Close figure to free memory + plt.close(fig4_log) + + # Activation histograms + fig5: plt.Figure + ax5a: plt.Axes + ax5b: plt.Axes + ax5c: plt.Axes + fig5, (ax5a, ax5b, ax5c) = plt.subplots(1, 3, figsize=(15, 4)) + + x_scale: str + y_scale: str + x_scale, y_scale = hist_scales + + # Histogram 1: All activations + all_activations: Float[Tensor, " samples*n_components"] = act_concat.flatten() + all_vals: np.ndarray = all_activations.cpu().numpy() + hist_counts: np.ndarray + bin_edges: np.ndarray + hist_counts, bin_edges = np.histogram(all_vals, bins=hist_bins) + bin_centers: np.ndarray = (bin_edges[:-1] + bin_edges[1:]) / 2 + ax5a.plot(bin_centers, hist_counts, color="blue", linewidth=2) + ax5a.set_title("All Activations") + ax5a.set_xlabel("Activation Value") + ax5a.set_ylabel("Count") + if x_scale == "log": + ax5a.set_xscale("log") + if y_scale == "log": + ax5a.set_yscale("log") + ax5a.grid(True, alpha=0.3) + + # Histogram 2: Activations per component + n_components: int = act_concat.shape[1] + + # Common bin edges for all component histograms + all_min: float = float(all_vals.min()) + all_max: float = float(all_vals.max()) + common_bins: np.ndarray = np.linspace(all_min, all_max, hist_bins) + common_centers: np.ndarray = (common_bins[:-1] + common_bins[1:]) / 2 + + # Get unique label prefixes and assign colors + label_prefixes: list[str] = [label.split(":")[0] for label in labels] + unique_prefixes: list[str] = list(dict.fromkeys(label_prefixes)) # Preserve order + colors: Sequence[tuple[int, int, int]] = mpl.colormaps["tab10"]( + np.linspace(0, 1, len(unique_prefixes)) + ) # pyright: ignore[reportAssignmentType] + prefix_colors: dict[str, tuple[int, int, int]] = { + prefix: colors[i] for i, prefix in enumerate(unique_prefixes) + } + + for comp_idx in range(n_components): + component_activations: Float[Tensor, " n_samples"] = act_concat[:, comp_idx] + comp_vals: np.ndarray = component_activations.cpu().numpy() + hist_counts, _ = np.histogram(comp_vals, bins=common_bins, density=True) + + # Get color based on label prefix + prefix: str = label_prefixes[comp_idx] + color: tuple[int, int, int] = prefix_colors[prefix] + + ax5b.plot(common_centers, hist_counts, color=color, alpha=0.1, linewidth=1) + + ax5b.set_title(f"Per Component ({n_components} components)") + ax5b.set_xlabel("Activation Value") + ax5b.set_ylabel("Density") + if x_scale == "log": + ax5b.set_xscale("log") + if y_scale == "log": + ax5b.set_yscale("log") + ax5b.grid(True, alpha=0.3) + + # Histogram 3: Activations per sample + for sample_idx in range(n_samples): + sample_activations: Float[Tensor, " n_components"] = act_concat[sample_idx, :] + sample_vals: np.ndarray = sample_activations.cpu().numpy() + hist_counts, _ = np.histogram(sample_vals, bins=common_bins, density=True) + ax5c.plot(common_centers, hist_counts, color="blue", alpha=0.1, linewidth=1) + + ax5c.set_title(f"Per Sample ({n_samples} samples)") + ax5c.set_xlabel("Activation Value") + ax5c.set_ylabel("Density") + if x_scale == "log": + ax5c.set_xscale("log") + if y_scale == "log": + ax5c.set_yscale("log") + ax5c.grid(True, alpha=0.3) + + plt.tight_layout() + + if save_dir is not None: + fig5_fname: Path = save_dir / f"{figure_prefix}_histograms.pdf" + fig5.savefig(fig5_fname, bbox_inches="tight", dpi=300) + + # Log to WandB if available + if wandb_run is not None: + wandb_run.log({"plots/activations/histograms": wandb.Image(fig5)}, step=0) + + # Close figure to free memory + plt.close(fig5) + + +def add_component_labeling( + ax: plt.Axes, component_labels: ComponentLabels, axis: str = "x" +) -> None: + """Add component labeling using major/minor ticks to show module boundaries. + + Args: + ax: Matplotlib axis to modify + component_labels: List of component labels in format "module:index" + axis: Which axis to label ('x' or 'y') + """ + if not component_labels: + return + + # Extract module information + module_changes: list[int] = [] + current_module: str = component_labels[0].split(":")[0] + module_labels: list[str] = [] + + for i, label in enumerate(component_labels): + module: str = label.split(":")[0] + if module != current_module: + module_changes.append(i) + module_labels.append(current_module) + current_module = module + module_labels.append(current_module) + + # Set up major and minor ticks + # Minor ticks: every 10 components + minor_ticks: list[int] = list(range(0, len(component_labels), 10)) + + # Major ticks: module boundaries (start of each module) + major_ticks: list[int] = [0] + module_changes + major_labels: list[str] = module_labels + + if axis == "x": + ax.set_xticks(minor_ticks, minor=True) + ax.set_xticks(major_ticks) + ax.set_xticklabels(major_labels) + ax.set_xlim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="x", which="minor", length=2, width=0.5) + ax.tick_params(axis="x", which="major", length=6, width=1.5) + for x in major_ticks: + ax.axvline(x - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) + else: + ax.set_yticks(minor_ticks, minor=True) + ax.set_yticks(major_ticks) + ax.set_yticklabels(major_labels) + ax.set_ylim(-0.5, len(component_labels) - 0.5) + # Style the ticks + ax.tick_params(axis="y", which="minor", length=2, width=0.5) + ax.tick_params(axis="y", which="major", length=6, width=1.5) + for y in major_ticks: + ax.axhline(y - 0.5, color="black", linestyle="--", linewidth=0.5, alpha=0.5) diff --git a/spd/clustering/plotting/merge.py b/spd/clustering/plotting/merge.py new file mode 100644 index 000000000..b213e1724 --- /dev/null +++ b/spd/clustering/plotting/merge.py @@ -0,0 +1,359 @@ +"""Plotting functions for merge visualizations.""" + +from typing import Any, Literal + +import matplotlib.pyplot as plt +import numpy as np +import torch +from jaxtyping import Bool, Float, Int +from torch import Tensor + +from spd.clustering.consts import ClusterCoactivationShaped, ComponentLabels, DistancesArray +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.merge_history import MergeHistory +from spd.clustering.util import format_scientific_latex + +DEFAULT_PLOT_CONFIG: dict[str, Any] = dict( + figsize=(16, 10), + tick_spacing=5, + save_pdf=False, + figure_prefix="merge_iteration", +) + + +def plot_merge_matrix( + merge_matrix: Bool[Tensor, "k_groups n_components"], + show: bool = True, + figsize: tuple[int, int] = (10, 3), + show_row_sums: bool | None = None, + ax: "plt.Axes | None" = None, + component_labels: ComponentLabels | None = None, +) -> None: + import matplotlib.pyplot as plt + + k_groups: int + k_groups, _ = merge_matrix.shape + group_sizes: Int[Tensor, " k_groups"] = merge_matrix.sum(dim=1) + + if show_row_sums is None: + show_row_sums = k_groups <= 20 + + ax_lbl: plt.Axes | None = None + if ax is not None: + show_row_sums = False # don't show row sums if we have an ax to plot on + ax_mat = ax + assert not show_row_sums + else: + if show_row_sums: + _fig, (ax_mat, ax_lbl) = plt.subplots( + 1, 2, figsize=figsize, gridspec_kw={"width_ratios": [10, 1]} + ) + else: + _fig, ax_mat = plt.subplots(figsize=figsize) + + ax_mat.matshow(merge_matrix.cpu(), aspect="auto", cmap="Blues", interpolation="nearest") + ax_mat.set_xlabel("Components") + ax_mat.set_ylabel("Groups") + ax_mat.set_title("Merge Matrix") + + # Add component labeling if component labels are provided + if component_labels is not None: + # Import the function here to avoid circular imports + from spd.clustering.plotting.activations import add_component_labeling + + add_component_labeling(ax_mat, component_labels, axis="x") + + if show_row_sums: + assert ax_lbl is not None + ax_lbl.set_xlim(0, 1) + ax_lbl.set_ylim(-0.5, k_groups - 0.5) + ax_lbl.invert_yaxis() + ax_lbl.set_title("Row Sums") + ax_lbl.axis("off") + for i, size in enumerate(group_sizes): + ax_lbl.text(0.5, i, str(size.item()), va="center", ha="center", fontsize=12) + + plt.tight_layout() + if show: + plt.show() + + +def plot_merge_iteration( + current_merge: GroupMerge, + current_coact: ClusterCoactivationShaped, + costs: ClusterCoactivationShaped, + # pair_cost: float, + iteration: int, + component_labels: ComponentLabels | None = None, + plot_config: dict[str, Any] | None = None, + nan_diag: bool = True, + show: bool = False, +) -> plt.Figure: + """Plot merge iteration results with merge tree, coactivations, and costs. + + Args: + current_merge: Current merge state + current_coact: Current coactivation matrix + costs: Current cost matrix + pair_cost: Cost of selected merge pair + iteration: Current iteration number + component_labels: Component labels for axis labeling + plot_config: Plot configuration settings + nan_diag: Whether to set diagonal to NaN for visualization + show: Whether to display the plot (default: False) + + Returns: + The matplotlib figure object + + Note: + Caller is responsible for closing the returned figure with plt.close(fig) + to prevent memory leaks. + """ + plot_config_: dict[str, Any] = { + **DEFAULT_PLOT_CONFIG, + **(plot_config or {}), + } + axs: list[plt.Axes] + fig, axs = plt.subplots( + 1, 3, figsize=plot_config_["figsize"], sharey=True, gridspec_kw={"width_ratios": [2, 1, 1]} + ) + + # Merge plot + plot_merge_matrix( + current_merge.to_matrix(), + ax=axs[0], + show=False, + component_labels=component_labels, + ) + + axs[0].set_title("Merge") + + # Coactivations plot + coact_min: float = current_coact.min().item() + coact_max: float = current_coact.max().item() + if nan_diag: + current_coact = current_coact.clone() + current_coact.fill_diagonal_(np.nan) + axs[1].matshow(current_coact.cpu().numpy(), aspect="equal") + coact_min_str: str = format_scientific_latex(coact_min) + coact_max_str: str = format_scientific_latex(coact_max) + axs[1].set_title(f"Coactivations\n[{coact_min_str}, {coact_max_str}]") + + # Setup ticks for coactivations + k_groups: int = current_coact.shape[0] + minor_ticks: list[int] = list(range(0, k_groups, plot_config_["tick_spacing"])) + axs[1].set_yticks(minor_ticks) + axs[1].set_xticks(minor_ticks) + axs[1].set_xticklabels([]) # Remove x-axis tick labels but keep ticks + + # Costs plot + costs_min: float = costs.min().item() + costs_max: float = costs.max().item() + if nan_diag: + costs = costs.clone() + costs.fill_diagonal_(np.nan) + axs[2].matshow(costs.cpu().numpy(), aspect="equal") + costs_min_str: str = format_scientific_latex(costs_min) + costs_max_str: str = format_scientific_latex(costs_max) + axs[2].set_title(f"Costs\n[{costs_min_str}, {costs_max_str}]") + + # Setup ticks for costs + axs[2].set_yticks(minor_ticks) + axs[2].set_xticks(minor_ticks) + axs[2].set_xticklabels([]) # Remove x-axis tick labels but keep ticks + + # fig.suptitle(f"Iteration {iteration} with cost {pair_cost:.4f}") + fig.suptitle(f"Iteration {iteration}") + plt.tight_layout() + + if plot_config_["save_pdf"]: + fig.savefig( + f"{plot_config_['figure_prefix']}_iter_{iteration:03d}.pdf", + bbox_inches="tight", + dpi=300, + ) + + if show: + plt.show() + + return fig + + +def plot_dists_distribution( + distances: DistancesArray, + mode: Literal["points", "dist"] = "points", + label: str | None = None, + ax: plt.Axes | None = None, + kwargs_fig: dict[str, Any] | None = None, + kwargs_plot: dict[str, Any] | None = None, + use_symlog: bool = True, + linthresh: float = 1.0, +) -> plt.Axes: + n_iters: int = distances.shape[0] + n_ens: int = distances.shape[1] + assert distances.shape[2] == n_ens, "Distances must be square" + + # Ensure ax and kwargs_fig are not both provided + if ax is not None and kwargs_fig is not None: + raise ValueError("Cannot provide both ax and kwargs_fig") + + dists_flat: Float[np.ndarray, " n_iters n_ens*n_ens"] = distances.reshape( + distances.shape[0], -1 + ) + + # Create figure if ax not provided + if ax is None: + _fig, ax_ = plt.subplots( # pyright: ignore[reportCallIssue] + 1, + 1, + **dict( + figsize=(8, 5), # pyright: ignore[reportArgumentType] + **(kwargs_fig or {}), + ), + ) + else: + ax_ = ax + + if mode == "points": + # Original points mode + n_samples: int = dists_flat.shape[1] + for i in range(n_iters): + ax_.plot( + np.full((n_samples), i), + dists_flat[i], + **dict( # pyright: ignore[reportArgumentType] + marker="o", + linestyle="", + color="blue", + alpha=min(1, 10 / (n_ens * n_ens)), + markersize=5, + markeredgewidth=0, + **(kwargs_plot or {}), + ), + ) + elif mode == "dist": + # Distribution statistics mode + # Generate a random color for this plot + color: Float[np.ndarray, " 3"] = np.random.rand(3) + + # Calculate statistics for each iteration + mins: list[float] = [] + maxs: list[float] = [] + means: list[float] = [] + medians: list[float] = [] + q1s: list[float] = [] + q3s: list[float] = [] + + for i in range(n_iters): + # Filter out NaN values (diagonal and upper triangle) + valid_dists: Float[np.ndarray, " n_valid"] = dists_flat[i][~np.isnan(dists_flat[i])] + if len(valid_dists) > 0: + mins.append(np.min(valid_dists)) + maxs.append(np.max(valid_dists)) + means.append(float(np.mean(valid_dists))) + medians.append(float(np.median(valid_dists))) + q1s.append(float(np.percentile(valid_dists, 25))) + q3s.append(float(np.percentile(valid_dists, 75))) + else: + # Handle case with no valid distances + mins.append(np.nan) + maxs.append(np.nan) + means.append(np.nan) + medians.append(np.nan) + q1s.append(np.nan) + q3s.append(np.nan) + + iterations: Int[np.ndarray, " n_iters"] = np.arange(n_iters) + + # Plot statistics + ax_.plot(iterations, mins, "-", color=color, alpha=0.5) + ax_.plot(iterations, maxs, "-", color=color, alpha=0.5) + ax_.plot(iterations, means, "-", color=color, linewidth=2, label=label) + ax_.plot(iterations, medians, "--", color=color, linewidth=2) + ax_.plot(iterations, q1s, ":", color=color, alpha=0.7) + ax_.plot(iterations, q3s, ":", color=color, alpha=0.7) + + # Shade between quartiles + ax_.fill_between(iterations, q1s, q3s, color=color, alpha=0.2) + + ax_.set_xlabel("Iteration #") + ax_.set_ylabel("distance") + ax_.set_title("Distribution of pairwise distances between group merges in an ensemble") + + if use_symlog: + from matplotlib.ticker import FuncFormatter + + ax_.set_yscale("symlog", linthresh=linthresh, linscale=0.2) + + # Custom formatter for y-axis ticks + def custom_format(y: float, _pos: int) -> str: + if abs(y) < linthresh: + # Show exact values in the linear range + return f"{y:.1f}" + elif abs(y) == 1: + return "1" + elif abs(y) == 10: + return "10" + else: + # Use scientific notation for larger values + exponent = int(np.log10(abs(y))) + return f"$10^{{{exponent}}}$" + + ax_.yaxis.set_major_formatter(FuncFormatter(custom_format)) + + # Add a visual indicator for the linear region (0 to linthresh) + ax_.axhspan(0, linthresh, alpha=0.05, color="gray", zorder=-10) + # Add subtle lines at linthresh boundaries + ax_.axhline(linthresh, color="gray", linestyle="--", linewidth=0.5, alpha=0.3) + if linthresh > 0: + ax_.axhline(0, color="gray", linestyle="-", linewidth=0.5, alpha=0.3) + + return ax_ + + +def plot_merge_history_cluster_sizes( + history: MergeHistory, + figsize: tuple[int, int] = (10, 5), + fmt: str = "png", + file_prefix: str | None = None, +) -> plt.Figure: + """Plot cluster sizes over iterations. + + Note: + Caller is responsible for closing the returned figure with plt.close(fig) + to prevent memory leaks. + """ + k_groups_t: Int[Tensor, " n_iters"] = history.merges.k_groups + valid_mask: Bool[Tensor, " n_iters"] = k_groups_t.ne(-1) + has_data: bool = bool(valid_mask.any().item()) + if not has_data: + raise ValueError("No populated iterations in history.k_groups") + + group_idxs_all: Int[Tensor, " n_iters n_components"] = history.merges.group_idxs[valid_mask] + k_groups_all: Int[Tensor, " n_iters"] = k_groups_t[valid_mask] + max_k: int = int(k_groups_all.max().item()) + + counts_list: list[Int[Tensor, " max_k"]] = [ + torch.bincount(row[row.ge(0)], minlength=max_k) # per-iteration cluster sizes + for row in group_idxs_all + ] + counts: Int[Tensor, " n_iters max_k"] = torch.stack(counts_list, dim=0) + + mask_pos: Bool[Tensor, " n_iters max_k"] = counts.gt(0) + it_idx_t, grp_idx_t = torch.nonzero(mask_pos, as_tuple=True) + xs_t: Float[Tensor, " n_points"] = it_idx_t.to(torch.float32) + sizes_t: Float[Tensor, " n_points"] = counts[it_idx_t, grp_idx_t].to(torch.float32) + + fig, ax = plt.subplots(figsize=figsize) + ax.plot( + xs_t.cpu().numpy(), sizes_t.cpu().numpy(), "bo", markersize=3, alpha=0.15, markeredgewidth=0 + ) + ax.set_xlabel("Iteration") + ax.set_ylabel("Cluster size") + ax.set_yscale("log") + ax.set_title("Distribution of cluster sizes over time") + + if file_prefix is not None: + fig.savefig(f"{file_prefix}_cluster_sizes.{fmt}", bbox_inches="tight", dpi=300) + + return fig diff --git a/spd/clustering/scripts/calc_distances.py b/spd/clustering/scripts/calc_distances.py new file mode 100644 index 000000000..993335671 --- /dev/null +++ b/spd/clustering/scripts/calc_distances.py @@ -0,0 +1,145 @@ +"""Calculate distances between clustering runs in an ensemble. + +Output structure: + SPD_CACHE_DIR/ensemble/{pipeline_run_id}/ + ├── pipeline_config.yaml # Created by run_pipeline.py + ├── ensemble_meta.json # Ensemble metadata + ├── ensemble_merge_array.npz # Normalized merge array + ├── distances_.npz # Distance array for each method + └── plots/ + └── distances_.png # Distance distribution plot +""" + +import argparse +import json +import multiprocessing + +import numpy as np +import torch +from matplotlib import pyplot as plt +from matplotlib.axes import Axes +from muutils.dbg import dbg_tensor + +from spd.clustering.consts import DistancesArray, DistancesMethod +from spd.clustering.ensemble_registry import get_clustering_runs +from spd.clustering.math.merge_distances import compute_distances +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.plotting.merge import plot_dists_distribution +from spd.clustering.scripts.run_clustering import ClusteringRunStorage +from spd.log import logger +from spd.settings import SPD_CACHE_DIR +from spd.utils.run_utils import ExecutionStamp + +# Set spawn method for CUDA compatibility with multiprocessing +# Must be done before any CUDA operations +if torch.cuda.is_available(): + try: # noqa: SIM105 + multiprocessing.set_start_method("spawn") + except RuntimeError: + # Already set, ignore + pass + + +def main(pipeline_run_id: str, distances_method: DistancesMethod) -> None: + """Calculate distances between clustering runs in an ensemble. + + Args: + pipeline_run_id: Pipeline run ID to query from registry + distances_method: Method for calculating distances + """ + logger.info(f"Calculating distances for pipeline run: {pipeline_run_id}") + + # Query registry for clustering runs + clustering_runs = get_clustering_runs(pipeline_run_id) + if not clustering_runs: + raise ValueError(f"No clustering runs found for pipeline {pipeline_run_id}") + + logger.info(f"Found {len(clustering_runs)} clustering runs") + + # Load histories from individual clustering run directories + histories: list[MergeHistory] = [] + for idx, clustering_run_id in clustering_runs: + history_path = ClusteringRunStorage( + ExecutionStamp( + run_id=clustering_run_id, + snapshot_branch="", + commit_hash="", + run_type="cluster", + ) + ).history_path + + # SPD_CACHE_DIR / "cluster" / clustering_run_id / "history.npz" + if not history_path.exists(): + raise FileNotFoundError( + f"History not found for run {clustering_run_id}: {history_path}" + ) + histories.append(MergeHistory.read(history_path)) + logger.info(f"Loaded history for run {idx}: {clustering_run_id}") + + # Compute normalized ensemble + ensemble: MergeHistoryEnsemble = MergeHistoryEnsemble(data=histories) + merge_array, merge_meta = ensemble.normalized() + + # Get pipeline output directory + pipeline_dir = SPD_CACHE_DIR / "ensemble" / pipeline_run_id + + # Save ensemble metadata and merge array + ensemble_meta_path = pipeline_dir / "ensemble_meta.json" + ensemble_meta_path.write_text(json.dumps(merge_meta, indent=2)) + logger.info(f"Saved ensemble metadata to {ensemble_meta_path}") + + ensemble_array_path = pipeline_dir / "ensemble_merge_array.npz" + np.savez_compressed(ensemble_array_path, merge_array=merge_array) + logger.info(f"Saved ensemble merge array to {ensemble_array_path}") + + # Compute distances + logger.info(f"Computing distances using method: {distances_method}") + distances: DistancesArray = compute_distances( + normalized_merge_array=merge_array, + method=distances_method, + ) + + dbg_tensor(distances) + + distances_path = pipeline_dir / f"distances_{distances_method}.npz" + np.savez_compressed(distances_path, distances=distances) + logger.info(f"Distances computed and saved: shape={distances.shape}, path={distances_path}") + + # Create and save distances distribution plot + ax: Axes = plot_dists_distribution( + distances=distances, mode="points", label=f"{distances_method} distances" + ) + plt.title(f"Distance Distribution ({distances_method})") + + # Only add legend if there are labeled artists + handles, _labels = ax.get_legend_handles_labels() + if handles: + plt.legend() + + plots_dir = pipeline_dir / "plots" + plots_dir.mkdir(parents=True, exist_ok=True) + fig_path = plots_dir / f"distances_{distances_method}.png" + plt.savefig(fig_path) + plt.close() + logger.info(f"Saved distances distribution plot to {fig_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Calculate distances between clustering runs") + parser.add_argument( + "--pipeline-run-id", + type=str, + required=True, + help="Pipeline run ID to query from registry", + ) + parser.add_argument( + "--distances-method", + choices=DistancesMethod.__args__, + default="perm_invariant_hamming", + help="Method for calculating distances", + ) + args = parser.parse_args() + main( + pipeline_run_id=args.pipeline_run_id, + distances_method=args.distances_method, + ) diff --git a/spd/clustering/scripts/run_clustering.py b/spd/clustering/scripts/run_clustering.py new file mode 100644 index 000000000..54f0805c6 --- /dev/null +++ b/spd/clustering/scripts/run_clustering.py @@ -0,0 +1,442 @@ +"""Perform a single clustering run. + +This can be run as a standalone script, or called via `spd-cluster` (i.e. clustering/scripts/run_pipeline.py). +If called via spd-cluster, the ensemble-key is passed in to identify the run within the pipeline ensemble. + +Output structure: + / # from execution stamp (run_type="cluster") + ├── clustering_run_config.json + └── history.npz +""" + +import argparse +import gc +import os +import tempfile +from collections.abc import Callable +from functools import partial +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import torch +import wandb +from jaxtyping import Float, Int +from matplotlib.figure import Figure +from torch import Tensor +from wandb.sdk.wandb_run import Run + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.consts import ( + ActivationsTensor, + BatchTensor, + ClusterCoactivationShaped, + ComponentLabels, +) +from spd.clustering.dataset import load_dataset +from spd.clustering.ensemble_registry import _ENSEMBLE_REGISTRY_DB, register_clustering_run +from spd.clustering.math.merge_matrix import GroupMerge +from spd.clustering.math.semilog import semilog +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_history import MergeHistory +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import plot_merge_history_cluster_sizes, plot_merge_iteration +from spd.clustering.storage import StorageBase +from spd.clustering.wandb_tensor_info import wandb_log_tensor +from spd.log import logger +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.spd_types import TaskName +from spd.utils.distributed_utils import get_device +from spd.utils.general_utils import replace_pydantic_model +from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str + +os.environ["WANDB_QUIET"] = "true" + + +class ClusteringRunStorage(StorageBase): + """Storage paths for a single clustering run. + + All paths are relative to ExecutionStamp.out_dir. + """ + + # Relative path constants + _CONFIG = "clustering_run_config.json" + # we are saving a zip file with things in it besides npy files -- hence, `.zip` and not `.npz` + _HISTORY = "history.zip" + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + super().__init__(execution_stamp) + self.config_path: Path = self.base_dir / self._CONFIG + self.history_path: Path = self.base_dir / self._HISTORY + + +LogCallback = Callable[ + [ + ClusterCoactivationShaped, + ComponentLabels, + GroupMerge, + ClusterCoactivationShaped, + MergeHistory, + int, + int, + float, + float, + float, + Float[Tensor, " k_groups"], + ], + None, +] + + +def _log_merge_history_plots(run: Run, history: MergeHistory) -> None: + """Log merge history plots to WandB.""" + fig_cs: Figure = plot_merge_history_cluster_sizes(history=history) + run.log( + {"plots/merge_history_cluster_sizes": wandb.Image(fig_cs)}, + step=history.n_iters_current, + ) + plt.close(fig_cs) + + +def _save_merge_history_artifact( + run: Run, + history_path: Path, + history: MergeHistory, +) -> None: + """Save merge history as WandB artifact.""" + artifact: wandb.Artifact = wandb.Artifact( + name="merge_history", + type="merge_history", + description="Merge history", + metadata={"n_iters_current": history.n_iters_current, "filename": str(history_path)}, + ) + artifact.add_file(str(history_path)) + run.log_artifact(artifact) + + +def _log_callback( + run: Run, + run_config: ClusteringRunConfig, + current_coact: ClusterCoactivationShaped, + component_labels: ComponentLabels, + current_merge: GroupMerge, + costs: ClusterCoactivationShaped, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: Float[Tensor, " k_groups"], +) -> None: + """Callback for logging during merge iteration.""" + if iter_idx % run_config.logging_intervals.stat == 0: + run.log( + { + "k_groups": int(k_groups), + "merge_pair_cost": merge_pair_cost, + "merge_pair_cost_semilog[1e-3]": semilog(merge_pair_cost, epsilon=1e-3), + "mdl_loss": float(mdl_loss), + "mdl_loss_norm": float(mdl_loss_norm), + }, + step=iter_idx, + ) + + if iter_idx % run_config.logging_intervals.tensor == 0: + group_sizes: Int[Tensor, " k_groups"] = current_merge.components_per_group + + tensor_data: dict[str, Tensor] = { + "coactivation": current_coact, + "costs": costs, + "group_sizes": group_sizes, + "group_activations": diag_acts, + "group_activations_over_sizes": ( + diag_acts / group_sizes.to(device=diag_acts.device).float() + ), + } + + fraction_singleton_groups: float = (group_sizes == 1).float().mean().item() + if fraction_singleton_groups > 0: + tensor_data["group_sizes.log1p"] = torch.log1p(group_sizes.float()) + + fraction_zero_coacts: float = (current_coact == 0).float().mean().item() + if fraction_zero_coacts > 0: + tensor_data["coactivation.log1p"] = torch.log1p(current_coact.float()) + + wandb_log_tensor(run, tensor_data, name="iters", step=iter_idx) + + run.log( + { + "fraction_singleton_groups": float(fraction_singleton_groups), + "num_nonsingleton_groups": int((group_sizes > 1).sum().item()), + "fraction_zero_coacts": float(fraction_zero_coacts), + }, + step=iter_idx, + ) + + if iter_idx > 0 and iter_idx % run_config.logging_intervals.artifact == 0: + with tempfile.NamedTemporaryFile() as tmp_file: + file: Path = Path(tmp_file.name) + merge_history.save(file) + artifact: wandb.Artifact = wandb.Artifact( + name=f"merge_hist_iter.iter_{iter_idx}", + type="merge_hist_iter", + description=f"Group indices at iteration {iter_idx}", + metadata={ + "iteration": iter_idx, + "config": merge_history.merge_config.model_dump(mode="json"), + }, + ) + artifact.add_file(str(file)) + run.log_artifact(artifact) + + if iter_idx % run_config.logging_intervals.plot == 0: + fig: Figure = plot_merge_iteration( + current_merge=current_merge, + current_coact=current_coact, + costs=costs, + iteration=iter_idx, + component_labels=component_labels, + show=False, + ) + run.log({"plots/merges": wandb.Image(fig)}, step=iter_idx) + plt.close(fig) + + +def main(run_config: ClusteringRunConfig) -> Path: + """A single clustering run. + + Args: + run_config: Runtime parameters for this clustering run + + Returns: + Path to saved merge history file + """ + # Create ExecutionStamp and storage + # don't create git snapshot -- if we are part of an ensemble, the snapshot should be created by the pipeline + execution_stamp = ExecutionStamp.create( + run_type="cluster", + create_snapshot=False, + ) + storage = ClusteringRunStorage(execution_stamp) + clustering_run_id = execution_stamp.run_id + logger.info(f"Clustering run ID: {clustering_run_id}") + + # Register with ensemble if this is part of a pipeline + assigned_idx: int | None + if run_config.ensemble_id: + assigned_idx = register_clustering_run( + pipeline_run_id=run_config.ensemble_id, + clustering_run_id=clustering_run_id, + ) + + logger.info( + f"Registered with pipeline {run_config.ensemble_id} at index {assigned_idx} in {_ENSEMBLE_REGISTRY_DB}" + ) + # IMPORTANT: set dataset seed based on assigned index + run_config = replace_pydantic_model( + run_config, + {"dataset_seed": run_config.dataset_seed + assigned_idx}, + ) + else: + assigned_idx = None + + # save config + run_config.to_file(storage.config_path) + logger.info(f"Config saved to {storage.config_path}") + + # start + logger.info("Starting clustering run") + logger.info(f"Output directory: {storage.base_dir}") + device = get_device() + + spd_run = SPDRunInfo.from_path(run_config.model_path) + task_name: TaskName = spd_run.config.task_config.task_name + + # 1. Load dataset + logger.info(f"Loading dataset (seed={run_config.dataset_seed})") + load_dataset_kwargs: dict[str, Any] = dict() + if run_config.dataset_streaming: + logger.info("Using streaming dataset loading") + load_dataset_kwargs["config_kwargs"] = dict(streaming=True) + assert task_name == "lm", ( + f"Streaming dataset loading only supported for 'lm' task, got '{task_name = }'. Remove dataset_streaming=True from config or use a different task." + ) + batch: BatchTensor = load_dataset( + model_path=run_config.model_path, + task_name=task_name, + batch_size=run_config.batch_size, + seed=run_config.dataset_seed, + **load_dataset_kwargs, + ) + batch = batch.to(device) + + # 2. Setup WandB for this run + wandb_run: Run | None = None + if run_config.wandb_project is not None: + wandb_run = wandb.init( + id=clustering_run_id, + entity=run_config.wandb_entity, + project=run_config.wandb_project, + group=run_config.ensemble_id, + config=run_config.model_dump(mode="json"), + tags=[ + "clustering", + f"task:{task_name}", + f"model:{run_config.wandb_decomp_model}", + f"ensemble_id:{run_config.ensemble_id}", + f"assigned_idx:{assigned_idx}", + ], + ) + # logger.info(f"WandB run: {wandb_run.url}") + + # 3. Load model + logger.info("Loading model") + model = ComponentModel.from_run_info(spd_run).to(device) + + # 4. Compute activations + logger.info("Computing activations") + activations_dict: ( + dict[str, Float[Tensor, "batch seq C"]] | dict[str, Float[Tensor, "batch C"]] + ) = component_activations( + model=model, + batch=batch, + device=device, + ) + + # 5. Process activations + logger.info("Processing activations") + processed_activations: ProcessedActivations = process_activations( + activations=activations_dict, + filter_dead_threshold=run_config.merge_config.filter_dead_threshold, + seq_mode="concat" if task_name == "lm" else None, + filter_modules=run_config.merge_config.filter_modules, + ) + + # 6. Log activations (if WandB enabled) + if wandb_run is not None: + logger.info("Plotting activations") + plot_activations( + processed_activations=processed_activations, + save_dir=None, # Don't save to disk, only WandB + n_samples_max=256, + wandb_run=wandb_run, + ) + wandb_log_tensor( + wandb_run, + processed_activations.activations, + "activations", + 0, + single=True, + ) + + # Clean up memory + activations: ActivationsTensor = processed_activations.activations + component_labels: ComponentLabels = ComponentLabels(processed_activations.labels.copy()) + del processed_activations + del activations_dict + del model + del batch + gc.collect() + + # 7. Run merge iteration + logger.info("Starting merging") + log_callback: LogCallback | None = ( + partial(_log_callback, run=wandb_run, run_config=run_config) + if wandb_run is not None + else None + ) + + history: MergeHistory = merge_iteration( + merge_config=run_config.merge_config, + activations=activations, + component_labels=component_labels, + log_callback=log_callback, + ) + + # 8. Save merge history + + history.save(storage.history_path) + logger.info(f"History saved to {storage.history_path}") + + # 9. Log to WandB + if wandb_run is not None: + _log_merge_history_plots(wandb_run, history) + _save_merge_history_artifact(wandb_run, storage.history_path, history) + wandb_run.finish() + logger.info("WandB run finished") + + return storage.history_path + + +def cli() -> None: + """CLI for running a single clustering run.""" + parser = argparse.ArgumentParser(description="Run clustering on a single dataset") + parser.add_argument( + "--config", + type=Path, + required=True, + help="Path to ClusteringRunConfig file", + ) + parser.add_argument( + "--pipeline-run-id", + type=str, + default=None, + help="Pipeline run ID (ensemble identifier). If provided with --idx-in-ensemble, registers run.", + ) + parser.add_argument( + "--idx-in-ensemble", + type=int, + default=None, + help="Index of this run in the ensemble", + ) + parser.add_argument( + "--wandb-project", + type=read_noneable_str, + default=_NO_ARG_PARSSED_SENTINEL, + help="WandB project name (if not provided, WandB logging is disabled)", + ) + parser.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity name (user or team)", + ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset)", + ) + + args: argparse.Namespace = parser.parse_args() + + # Load base config + run_config = ClusteringRunConfig.from_file(args.config) + + # Override config values from CLI + overrides: dict[str, Any] = { + "dataset_streaming": args.dataset_streaming, + } + + # Handle ensemble-related overrides + if args.pipeline_run_id is not None: + overrides["ensemble_id"] = args.pipeline_run_id + + if args.wandb_project is not _NO_ARG_PARSSED_SENTINEL: + overrides["wandb_project"] = args.wandb_project + if args.wandb_entity is not None: + overrides["wandb_entity"] = args.wandb_entity + + run_config = replace_pydantic_model(run_config, overrides) + + # Run clustering + main(run_config) + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/scripts/run_pipeline.py b/spd/clustering/scripts/run_pipeline.py new file mode 100644 index 000000000..179bc8bca --- /dev/null +++ b/spd/clustering/scripts/run_pipeline.py @@ -0,0 +1,480 @@ +"""Submit clustering runs to SLURM as separate jobs in a SLURM array. + +This script submits independent clustering runs as a SLURM job array, +where each run gets its own dataset (seeded), WandB run, and merge history output. + +Also submits a job to calculate distances between the clustering runs, which will run after +the clustering runs (the SLURM job depends on the previous array job). + +Output structure (only pipeline_config.json is saved to directly in this script. The files under + are saved by run_clustering.py which is called in SLURM jobs deployed by this script.): + / # from execution stamp + |── pipeline_config.json # Saved in this script + |── clustering_run_config.json # make copy of the file pointed to by pipeline config + ├── ensemble_meta.json # (Saved by calc_distances.py) Ensemble metadata + ├── ensemble_merge_array.npz # (Saved by calc_distances.py) Normalized merge array + ├── distances_.npz # (Saved by calc_distances.py) Distance array for each method + └── distances_.png # (Saved by calc_distances.py) Distance distribution plot +""" + +import argparse +import os +import shlex +import tempfile +from pathlib import Path +from typing import Any + +import wandb_workspaces.workspaces as ws +from pydantic import Field, PositiveInt, field_validator, model_validator + +from spd.base_config import BaseConfig +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.consts import DistancesMethod +from spd.clustering.storage import StorageBase +from spd.log import logger +from spd.settings import SPD_CACHE_DIR +from spd.utils.command_utils import run_script_array_local +from spd.utils.general_utils import replace_pydantic_model +from spd.utils.run_utils import _NO_ARG_PARSSED_SENTINEL, ExecutionStamp, read_noneable_str +from spd.utils.slurm_utils import ( + create_slurm_array_script, + create_slurm_script, + submit_slurm_script, +) + +os.environ["WANDB_QUIET"] = "true" + + +class ClusteringPipelineStorage(StorageBase): + """Storage paths for clustering pipeline (ensemble). + + All paths are relative to ExecutionStamp.out_dir. + """ + + # Relative path constants + _PIPELINE_CONFIG = "pipeline_config.yaml" + _RUN_IDS = "run_ids.json" + _ENSEMBLE_META = "ensemble_meta.json" + _ENSEMBLE_MERGE_ARRAY = "ensemble_merge_array.npz" + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + super().__init__(execution_stamp) + self.pipeline_config_path: Path = self.base_dir / self._PIPELINE_CONFIG + self.run_ids_path: Path = self.base_dir / self._RUN_IDS + self.ensemble_meta_path: Path = self.base_dir / self._ENSEMBLE_META + self.ensemble_merge_array_path: Path = self.base_dir / self._ENSEMBLE_MERGE_ARRAY + + def distances_path(self, method: DistancesMethod) -> Path: + return self.base_dir / f"distances_{method}.npz" + + +class ClusteringPipelineConfig(BaseConfig): + """Configuration for submitting an ensemble of clustering runs to SLURM.""" + + clustering_run_config_path: Path = Field( + description="Path to ClusteringRunConfig file.", + ) + n_runs: PositiveInt = Field(description="Number of clustering runs in the ensemble") + distances_methods: list[DistancesMethod] = Field( + description="List of method(s) to use for calculating distances" + ) + base_output_dir: Path = Field( + default=SPD_CACHE_DIR / "clustering_pipeline", + description="Base directory for outputs of clustering ensemble pipeline runs.", + ) + slurm_job_name_prefix: str | None = Field( + default=None, description="Prefix for SLURM job names" + ) + slurm_partition: str | None = Field(default=None, description="SLURM partition to use") + wandb_project: str | None = Field( + default=None, + description="Weights & Biases project name (set to None to disable WandB logging)", + ) + wandb_entity: str = Field(default="goodfire", description="WandB entity (team/user) name") + create_git_snapshot: bool = Field( + default=False, description="Create a git snapshot for the run" + ) + + @model_validator(mode="after") + def validate_crc(self) -> "ClusteringPipelineConfig": + """Validate that exactly one of clustering_run_config_path points to a valid `ClusteringRunConfig`.""" + assert self.clustering_run_config_path.exists(), ( + f"clustering_run_config_path does not exist: {self.clustering_run_config_path}" + ) + # Try to load ClusteringRunConfig + assert ClusteringRunConfig.from_file(self.clustering_run_config_path) + + return self + + @field_validator("distances_methods") + @classmethod + def validate_distances_methods(cls, v: list[DistancesMethod]) -> list[DistancesMethod]: + """Validate that distances_methods is non-empty and contains valid methods.""" + assert all(method in DistancesMethod.__args__ for method in v), ( + f"Invalid distances_methods: {v}" + ) + + return v + + +def create_clustering_workspace_view(ensemble_id: str, project: str, entity: str) -> str: + """Create WandB workspace view for clustering runs. + + TODO: Use a template workspace which actually shows some panels + TODO: since the run_id here is the same as the wandb id, can we take advantage of that? + + Args: + ensemble_id: Unique identifier for this ensemble + project: WandB project name + entity: WandB entity (team/user) name + + Returns: + URL to workspace view + """ + workspace = ws.Workspace(entity=entity, project=project) + workspace.name = f"Clustering - {ensemble_id}" + + workspace.runset_settings.filters = [ + ws.Tags("tags").isin([f"ensemble_id:{ensemble_id}"]), + ] + + try: + workspace.save_as_new_view() + return workspace.url + except Exception as e: + logger.warning( + f"Failed to create WandB workspace view: {workspace=}, {workspace.name=}, {ensemble_id=}, {project=}, {entity=}, {e}" + ) + raise e + + +def generate_clustering_commands( + pipeline_config: ClusteringPipelineConfig, + pipeline_run_id: str, + dataset_streaming: bool = False, +) -> list[str]: + """Generate commands for each clustering run. + + Args: + pipeline_config: Pipeline configuration + pipeline_run_id: Pipeline run ID (each run will create its own ExecutionStamp) + dataset_streaming: Whether to use dataset streaming + + Returns: + List of shell-safe command strings + """ + commands: list[str] = [] + + for idx in range(pipeline_config.n_runs): + cmd_parts = [ + "python", + "spd/clustering/scripts/run_clustering.py", + "--config", + pipeline_config.clustering_run_config_path.as_posix(), + "--pipeline-run-id", + pipeline_run_id, + "--idx-in-ensemble", + str(idx), + "--wandb-project", + str(pipeline_config.wandb_project), + "--wandb-entity", + pipeline_config.wandb_entity, + ] + if dataset_streaming: + cmd_parts.append("--dataset-streaming") + + commands.append(shlex.join(cmd_parts)) + + return commands + + +def generate_calc_distances_commands( + pipeline_run_id: str, distances_methods: list[DistancesMethod] +) -> list[str]: + """Generate commands for calculating distances. + + Args: + pipeline_run_id: Pipeline run ID (will query registry for clustering runs) + distances_methods: List of methods for calculating distances + + Returns: + List of shell-safe command strings, one per method + """ + commands: list[str] = [] + for method in distances_methods: + commands.append( + shlex.join( + [ + "python", + "spd/clustering/scripts/calc_distances.py", + "--pipeline-run-id", + pipeline_run_id, + "--distances-method", + method, + ] + ) + ) + return commands + + +def main( + pipeline_config: ClusteringPipelineConfig, + local: bool = False, + local_clustering_parallel: bool = False, + local_calc_distances_parallel: bool = False, + dataset_streaming: bool = False, + track_resources_calc_distances: bool = False, +) -> None: + """Submit clustering runs to SLURM. + + Args: + pipeline_config_path: Path to ClusteringPipelineConfig file + n_runs: Number of clustering runs in the ensemble. Will override value in the config file. + """ + # setup + # ========================================================================================== + + logger.set_format("console", "terse") + + if local_clustering_parallel or local_calc_distances_parallel or track_resources_calc_distances: + assert local, ( + "local_clustering_parallel, local_calc_distances_parallel, track_resources_calc_distances " + "can only be set when running locally\n" + f"{local_clustering_parallel=}, {local_calc_distances_parallel=}, {track_resources_calc_distances=}, {local=}" + ) + + # Create ExecutionStamp for pipeline + execution_stamp: ExecutionStamp = ExecutionStamp.create( + run_type="ensemble", + create_snapshot=pipeline_config.create_git_snapshot, + ) + pipeline_run_id: str = execution_stamp.run_id + logger.info(f"Pipeline run ID: {pipeline_run_id}") + + # Initialize storage + storage = ClusteringPipelineStorage(execution_stamp) + logger.info(f"Pipeline output directory: {storage.base_dir}") + + # Save pipeline config + pipeline_config.to_file(storage.pipeline_config_path) + logger.info(f"Pipeline config saved to {storage.pipeline_config_path}") + + # Create WandB workspace if requested + if pipeline_config.wandb_project is not None: + workspace_url = create_clustering_workspace_view( + ensemble_id=pipeline_run_id, + project=pipeline_config.wandb_project, + entity=pipeline_config.wandb_entity, + ) + logger.info(f"WandB workspace: {workspace_url}") + + # Generate commands for clustering runs + clustering_commands = generate_clustering_commands( + pipeline_config=pipeline_config, + pipeline_run_id=pipeline_run_id, + dataset_streaming=dataset_streaming, + ) + + # Generate commands for calculating distances + calc_distances_commands = generate_calc_distances_commands( + pipeline_run_id=pipeline_run_id, + distances_methods=pipeline_config.distances_methods, + ) + + # Submit to SLURM + if local: + # submit clustering array job + run_script_array_local( + commands=clustering_commands, + parallel=local_clustering_parallel, + ) + + # submit calc_distances jobs in parallel + logger.info("Calculating distances...") + run_script_array_local( + commands=calc_distances_commands, + parallel=local_calc_distances_parallel, + track_resources=track_resources_calc_distances, + ) + + logger.section("complete!") + + # Build distances plot paths dict + distances_plots = { + f"distances via {method}": str(storage.plots_dir / f"distances_{method}.png") + for method in pipeline_config.distances_methods + } + + logger.values( + { + "Total clustering runs": len(clustering_commands), + "Pipeline run ID": pipeline_run_id, + "Pipeline output dir": str(storage.base_dir), + **distances_plots, + } + ) + + else: + assert pipeline_config.slurm_job_name_prefix is not None, ( + "must specify slurm_job_name_prefix if not running locally" + ) + assert pipeline_config.slurm_partition is not None, ( + "must specify slurm_partition if not running locally" + ) + with tempfile.TemporaryDirectory() as temp_dir: + # Submit clustering array job + clustering_script_path = Path(temp_dir) / f"clustering_{pipeline_run_id}.sh" + + create_slurm_array_script( + script_path=clustering_script_path, + job_name=f"{pipeline_config.slurm_job_name_prefix}_cluster", + commands=clustering_commands, + snapshot_branch=execution_stamp.snapshot_branch, + max_concurrent_tasks=pipeline_config.n_runs, # Run all concurrently + n_gpus_per_job=1, # Always 1 GPU per run + partition=pipeline_config.slurm_partition, + ) + array_job_id = submit_slurm_script(clustering_script_path) + + # Submit calc_distances jobs (one per method) with dependency on array job + calc_distances_job_ids: list[str] = [] + calc_distances_logs: list[str] = [] + + for _i, (method, cmd) in enumerate( + zip(pipeline_config.distances_methods, calc_distances_commands, strict=True) + ): + calc_distances_script_path = ( + Path(temp_dir) / f"calc_distances_{method}_{pipeline_run_id}.sh" + ) + + create_slurm_script( + script_path=calc_distances_script_path, + job_name=f"{pipeline_config.slurm_job_name_prefix}_dist_{method}", + command=cmd, + snapshot_branch=execution_stamp.snapshot_branch, + n_gpus=1, # Always 1 GPU for distances calculation + partition=pipeline_config.slurm_partition, + dependency_job_id=array_job_id, + ) + job_id = submit_slurm_script(calc_distances_script_path) + calc_distances_job_ids.append(job_id) + calc_distances_logs.append(f"~/slurm_logs/slurm-{job_id}.out") + + logger.section("Jobs submitted successfully!") + + # Build distances plot paths dict + distances_plots = { + method: str(storage.plots_dir / f"distances_{method}.png") + for method in pipeline_config.distances_methods + } + + logger.values( + { + "Clustering Array Job ID": array_job_id, + "Calc Distances Job IDs": ", ".join(calc_distances_job_ids), + "Total clustering runs": len(clustering_commands), + "Pipeline run ID": pipeline_run_id, + "Pipeline output dir": str(storage.base_dir), + "Clustering logs": f"~/slurm_logs/slurm-{array_job_id}_*.out", + "Calc Distances logs": ", ".join(calc_distances_logs), + } + ) + logger.info("Distances plots will be saved to:") + for method, path in distances_plots.items(): + logger.info(f" {method}: {path}") + + +def cli(): + """CLI for spd-cluster command.""" + parser = argparse.ArgumentParser( + prog="spd-cluster", + description="Submit clustering runs to SLURM. Arguments specified here will override the " + "corresponding value in the config file.", + ) + + parser.add_argument( + "--config", + required=True, + type=Path, + help="Path to pipeline config file", + ) + parser.add_argument( + "--n-runs", + type=int, + help="Number of clustering runs in the ensemble (overrides value in config file)", + ) + parser.add_argument( + "--wandb-project", + type=read_noneable_str, + default=_NO_ARG_PARSSED_SENTINEL, + help="WandB project name (if not provided, WandB logging is disabled)", + ) + parser.add_argument( + "--wandb-entity", + type=str, + default=None, + help="WandB entity name (user or team)", + ) + parser.add_argument( + "--distances-methods", + type=str, + default=None, + help="Comma-separated list of distance methods (e.g., 'perm_invariant_hamming,matching_dist')", + ) + parser.add_argument( + "--local", + action=argparse.BooleanOptionalAction, + default=False, + help="Run locally instead of submitting to SLURM (required if slurm_job_name_prefix and slurm_partition are None in config)", + ) + parser.add_argument( + "--local-clustering-parallel", + action="store_true", + help="If running locally, whether to run clustering runs in parallel", + ) + parser.add_argument( + "--local-calc-distances-parallel", + action="store_true", + help="If running locally, whether to run distance calculations in parallel", + ) + parser.add_argument( + "--track-resources-calc-distances", + action="store_true", + help="If running locally, whether to track resource usage during distance calculations", + ) + parser.add_argument( + "--dataset-streaming", + action="store_true", + help="Whether to use streaming dataset loading (if supported by the dataset). see https://github.com/goodfire-ai/spd/pull/199", + ) + + args = parser.parse_args() + + pipeline_config = ClusteringPipelineConfig.from_file(args.config) + overrides: dict[str, Any] = {} + + if args.n_runs is not None: + overrides["n_runs"] = args.n_runs + if args.wandb_project is not _NO_ARG_PARSSED_SENTINEL: + overrides["wandb_project"] = args.wandb_project + if args.wandb_entity is not None: + overrides["wandb_entity"] = args.wandb_entity + if args.distances_methods is not None: + # Parse comma-separated list of distance methods + methods = [method.strip() for method in args.distances_methods.split(",")] + overrides["distances_methods"] = methods + + pipeline_config = replace_pydantic_model(pipeline_config, overrides) + + main( + pipeline_config=pipeline_config, + local=args.local, + dataset_streaming=args.dataset_streaming, + local_clustering_parallel=args.local_clustering_parallel, + local_calc_distances_parallel=args.local_calc_distances_parallel, + track_resources_calc_distances=args.track_resources_calc_distances, + ) + + +if __name__ == "__main__": + cli() diff --git a/spd/clustering/storage.py b/spd/clustering/storage.py new file mode 100644 index 000000000..dc3d8765a --- /dev/null +++ b/spd/clustering/storage.py @@ -0,0 +1,19 @@ +"""Minimal storage base class for clustering - just path management.""" + +from pathlib import Path + +from spd.utils.run_utils import ExecutionStamp + + +class StorageBase: + """Base class for storage - provides ExecutionStamp and base directory. + + Subclasses define path constants (relative to base_dir) and set absolute paths in __init__. + Caller handles all actual saving and WandB uploading. + """ + + def __init__(self, execution_stamp: ExecutionStamp) -> None: + """Initialize storage with execution stamp.""" + self.execution_stamp: ExecutionStamp = execution_stamp + self.base_dir: Path = execution_stamp.out_dir + self.plots_dir: Path = self.base_dir / "plots" diff --git a/spd/clustering/util.py b/spd/clustering/util.py new file mode 100644 index 000000000..bd11e2fd4 --- /dev/null +++ b/spd/clustering/util.py @@ -0,0 +1,18 @@ +from collections.abc import Callable + + +def format_scientific_latex(value: float) -> str: + """Format a number in LaTeX scientific notation style.""" + if value == 0: + return r"$0$" + + import math + + exponent: int = int(math.floor(math.log10(abs(value)))) + mantissa: float = value / (10**exponent) + + return f"${mantissa:.2f} \\times 10^{{{exponent}}}$" + + +ModuleFilterSource = str | Callable[[str], bool] | set[str] | None +ModuleFilterFunc = Callable[[str], bool] diff --git a/spd/clustering/wandb_tensor_info.py b/spd/clustering/wandb_tensor_info.py new file mode 100644 index 000000000..14463ab88 --- /dev/null +++ b/spd/clustering/wandb_tensor_info.py @@ -0,0 +1,169 @@ +"""Minimal WandB tensor logging utilities using muutils.""" + +import warnings +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import wandb +import wandb.sdk.wandb_run +from muutils.dbg import dbg_tensor +from muutils.tensor_info import array_info +from torch import Tensor + + +def wandb_log_tensor( + run: wandb.sdk.wandb_run.Run, + data: Tensor | dict[str, Tensor], + name: str, + step: int, + single: bool = False, +) -> None: + """Log tensor(s) with stats to WandB as metrics and histograms. + + Args: + run: Current WandB run (None if WandB disabled) + data: Either a Tensor or dict[str, Tensor] + name: Name for logging + step: WandB step + single: True if this tensor is only logged once (component activations) + """ + try: + if isinstance(data, dict): + # Handle dict of tensors + for key, tensor in data.items(): + full_name: str = f"{name}.{key}" + _log_one(run, tensor, full_name, step, single=single) + else: + # Handle single tensor + _log_one(run, data, name, step, single=single) + except Exception as e: + warnings.warn(f"Failed to log tensor {name}: {e}") # noqa: B028 + dbg_tensor(data) + raise e + + +def _create_histogram( + info: dict[str, Any], tensor: Tensor, name: str, logy: bool = True +) -> plt.Figure: + """Create matplotlib histogram with stats markers.""" + # sanity check + if info["status"] != "ok" or info["size"] == 0: + fig: plt.Figure + ax: plt.Axes + fig, ax = plt.subplots(figsize=(8, 6)) + ax.text(0.5, 0.5, f"{info['status']}", ha="center", va="center") + ax.set_title(f"{name} - {info['status']}") + return fig + + # make basic hist + values: np.ndarray = tensor.flatten().detach().cpu().numpy() + if info["has_nans"]: + values = values[~np.isnan(values)] + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.hist(values, bins=50, alpha=0.7, edgecolor="black", linewidth=0.5) + + # Add stat lines + mean_val: float = info["mean"] or float("nan") + median_val: float = info["median"] or float("nan") + std_val: float = info["std"] or float("nan") + + if info["mean"] is not None: + ax.axvline( + mean_val, + color="red", + linestyle="-", + linewidth=2, + label="$\\mu$", + ) + ax.axvline( + median_val, + color="blue", + linestyle="-", + linewidth=2, + label="$\\tilde{x}$", + ) + if std_val: + ax.axvline( + mean_val + std_val, + color="orange", + linestyle="--", + linewidth=1.5, + alpha=0.8, + label="$\\mu+\\sigma$", + ) + ax.axvline( + mean_val - std_val, + color="orange", + linestyle="--", + linewidth=1.5, + alpha=0.8, + label="$\\mu-\\sigma$", + ) + + # Build informative title with tensor stats + shape_str: str = str(tuple(info["shape"])) if "shape" in info else "unknown" + dtype_str: str = str(info.get("dtype", "unknown")).replace("torch.", "") + + title_line1: str = f"{name}" + title_line2: str = f"shape={shape_str}, dtype={dtype_str}" + title_line3: str = ( + f"range=[{info['min']:.3g}, {info['max']:.3g}], " + f"$\\mu$={mean_val:.3g}, $\\tilde{{x}}$={median_val:.3g}, $\\sigma$={std_val:.3g}" + ) + + # Combine into multi-line title + full_title: str = f"{title_line1}\n{title_line2}\n{title_line3}" + ax.set_title(full_title, fontsize=10) + ax.set_xlabel("Value") + ax.set_ylabel("Count") + ax.legend() + ax.grid(True, alpha=0.3) + if logy: + ax.set_yscale("log") + + plt.tight_layout() + return fig + + +def _log_one( + run: wandb.sdk.wandb_run.Run, + tensor_: Tensor, + name: str, + step: int, + single: bool = False, + # use_log_counts: bool = True, +) -> None: + """Log a single tensor.""" + info: dict[str, Any] = array_info(tensor_) + + if single: + # For single-use logging, log a single histogram as a figure + hist_fig: plt.Figure = _create_histogram(info=info, tensor=tensor_, name=name) + histogram_key: str = f"single_hists/{name}" + run.log({histogram_key: wandb.Image(hist_fig)}, step=step) + plt.close(hist_fig) # Close figure to free memory + else: + # Log numeric stats as metrics (viewable like loss) using dict comprehension + stats_to_log: dict[str, float | wandb.Histogram] = { + f"tensor_metrics/{name}/{key}": info[key] + for key in ["mean", "std", "median", "min", "max"] + if key in info and info[key] is not None + } + + # For regular logging, use wandb.Histogram directly + hist_key: str = f"tensor_histograms/{name}" + stats_to_log[hist_key] = wandb.Histogram(tensor_.flatten().cpu().numpy()) # pyright: ignore[reportArgumentType] + + # Add nan_percent if present + nan_percent: float | None = info["nan_percent"] + # TODO: this is a hack for when the tensor is empty + if nan_percent is None: + dbg_tensor(tensor_) + nan_percent = float("nan") + if nan_percent > 0: + stats_to_log[f"tensor_metrics/{name}/nan_percent"] = nan_percent + + if stats_to_log: + run.log(stats_to_log, step=step) diff --git a/spd/experiments/ih/ih_decomposition.py b/spd/experiments/ih/ih_decomposition.py index 399b08783..f93836ca1 100644 --- a/spd/experiments/ih/ih_decomposition.py +++ b/spd/experiments/ih/ih_decomposition.py @@ -43,7 +43,7 @@ def main( logger.info(f"Using device: {device}") if config.wandb_project: - tags = ["ih"] + tags = ["induction_head"] if evals_id: tags.append(evals_id) if sweep_id: diff --git a/spd/identity_insertion.py b/spd/identity_insertion.py index 2859b7700..dcad69e81 100644 --- a/spd/identity_insertion.py +++ b/spd/identity_insertion.py @@ -42,12 +42,8 @@ def insert_identity_operations_(target_model: nn.Module, identity_patterns: list identity_patterns: Patterns matching modules to prepend identity ops to """ - if is_main_process(): - logger.info(f"Inserting identity operations before {len(identity_patterns)} modules") - identity_module_paths = get_target_module_paths(target_model, identity_patterns) - # Add identity layers and hooks for module_path in identity_module_paths: module = target_model.get_submodule(module_path) @@ -61,7 +57,7 @@ def insert_identity_operations_(target_model: nn.Module, identity_patterns: list case _: raise ValueError(f"Module {module} not supported. type: {type(module)}") - module.pre_identity = Identity(d_in) # type: ignore + module.pre_identity = Identity(d_in) module.register_forward_pre_hook(pre_id_hook, with_kwargs=True) if is_main_process(): diff --git a/spd/models/component_model.py b/spd/models/component_model.py index e95e150f0..a5b00a2ee 100644 --- a/spd/models/component_model.py +++ b/spd/models/component_model.py @@ -324,6 +324,7 @@ def __call__( def __call__(self, *args: Any, **kwargs: Any) -> Tensor | OutputWithCache: return super().__call__(*args, **kwargs) + # TODO: why doesnt this have overrides??? @override def forward( self, diff --git a/spd/scripts/run.py b/spd/scripts/run.py index 921ab1146..1bd51088f 100644 --- a/spd/scripts/run.py +++ b/spd/scripts/run.py @@ -11,9 +11,7 @@ import copy import json import shlex -import subprocess import tempfile -from datetime import datetime from hashlib import sha256 from pathlib import Path from typing import Any, Final @@ -24,17 +22,17 @@ from spd.log import LogFormat, logger from spd.registry import EXPERIMENT_REGISTRY, get_max_expected_runtime from spd.settings import REPO_ROOT -from spd.utils.git_utils import create_git_snapshot, repo_current_branch -from spd.utils.run_utils import apply_nested_updates, generate_grid_combinations, generate_run_name -from spd.utils.slurm_utils import create_slurm_array_script, submit_slurm_array +from spd.utils.command_utils import run_script_array_local +from spd.utils.run_utils import ( + ExecutionStamp, + apply_nested_updates, + generate_grid_combinations, + generate_run_name, +) +from spd.utils.slurm_utils import create_slurm_array_script, submit_slurm_script from spd.utils.wandb_utils import wandb_setup -def generate_run_id() -> str: - """Generate a unique run ID based on timestamp.""" - return f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}" - - def resolve_sweep_params_path(sweep_params_file: str) -> Path: """Resolve the full path to the sweep parameters file.""" if "/" not in sweep_params_file: @@ -114,12 +112,6 @@ def _choose_master_port(run_id_local: str, idx: int) -> int: return base + (h % span) -def _build_mpi_prefix(run_id: str, idx: int, dp: int) -> str: - """Build an MPI prefix for a command.""" - port: int = _choose_master_port(run_id, idx) - return f"MASTER_PORT={port} mpirun -x MASTER_PORT -np {dp} " - - def generate_commands( experiments_list: list[str], run_id: str, @@ -147,7 +139,6 @@ def generate_commands( for experiment in experiments_list: exp_config = EXPERIMENT_REGISTRY[experiment] - # Load base config base_config = Config.from_file(exp_config.config_path) if sweep_params_path is None: @@ -158,14 +149,24 @@ def generate_commands( config_json = f"json:{json.dumps(config_with_overrides.model_dump(mode='json'))}" - mpi_prefix = _build_mpi_prefix(run_id, cmd_idx, dp) if dp > 1 else "" - - command = ( - f"{mpi_prefix}python {exp_config.decomp_script} --config_json '{config_json}' " - f"--sweep_id {run_id} --evals_id {experiment}" - ) + cmd_parts = [ + "python", + str(exp_config.decomp_script), + "--config_json", + config_json, + "--sweep_id", + run_id, + "--evals_id", + experiment, + ] + + if dp > 1: + port = _choose_master_port(run_id, cmd_idx) + cmd = f"MASTER_PORT={port} {shlex.join(['mpirun', '-x', 'MASTER_PORT', '-np', str(dp)] + cmd_parts)}" + else: + cmd = shlex.join(cmd_parts) - commands.append(command) + commands.append(cmd) task_breakdown[experiment] = "1 task" cmd_idx += 1 @@ -186,15 +187,26 @@ def generate_commands( config_json = f"json:{json.dumps(config_with_overrides.model_dump(mode='json'))}" sweep_params_json = f"json:{json.dumps(sweep_params)}" - mpi_prefix = _build_mpi_prefix(run_id, cmd_idx, dp) if dp > 1 else "" - command = ( - f"{mpi_prefix}python {exp_config.decomp_script} --config_json '{config_json}' " - f"--sweep_id {run_id} " - f"--evals_id {experiment} " - f"--sweep_params_json '{sweep_params_json}'" - ) - - commands.append(command) + cmd_parts = [ + "python", + str(exp_config.decomp_script), + "--config_json", + config_json, + "--sweep_id", + run_id, + "--evals_id", + experiment, + "--sweep_params_json", + sweep_params_json, + ] + + if dp > 1: + port = _choose_master_port(run_id, cmd_idx) + cmd = f'MASTER_PORT={port} mpirun -x "MASTER_PORT" -np {dp} {shlex.join(cmd_parts)}' + else: + cmd = shlex.join(cmd_parts) + + commands.append(cmd) cmd_idx += 1 # Print first combination as example @@ -208,35 +220,6 @@ def generate_commands( return commands -def run_commands_locally(commands: list[str]) -> None: - """Execute commands locally in sequence. - - Args: - commands: List of shell commands to execute - """ - - logger.section(f"LOCAL EXECUTION: Running {len(commands)} tasks") - - for i, command in enumerate(commands, 1): - # Parse command into arguments - args = shlex.split(command) - - # Extract experiment name from script path for cleaner output - script_name = args[1].split("/")[-1] - logger.section(f"[{i}/{len(commands)}] Executing: {script_name}...") - - result = subprocess.run(args) - - if result.returncode != 0: - logger.warning( - f"[{i}/{len(commands)}] ⚠️ Warning: Command failed with exit code {result.returncode}" - ) - else: - logger.info(f"[{i}/{len(commands)}] ✓ Completed successfully") - - logger.section("LOCAL EXECUTION COMPLETE") - - def get_experiments( experiments: str | None = None, ) -> list[str]: @@ -338,7 +321,11 @@ def main( logger.set_format("console", log_format) # Determine run id - run_id: str = generate_run_id() + execution_stamp: ExecutionStamp = ExecutionStamp.create( + run_type="spd", + create_snapshot=create_snapshot, + ) + run_id: str = execution_stamp.run_id logger.info(f"Run ID: {run_id}") # Determine the sweep parameters file @@ -365,18 +352,6 @@ def main( # ========================================================================================== if not local or use_wandb: - # set up snapshot branch and commit hash - snapshot_branch: str - commit_hash: str - - if create_snapshot: - snapshot_branch, commit_hash = create_git_snapshot(branch_name_prefix="run") - logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") - else: - snapshot_branch = repo_current_branch() - commit_hash = "none" - logger.info(f"Using current branch: {snapshot_branch}") - # set up wandb if use_wandb: wandb_setup( @@ -386,8 +361,8 @@ def main( create_report=create_report, # if `create_report == False`, the rest of the arguments don't matter report_title=report_title, - snapshot_branch=snapshot_branch, - commit_hash=commit_hash, + snapshot_branch=execution_stamp.snapshot_branch, + commit_hash=execution_stamp.commit_hash, include_run_comparer=sweep_params_file is not None, ) else: @@ -410,7 +385,7 @@ def main( ) if local: - run_commands_locally(commands) + run_script_array_local(commands) else: # Submit to SLURM with tempfile.TemporaryDirectory() as temp_dir: @@ -427,14 +402,13 @@ def main( script_path=array_script, job_name=job_name, commands=commands, - # again -- local is false, so snapshot_branch will exist - snapshot_branch=snapshot_branch, # pyright: ignore[reportPossiblyUnboundVariable] + snapshot_branch=execution_stamp.snapshot_branch, max_concurrent_tasks=n_agents, n_gpus_per_job=n_gpus_per_job, partition=partition, ) - array_job_id = submit_slurm_array(array_script) + array_job_id = submit_slurm_script(array_script) logger.section("Job submitted successfully!") logger.values( diff --git a/spd/spd_types.py b/spd/spd_types.py index 07249f7f7..d348f5e8f 100644 --- a/spd/spd_types.py +++ b/spd/spd_types.py @@ -1,7 +1,8 @@ from pathlib import Path from typing import Annotated, Literal -from pydantic import BeforeValidator, Field, PlainSerializer +from annotated_types import Ge, Le +from pydantic import BeforeValidator, PlainSerializer from spd.settings import REPO_ROOT @@ -45,6 +46,6 @@ def validate_path(v: str | Path) -> str | Path: ] -Probability = Annotated[float, Field(strict=True, ge=0, le=1)] +Probability = Annotated[float, Ge(0), Le(1)] -TaskName = Literal["tms", "resid_mlp", "lm", "ih"] +TaskName = Literal["tms", "resid_mlp", "lm", "induction_head"] diff --git a/spd/utils/command_utils.py b/spd/utils/command_utils.py new file mode 100644 index 000000000..b6b74f3b3 --- /dev/null +++ b/spd/utils/command_utils.py @@ -0,0 +1,113 @@ +"""Minimal utilities for running shell-safe commands locally.""" + +import subprocess +import tempfile +from pathlib import Path + +from spd.log import logger + + +def run_script_array_local( + commands: list[str], parallel: bool = False, track_resources: bool = False +) -> dict[str, dict[str, float]] | None: + """Run multiple shell-safe command strings locally. + + Args: + commands: List of shell-safe command strings (built with shlex.join()) + parallel: If True, run all commands in parallel. If False, run sequentially. + track_resources: If True, track and return resource usage for each command using /usr/bin/time. + + Returns: + If track_resources is True, returns dict mapping commands to resource metrics dict. + Resource metrics include: K (avg memory KB), M (max memory KB), P (CPU %), + S (system CPU sec), U (user CPU sec), e (wall time sec). + Otherwise returns None. + """ + n_commands = len(commands) + resources: dict[str, dict[str, float]] = {} + resource_files: list[Path] = [] + + # Wrap commands with /usr/bin/time if resource tracking is requested + if track_resources: + wrapped_commands: list[str] = [] + for cmd in commands: + resource_file = Path(tempfile.mktemp(suffix=".resources")) # pyright: ignore[reportDeprecated] + resource_files.append(resource_file) + # Use /usr/bin/time to track comprehensive resource usage + # K=avg total mem, M=max resident, P=CPU%, S=system time, U=user time, e=wall time + wrapped_cmd = ( + f'/usr/bin/time -f "K:%K M:%M P:%P S:%S U:%U e:%e" -o {resource_file} {cmd}' + ) + wrapped_commands.append(wrapped_cmd) + commands_to_run = wrapped_commands + else: + commands_to_run = commands + + try: + if not parallel: + logger.section(f"LOCAL EXECUTION: Running {n_commands} tasks serially") + for i, cmd in enumerate(commands_to_run, 1): + logger.info(f"[{i}/{n_commands}] Running: {commands[i - 1]}") + subprocess.run(cmd, shell=True, check=True) + logger.section("LOCAL EXECUTION COMPLETE") + else: + logger.section(f"LOCAL EXECUTION: Starting {n_commands} tasks in parallel") + procs: list[subprocess.Popen[bytes]] = [] + + for i, cmd in enumerate(commands_to_run, 1): + logger.info(f"[{i}/{n_commands}] Starting: {commands[i - 1]}") + proc = subprocess.Popen(cmd, shell=True) + procs.append(proc) + + logger.section("WAITING FOR ALL TASKS TO COMPLETE") + for proc, cmd in zip(procs, commands, strict=True): # noqa: B007 + proc.wait() + if proc.returncode != 0: + logger.error(f"Process {proc.pid} failed with exit code {proc.returncode}") + logger.section("LOCAL EXECUTION COMPLETE") + + # Read resource usage results + if track_resources: + for cmd, resource_file in zip(commands, resource_files, strict=True): + if resource_file.exists(): + # Parse format: "K:123 M:456 P:78% S:1.23 U:4.56 e:7.89" + output = resource_file.read_text().strip() + metrics: dict[str, float] = {} + + for part in output.split(): + if ":" in part: + key, value = part.split(":", 1) + # Remove % sign from CPU percentage + value = value.rstrip("%") + try: + metrics[key] = float(value) + except ValueError: + logger.warning(f"Could not parse {key}:{value} for command: {cmd}") + + resources[cmd] = metrics + else: + logger.warning(f"Resource file not found for: {cmd}") + + # Log comprehensive resource usage table + logger.section("RESOURCE USAGE RESULTS") + for cmd, metrics in resources.items(): + logger.info(f"Command: {cmd}") + logger.info( + f" Time: {metrics.get('e', 0):.2f}s wall, " + f"{metrics.get('U', 0):.2f}s user, " + f"{metrics.get('S', 0):.2f}s system" + ) + logger.info( + f" Memory: {metrics.get('M', 0) / 1024:.1f} MB peak, " + f"{metrics.get('K', 0) / 1024:.1f} MB avg" + ) + logger.info(f" CPU: {metrics.get('P', 0):.1f}%") + + finally: + # Clean up temp files + if track_resources: + for resource_file in resource_files: + if resource_file.exists(): + resource_file.unlink() + + return resources if track_resources else None diff --git a/spd/utils/git_utils.py b/spd/utils/git_utils.py index d21bf240e..b9c0cf370 100644 --- a/spd/utils/git_utils.py +++ b/spd/utils/git_utils.py @@ -1,6 +1,5 @@ """Git utilities for creating code snapshots.""" -import datetime import subprocess import tempfile from pathlib import Path @@ -30,7 +29,32 @@ def repo_current_branch() -> str: return result.stdout.strip() -def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: +def repo_is_clean(catch_except_as_false: bool = False) -> bool: + """Return True if the current git repository has no uncommitted or untracked changes. + + # TODO: this may error in CI environments: https://github.com/goodfire-ai/spd/actions/runs/18560369066/job/52907611203 + `fatal: detected dubious ownership in repository at '/__w/spd/spd'` + + for now, if `catch_except_as_false` is True, we catch any exceptions and return False. + + """ + try: + status: str = subprocess.check_output(["git", "status", "--porcelain"], text=True).strip() + return status == "" + except Exception as e: + if catch_except_as_false: + return False + else: + raise e + + +def repo_current_commit_hash() -> str: + """Return the current commit hash of the active HEAD.""" + commit_hash: str = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + return commit_hash + + +def create_git_snapshot(run_id: str) -> tuple[str, str]: """Create a git snapshot branch with current changes. Creates a timestamped branch containing all current changes (staged and unstaged). Uses a @@ -44,13 +68,12 @@ def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: Raises: subprocess.CalledProcessError: If git commands fail (except for push) """ - # Generate timestamped branch name - timestamp_utc = datetime.datetime.now(datetime.UTC).strftime("%Y%m%d-%H%M%S") - snapshot_branch = f"{branch_name_prefix}-{timestamp_utc}" + # prefix branch name + snapshot_branch: str = f"snapshot/{run_id}" # Create temporary worktree path with tempfile.TemporaryDirectory() as temp_dir: - worktree_path = Path(temp_dir) / f"spd-snapshot-{timestamp_utc}" + worktree_path = Path(temp_dir) / f"spd-snapshot-{run_id}" try: # Create worktree with new branch @@ -87,7 +110,7 @@ def create_git_snapshot(branch_name_prefix: str) -> tuple[str, str]: # Commit changes if any exist if diff_result.returncode != 0: # Non-zero means there are changes subprocess.run( - ["git", "commit", "-m", f"Sweep snapshot {timestamp_utc}", "--no-verify"], + ["git", "commit", "-m", f"run id {run_id}", "--no-verify"], cwd=worktree_path, check=True, capture_output=True, diff --git a/spd/utils/run_utils.py b/spd/utils/run_utils.py index 89479f46f..db51affe3 100644 --- a/spd/utils/run_utils.py +++ b/spd/utils/run_utils.py @@ -6,13 +6,20 @@ import secrets import string from pathlib import Path -from typing import Any +from typing import Any, Final, Literal, NamedTuple import torch import wandb import yaml +from spd.log import logger from spd.settings import SPD_CACHE_DIR +from spd.utils.git_utils import ( + create_git_snapshot, + repo_current_branch, + repo_current_commit_hash, + repo_is_clean, +) # Fields that use discriminated union merging: field_name -> discriminator_field _DISCRIMINATED_LIST_FIELDS: dict[str, str] = { @@ -37,6 +44,7 @@ def get_local_run_id() -> str: return f"local-{random_suffix}" +# TODO: avoid using this function? def get_output_dir(use_wandb_id: bool = True) -> Path: """Get the output directory for a run. @@ -465,3 +473,89 @@ def generate_run_name(params: dict[str, Any]) -> str: parts.append(f"{param}-{value}") return "-".join(parts) + + +RunType = Literal["spd", "cluster", "ensemble"] + +RUN_TYPE_ABBREVIATIONS: Final[dict[RunType, str]] = { + "spd": "s", + "cluster": "c", + "ensemble": "e", +} + + +# TODO: This doesnt work in pytest but would in general be nice to enforce. hmm. +# _CREATED_RUN_ID: bool = False + + +class ExecutionStamp(NamedTuple): + run_id: str + snapshot_branch: str + commit_hash: str + run_type: RunType + + @staticmethod + def _generate_run_id(run_type: RunType) -> str: + """Generate a unique run identifier, + + Format: `{type_abbr}-{random_hex}` + """ + # global _CREATED_RUN_ID + # if _CREATED_RUN_ID: + # raise RuntimeError( + # "Run ID has already been generated for this process! You can only call this once." + # ) + type_abbr: str = RUN_TYPE_ABBREVIATIONS[run_type] + random_hex: str = secrets.token_hex(4) + # _CREATED_RUN_ID = True + return f"{type_abbr}-{random_hex}" + + @classmethod + def create( + cls, + run_type: RunType, + create_snapshot: bool, + ) -> "ExecutionStamp": + """create an execution stamp, possibly including a git snapshot branch""" + + run_id: str = ExecutionStamp._generate_run_id(run_type) + snapshot_branch: str + commit_hash: str + + if create_snapshot: + snapshot_branch, commit_hash = create_git_snapshot(run_id=run_id) + logger.info(f"Created git snapshot branch: {snapshot_branch} ({commit_hash[:8]})") + else: + snapshot_branch = repo_current_branch() + if repo_is_clean(catch_except_as_false=True): + commit_hash = repo_current_commit_hash() + logger.info(f"Using current branch: {snapshot_branch} ({commit_hash[:8]})") + else: + commit_hash = "none" + logger.info( + f"Using current branch: {snapshot_branch} (unpushed changes, no commit hash)" + ) + + return ExecutionStamp( + run_id=run_id, + snapshot_branch=snapshot_branch, + commit_hash=commit_hash, + run_type=run_type, + ) + + @property + def out_dir(self) -> Path: + """Get the output directory for this execution stamp.""" + run_dir = SPD_CACHE_DIR / self.run_type / self.run_id + run_dir.mkdir(parents=True, exist_ok=True) + return run_dir + + +_NO_ARG_PARSSED_SENTINEL = object() + + +def read_noneable_str(value: str) -> str | None: + """Read a string that may be 'None' and convert to None.""" + if value == "None": + return None + return value diff --git a/spd/utils/slurm_utils.py b/spd/utils/slurm_utils.py index a5a9426ef..b9290061f 100644 --- a/spd/utils/slurm_utils.py +++ b/spd/utils/slurm_utils.py @@ -22,57 +22,42 @@ def format_runtime_str(runtime_minutes: int) -> str: return f"{hours}h{minutes}m" if hours > 0 else f"{minutes}m" -def create_slurm_array_script( +def _create_slurm_script_base( script_path: Path, job_name: str, - commands: list[str], snapshot_branch: str, - n_gpus_per_job: int, + n_gpus: int, partition: str, - time_limit: str = "72:00:00", - max_concurrent_tasks: int | None = None, + time_limit: str, + sbatch_directives: str, + work_dir_suffix: str, + command_block: str, ) -> None: - """Create a SLURM job array script with git snapshot for consistent code. + """Create a SLURM script with git snapshot for consistent code. Args: script_path: Path where the script should be written - job_name: Name for the SLURM job array - commands: List of commands to execute in each array job - snapshot_branch: Git branch to checkout. - n_gpus_per_job: Number of GPUs per job. If 0, use CPU jobs. - time_limit: Time limit for each job (default: 72:00:00) - max_concurrent_tasks: Maximum number of array tasks to run concurrently. If None, no limit. + job_name: Name for the SLURM job + snapshot_branch: Git branch to checkout + n_gpus: Number of GPUs. If 0, use CPU jobs. + partition: SLURM partition to use + time_limit: Time limit for the job + sbatch_directives: Additional SBATCH directives (e.g. --array, --dependency, --output) + work_dir_suffix: Suffix for the working directory (e.g. "${SLURM_JOB_ID}" or "${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}") + command_block: The command(s) to execute """ - - slurm_logs_dir = Path.home() / "slurm_logs" - slurm_logs_dir.mkdir(exist_ok=True) - - # Create array range (SLURM arrays are 1-indexed) - if max_concurrent_tasks is not None: - array_range = f"1-{len(commands)}%{max_concurrent_tasks}" - else: - array_range = f"1-{len(commands)}" - - # Create case statement for commands - case_statements = [] - for i, command in enumerate(commands, 1): - case_statements.append(f"{i}) {command} ;;") - - case_block = "\n ".join(case_statements) - script_content = textwrap.dedent(f""" #!/bin/bash #SBATCH --nodes=1 - #SBATCH --gres=gpu:{n_gpus_per_job} + #SBATCH --gres=gpu:{n_gpus} #SBATCH --partition={partition} #SBATCH --time={time_limit} #SBATCH --job-name={job_name} - #SBATCH --array={array_range} + {sbatch_directives} #SBATCH --distribution=pack - #SBATCH --output={slurm_logs_dir}/slurm-%A_%a.out # Create job-specific working directory - WORK_DIR="/tmp/spd-gf-copy-${{SLURM_ARRAY_JOB_ID}}_${{SLURM_ARRAY_TASK_ID}}" + WORK_DIR="/tmp/spd-gf-copy-{work_dir_suffix}" # Clone the repository to the job-specific directory git clone {REPO_ROOT} $WORK_DIR @@ -93,38 +78,123 @@ def create_slurm_array_script( uv sync --no-dev --link-mode copy -q source .venv/bin/activate - # Execute the appropriate command based on array task ID - case $SLURM_ARRAY_TASK_ID in - {case_block} - esac + {command_block} """).strip() with open(script_path, "w") as f: f.write(script_content) - # Make script executable script_path.chmod(0o755) -def submit_slurm_array(script_path: Path) -> str: - """Submit a SLURM job array and return the array job ID. +def create_slurm_array_script( + script_path: Path, + job_name: str, + commands: list[str], + snapshot_branch: str, + n_gpus_per_job: int, + partition: str, + time_limit: str = "72:00:00", + max_concurrent_tasks: int | None = None, +) -> None: + """Create a SLURM job array script with git snapshot for consistent code. Args: - script_path: Path to SLURM batch script + script_path: Path where the script should be written + job_name: Name for the SLURM job array + commands: List of shell-safe command strings (built with shlex.join()) + snapshot_branch: Git branch to checkout. + n_gpus_per_job: Number of GPUs per job. If 0, use CPU jobs. + partition: SLURM partition to use + time_limit: Time limit for each job (default: 72:00:00) + max_concurrent_tasks: Maximum number of array tasks to run concurrently. If None, no limit. + """ + slurm_logs_dir = Path.home() / "slurm_logs" + slurm_logs_dir.mkdir(exist_ok=True) - Returns: - Array job ID from submitted job array + # Create array range (SLURM arrays are 1-indexed) + if max_concurrent_tasks is not None: + array_range = f"1-{len(commands)}%{max_concurrent_tasks}" + else: + array_range = f"1-{len(commands)}" + + # Create case statement for commands + case_statements = [] + for i, cmd in enumerate(commands, 1): + case_statements.append(f"{i}) {cmd} ;;") + + case_block = "\n ".join(case_statements) + + sbatch_directives = f"""#SBATCH --array={array_range} + #SBATCH --output={slurm_logs_dir}/slurm-%A_%a.out""" + + command_block = f"""# Execute the appropriate command based on array task ID + case $SLURM_ARRAY_TASK_ID in + {case_block} + esac""" + + _create_slurm_script_base( + script_path=script_path, + job_name=job_name, + snapshot_branch=snapshot_branch, + n_gpus=n_gpus_per_job, + partition=partition, + time_limit=time_limit, + sbatch_directives=sbatch_directives, + work_dir_suffix="${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}", + command_block=command_block, + ) + + +def create_slurm_script( + script_path: Path, + job_name: str, + command: str, + snapshot_branch: str, + n_gpus: int, + partition: str, + time_limit: str = "72:00:00", + dependency_job_id: str | None = None, +) -> None: + """Create a SLURM job script with git snapshot for consistent code. + + Args: + script_path: Path where the script should be written + job_name: Name for the SLURM job + command: Shell-safe command string (built with shlex.join()) + snapshot_branch: Git branch to checkout + n_gpus: Number of GPUs. If 0, use CPU job. + partition: SLURM partition to use + time_limit: Time limit for the job (default: 72:00:00) + dependency_job_id: Optional job ID to depend on (uses afterok) """ - result = subprocess.run( - ["sbatch", str(script_path)], capture_output=True, text=True, check=True + slurm_logs_dir = Path.home() / "slurm_logs" + slurm_logs_dir.mkdir(exist_ok=True) + + # Build SBATCH directives + directives = [f"#SBATCH --output={slurm_logs_dir}/slurm-%j.out"] + if dependency_job_id is not None: + directives.append(f"#SBATCH --dependency=afterok:{dependency_job_id}") + + sbatch_directives = "\n ".join(directives) + + command_block = f"# Execute the command\n {command}" + + _create_slurm_script_base( + script_path=script_path, + job_name=job_name, + snapshot_branch=snapshot_branch, + n_gpus=n_gpus, + partition=partition, + time_limit=time_limit, + sbatch_directives=sbatch_directives, + work_dir_suffix="${SLURM_JOB_ID}", + command_block=command_block, ) - # Extract job ID from sbatch output (format: "Submitted batch job 12345") - job_id = result.stdout.strip().split()[-1] - return job_id -def submit_slurm_job(script_path: Path) -> str: - """Submit a SLURM job and return the job ID. +def submit_slurm_script(script_path: Path) -> str: + """Submit a SLURM job (array or single) and return the job ID. Args: script_path: Path to SLURM batch script diff --git a/spd/utils/wandb_utils.py b/spd/utils/wandb_utils.py index e2cffeec5..d883785d5 100644 --- a/spd/utils/wandb_utils.py +++ b/spd/utils/wandb_utils.py @@ -7,9 +7,9 @@ import wandb_workspaces.reports.v2 as wr import wandb_workspaces.workspaces as ws from dotenv import load_dotenv -from pydantic import BaseModel from wandb.apis.public import File, Run +from spd.base_config import BaseConfig from spd.log import logger from spd.registry import EXPERIMENT_REGISTRY from spd.settings import REPO_ROOT @@ -137,7 +137,7 @@ def download_wandb_file(run: Run, wandb_run_dir: Path, file_name: str) -> Path: return path -def init_wandb[T_config: BaseModel]( +def init_wandb[T_config: BaseConfig]( config: T_config, project: str, name: str | None = None, tags: list[str] | None = None ) -> T_config: """Initialize Weights & Biases and return a config updated with sweep hyperparameters. @@ -153,6 +153,7 @@ def init_wandb[T_config: BaseModel]( """ load_dotenv(override=True) + # TODO: pass run id from ExecutionStamp wandb.init( project=project, entity=os.getenv("WANDB_ENTITY"), diff --git a/tests/clustering/math/test_perm_invariant_hamming.py b/tests/clustering/math/test_perm_invariant_hamming.py new file mode 100644 index 000000000..7d2bf4740 --- /dev/null +++ b/tests/clustering/math/test_perm_invariant_hamming.py @@ -0,0 +1,123 @@ +from itertools import permutations + +import numpy as np +import pytest + +from spd.clustering.math.perm_invariant_hamming import perm_invariant_hamming_matrix + +# pyright complains about the types when calling perm_invariant_hamming +# pyright: reportCallIssue=false + + +def brute_force_min_hamming(a: np.ndarray, b: np.ndarray) -> int: + """Exhaustive check for small k.""" + k = int(max(a.max(), b.max()) + 1) + best = len(a) + for perm in permutations(range(k)): + mapping = np.array(perm) + best = min(best, int((mapping[a] != b).sum())) + return best + + +def test_identity() -> None: + """a == b should give distance 0.""" + a = np.array([0, 1, 2, 1, 0]) + b = a.copy() + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + # Distance between row 1 and row 0 should be 0 + assert D[1, 0] == 0 + + +def test_all_one_group() -> None: + """All rows belong to one group in both arrays (possibly different labels).""" + a = np.zeros(10, dtype=int) + b = np.ones(10, dtype=int) # different label but identical grouping + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 + + +def test_permuted_labels() -> None: + a = np.array([0, 2, 1, 1, 0]) + b = np.array([1, 0, 0, 2, 1]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 1 + + +def test_swap_two_labels() -> None: + a = np.array([0, 0, 1, 1]) + b = np.array([1, 1, 0, 0]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 + + +def test_random_small_bruteforce() -> None: + rng = np.random.default_rng(0) + for _ in range(50): + n = 7 + k = 3 + a = rng.integers(0, k, size=n) + b = rng.integers(0, k, size=n) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + d_alg = D[1, 0] + d_true = brute_force_min_hamming(a, b) + assert d_alg == d_true + + +def test_shape_mismatch() -> None: + a = np.array([0, 1, 2]) + b = np.array([0, 1]) + with pytest.raises((ValueError, IndexError)): + # This should fail when trying to create the matrix due to shape mismatch + X = np.array([a, b]) + perm_invariant_hamming_matrix(X) + + +def test_matrix_multiple_pairs() -> None: + """Test the matrix function with multiple label vectors.""" + a = np.array([0, 0, 1, 1]) + b = np.array([2, 2, 3, 3]) # Should be distance 0 (perfect mapping) + c = np.array([0, 1, 0, 1]) # Should be distance 2 from both a and b + X = np.array([a, b, c]) + D = perm_invariant_hamming_matrix(X) + + assert D[1, 0] == 0 # a and b should have distance 0 + assert D[2, 0] == 2 # a and c should have distance 2 + assert D[2, 1] == 2 # b and c should have distance 2 + + +def test_matrix_upper_triangle_nan() -> None: + """Test that upper triangle and diagonal are NaN.""" + a = np.array([0, 1, 0]) + b = np.array([1, 0, 1]) + c = np.array([0, 0, 1]) + X = np.array([a, b, c]) + D = perm_invariant_hamming_matrix(X) + + # Diagonal should be NaN + assert np.isnan(D[0, 0]) + assert np.isnan(D[1, 1]) + assert np.isnan(D[2, 2]) + + # Upper triangle should be NaN + assert np.isnan(D[0, 1]) + assert np.isnan(D[0, 2]) + assert np.isnan(D[1, 2]) + + # Lower triangle should have actual distances + assert not np.isnan(D[1, 0]) + assert not np.isnan(D[2, 0]) + assert not np.isnan(D[2, 1]) + + +def test_unused_labels() -> None: + """Test when arrays don't use all labels 0..k-1.""" + a = np.array([0, 0, 3, 3]) # skips 1, 2 + b = np.array([1, 1, 2, 2]) + X = np.array([a, b]) + D = perm_invariant_hamming_matrix(X) + assert D[1, 0] == 0 diff --git a/tests/clustering/scripts/cluster_resid_mlp.py b/tests/clustering/scripts/cluster_resid_mlp.py new file mode 100644 index 000000000..bbfb5259e --- /dev/null +++ b/tests/clustering/scripts/cluster_resid_mlp.py @@ -0,0 +1,192 @@ +# %% +from pathlib import Path +from typing import Any + +import matplotlib.pyplot as plt +import torch +from muutils.dbg import dbg_auto +from torch import Tensor + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import ( + plot_dists_distribution, + plot_merge_iteration, +) +from spd.configs import Config +from spd.experiments.resid_mlp.resid_mlp_dataset import ResidMLPDataset +from spd.models.component_model import ComponentModel, SPDRunInfo +from spd.registry import EXPERIMENT_REGISTRY +from spd.utils.data_utils import DatasetGeneratedDataLoader + +DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" +TEMP_DIR: Path = Path( + "tests/.temp" +) # save to an actual dir that is gitignored, so users can view plots +TEMP_DIR.mkdir(parents=True, exist_ok=True) + + +# pyright: reportUnusedParameter=false + +# magic autoreload +# %load_ext autoreload +# %autoreload 2 + +# %% +# Load model +# ============================================================ +_CANONICAL_RUN: str | None = EXPERIMENT_REGISTRY["resid_mlp2"].canonical_run +assert _CANONICAL_RUN is not None, "No canonical run found for resid_mlp2 experiment" +SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(_CANONICAL_RUN) +MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) +MODEL.to(DEVICE) +SPD_CONFIG: Config = SPD_RUN.config + +# %% +# Setup dataset and dataloader +# ============================================================ +N_SAMPLES: int = 128 + +DATASET: ResidMLPDataset = ResidMLPDataset( + n_features=MODEL.target_model.config.n_features, # pyright: ignore[reportAttributeAccessIssue, reportArgumentType], + feature_probability=SPD_CONFIG.task_config.feature_probability, # pyright: ignore[reportAttributeAccessIssue] + device=DEVICE, + calc_labels=False, + label_type=None, + act_fn_name=None, + label_fn_seed=None, + label_coeffs=None, + data_generation_type=SPD_CONFIG.task_config.data_generation_type, # pyright: ignore[reportAttributeAccessIssue] +) + +dbg_auto( + dict( + n_features=DATASET.n_features, + feature_probability=DATASET.feature_probability, + data_generation_type=DATASET.data_generation_type, + ) +) +DATALOADER = DatasetGeneratedDataLoader(DATASET, batch_size=N_SAMPLES, shuffle=False) + +# %% +# Get component activations +# ============================================================ +# Get a single batch from the dataloader +BATCH_DATA: tuple[Tensor, Tensor] = next(iter(DATALOADER)) +BATCH: Tensor = BATCH_DATA[0] + +COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + device=DEVICE, + batch=BATCH, +) + +dbg_auto(COMPONENT_ACTS) + +# %% + +FILTER_DEAD_THRESHOLD: float = 0.1 + +# Process activations +# ============================================================ +PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) + + +plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, +) + +# %% +# run the merge iteration +# ============================================================ + +MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.1, + alpha=1, + iters=int(PROCESSED_ACTIVATIONS.n_components_alive * 0.9), + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.0}, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) + + +def _plot_func( + current_coact: torch.Tensor, + component_labels: ComponentLabels, + current_merge: Any, + costs: torch.Tensor, + merge_history: MergeHistory, + iter_idx: int, + k_groups: int, + merge_pair_cost: float, + mdl_loss: float, + mdl_loss_norm: float, + diag_acts: torch.Tensor, +) -> None: + if (iter_idx % 50 == 0 and iter_idx > 0) or iter_idx == 1: + plot_merge_iteration( + current_merge=current_merge, + current_coact=current_coact, + costs=costs, + iteration=iter_idx, + component_labels=component_labels, + show=True, # Show the plot interactively + ) + + +MERGE_HIST: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=_plot_func, +) + +# %% +# Plot merge history +# ============================================================ + +# plt.hist(mh[270]["merges"].components_per_group, bins=np.linspace(0, 56, 57)) +# plt.yscale("log") +# plt.xscale("log") + + +# %% +# compute and plot distances in an ensemble +# ============================================================ + +# Modern approach: run merge_iteration multiple times to create ensemble +ENSEMBLE_SIZE: int = 4 +HISTORIES: list[MergeHistory] = [] +for _i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, + ) + HISTORIES.append(HISTORY) + +ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) + +DISTANCES = ENSEMBLE.get_distances(method="perm_invariant_hamming") + +plot_dists_distribution( + distances=DISTANCES, + mode="points", + # label="v1" +) +plt.legend() diff --git a/tests/clustering/scripts/cluster_ss.py b/tests/clustering/scripts/cluster_ss.py new file mode 100644 index 000000000..0b7f8de97 --- /dev/null +++ b/tests/clustering/scripts/cluster_ss.py @@ -0,0 +1,133 @@ +# %% +import os + +# Suppress tokenizer parallelism warning when forking +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +from pathlib import Path + +import torch +from jaxtyping import Int +from muutils.dbg import dbg_auto +from torch import Tensor + +from spd.clustering.activations import ( + ProcessedActivations, + component_activations, + process_activations, +) +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.dataset import load_dataset +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble +from spd.clustering.plotting.activations import plot_activations +from spd.clustering.plotting.merge import plot_dists_distribution +from spd.models.component_model import ComponentModel, SPDRunInfo + +DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu" +TEMP_DIR: Path = Path( + "tests/.temp" +) # save to an actual dir that is gitignored, so users can view plots +TEMP_DIR.mkdir(parents=True, exist_ok=True) + +# magic autoreload +# %load_ext autoreload +# %autoreload 2 + +# %% +# Load model and dataset +# ============================================================ +MODEL_PATH: str = "wandb:goodfire/spd-pre-Sep-2025/runs/ioprgffh" + +SPD_RUN: SPDRunInfo = SPDRunInfo.from_path(MODEL_PATH) +MODEL: ComponentModel = ComponentModel.from_pretrained(SPD_RUN.checkpoint_path) +MODEL.to(DEVICE) +SPD_CONFIG = SPD_RUN.config + +# Use load_dataset with RunConfig to get real data +CONFIG: ClusteringRunConfig = ClusteringRunConfig( + merge_config=MergeConfig(), + model_path=MODEL_PATH, + batch_size=2, + dataset_seed=42, + dataset_streaming=True, # no effect since we do this manually +) + +DATA_BATCH: Int[Tensor, "batch_size n_ctx"] = load_dataset( + model_path=MODEL_PATH, + task_name="lm", + batch_size=CONFIG.batch_size, + seed=CONFIG.dataset_seed, + # config=CONFIG, + config_kwargs=dict(streaming=True), # see https://github.com/goodfire-ai/spd/pull/199 +) + + +# %% +# Get component activations +# ============================================================ +COMPONENT_ACTS: dict[str, Tensor] = component_activations( + model=MODEL, + batch=DATA_BATCH, + device=DEVICE, +) + +_ = dbg_auto(COMPONENT_ACTS) +# %% +# Process activations +# ============================================================ +FILTER_DEAD_THRESHOLD: float = 0.001 +FILTER_MODULES: str = "model.layers.0" + +PROCESSED_ACTIVATIONS: ProcessedActivations = process_activations( + activations=COMPONENT_ACTS, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, + filter_modules=lambda x: x.startswith(FILTER_MODULES), + seq_mode="concat", +) + +plot_activations( + processed_activations=PROCESSED_ACTIVATIONS, + save_dir=TEMP_DIR, + n_samples_max=256, + wandb_run=None, +) + +# %% +# Compute ensemble merge iterations +# ============================================================ +MERGE_CFG: MergeConfig = MergeConfig( + activation_threshold=0.01, + alpha=0.01, + iters=2, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + module_name_filter=FILTER_MODULES, + filter_dead_threshold=FILTER_DEAD_THRESHOLD, +) + +# Modern approach: run merge_iteration multiple times to create ensemble +ENSEMBLE_SIZE: int = 2 +HISTORIES: list[MergeHistory] = [] +for _i in range(ENSEMBLE_SIZE): + HISTORY: MergeHistory = merge_iteration( + merge_config=MERGE_CFG, + activations=PROCESSED_ACTIVATIONS.activations, + component_labels=PROCESSED_ACTIVATIONS.labels, + log_callback=None, + ) + HISTORIES.append(HISTORY) + +ENSEMBLE: MergeHistoryEnsemble = MergeHistoryEnsemble(data=HISTORIES) + + +# %% +# Compute and plot distances +# ============================================================ +DISTANCES = ENSEMBLE.get_distances() + +plot_dists_distribution( + distances=DISTANCES, + mode="points", +) diff --git a/tests/clustering/test_calc_distances.py b/tests/clustering/test_calc_distances.py new file mode 100644 index 000000000..b06350f4b --- /dev/null +++ b/tests/clustering/test_calc_distances.py @@ -0,0 +1,31 @@ +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge_config import MergeConfig +from spd.clustering.merge_history import MergeHistory, MergeHistoryEnsemble + + +def test_merge_history_normalization_happy_path(): + """Test that the normalization part of calc_distances.py works without errors""" + + # Create test merge histories + config = MergeConfig( + iters=3, + alpha=1.0, + activation_threshold=None, + ) + + histories = [] + for _idx in range(2): + history = MergeHistory.from_config( + merge_config=config, + labels=ComponentLabels([f"comp{j}" for j in range(4)]), + ) + histories.append(history) + + # Test ensemble creation + ensemble = MergeHistoryEnsemble(data=histories) + assert len(ensemble.data) == 2 + + # Test normalization + normalized_array, metadata = ensemble.normalized() + assert normalized_array is not None + assert metadata is not None diff --git a/tests/clustering/test_clustering_experiments.py b/tests/clustering/test_clustering_experiments.py new file mode 100644 index 000000000..fc27b6831 --- /dev/null +++ b/tests/clustering/test_clustering_experiments.py @@ -0,0 +1,102 @@ +"""Tests for clustering experiments and notebook-style scripts.""" + +import subprocess +import sys +from pathlib import Path + +import pytest + +# Test resource directories +NOTEBOOK_DIR: Path = Path("tests/clustering/scripts") +CONFIG_DIR: Path = Path("spd/clustering/configs") + + +@pytest.mark.slow +def test_cluster_resid_mlp_notebook(): + """Test running the cluster_resid_mlp.py notebook-style script.""" + script_path = NOTEBOOK_DIR / "cluster_resid_mlp.py" + assert script_path.exists(), f"Script not found: {script_path}" + + # Run the script as-is + result = subprocess.run( + [sys.executable, str(script_path)], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Script failed with return code {result.returncode}" + + +@pytest.mark.slow +def test_clustering_with_resid_mlp1_config(): + """Test running clustering with test-resid_mlp1.json config.""" + config_path = CONFIG_DIR / "pipeline-test-resid_mlp1.yaml" + assert config_path.exists(), f"Config not found: {config_path}" + + # Run the clustering main script with the test config + result = subprocess.run( + [ + "spd-cluster", + "--config", + str(config_path), + "--local", # don't assume we have slurm in the test env + ], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Clustering failed with return code {result.returncode}" + + +@pytest.mark.slow +def test_cluster_ss_notebook(): + """Test running the cluster_ss.py notebook-style script.""" + script_path = NOTEBOOK_DIR / "cluster_ss.py" + assert script_path.exists(), f"Script not found: {script_path}" + + # Run the script as-is + result = subprocess.run( + [sys.executable, str(script_path)], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Script failed with return code {result.returncode}" + + +@pytest.mark.slow +def test_clustering_with_simplestories_config(): + """Test running clustering with test-simplestories.json config.""" + config_path = CONFIG_DIR / "pipeline-test-simplestories.yaml" + assert config_path.exists(), f"Config not found: {config_path}" + + # Run the clustering main script with the test config + result = subprocess.run( + [ + "spd-cluster", + "--config", + str(config_path), + "--dataset-streaming", # see https://github.com/goodfire-ai/spd/pull/199 + "--local", # don't assume we have slurm in the test env + ], + capture_output=True, + text=True, + ) + + # Check that the script ran without errors + if result.returncode != 0: + print(f"stdout: {result.stdout}") + print(f"stderr: {result.stderr}") + assert result.returncode == 0, f"Clustering failed with return code {result.returncode}" diff --git a/tests/clustering/test_ensemble_registry.py b/tests/clustering/test_ensemble_registry.py new file mode 100644 index 000000000..c903af801 --- /dev/null +++ b/tests/clustering/test_ensemble_registry.py @@ -0,0 +1,110 @@ +"""Tests for ensemble_registry module.""" + +import tempfile +from pathlib import Path +from typing import Any + +import pytest + +from spd.clustering.ensemble_registry import ( + get_clustering_runs, + register_clustering_run, +) + + +@pytest.fixture +def _temp_registry_db(monkeypatch: Any): # pyright: ignore[reportUnusedFunction] + """Create a temporary registry database for testing.""" + with tempfile.TemporaryDirectory() as tmpdir: + temp_db_path = Path(tmpdir) / "test_registry.db" + monkeypatch.setattr("spd.clustering.ensemble_registry._ENSEMBLE_REGISTRY_DB", temp_db_path) + yield temp_db_path + + +class TestRegisterClusteringRun: + """Test register_clustering_run() function.""" + + def test_register_single_run(self, _temp_registry_db: Any): + """Test registering a single run.""" + pipeline_id = "pipeline_001" + run_id = "run_001" + + assigned_idx = register_clustering_run(pipeline_id, run_id) + + # First index should be 0 + assert assigned_idx == 0 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001")] + + def test_register_multiple_runs(self, _temp_registry_db: Any): + """Test registering multiple runs sequentially.""" + pipeline_id = "pipeline_002" + + idx0 = register_clustering_run(pipeline_id, "run_001") + idx1 = register_clustering_run(pipeline_id, "run_002") + idx2 = register_clustering_run(pipeline_id, "run_003") + + # Should auto-assign 0, 1, 2 + assert idx0 == 0 + assert idx1 == 1 + assert idx2 == 2 + + # Verify in database + runs = get_clustering_runs(pipeline_id) + assert runs == [(0, "run_001"), (1, "run_002"), (2, "run_003")] + + def test_different_pipelines_independent(self, _temp_registry_db: Any): + """Test that different pipelines have independent index sequences.""" + pipeline_a = "pipeline_a" + pipeline_b = "pipeline_b" + + # Both should start at 0 when auto-assigning + idx_a0 = register_clustering_run(pipeline_a, "run_a1") + idx_b0 = register_clustering_run(pipeline_b, "run_b1") + + assert idx_a0 == 0 + assert idx_b0 == 0 + + # Both should increment independently + idx_a1 = register_clustering_run(pipeline_a, "run_a2") + idx_b1 = register_clustering_run(pipeline_b, "run_b2") + + assert idx_a1 == 1 + assert idx_b1 == 1 + + # Verify in database + runs_a = get_clustering_runs(pipeline_a) + runs_b = get_clustering_runs(pipeline_b) + + assert runs_a == [(0, "run_a1"), (1, "run_a2")] + assert runs_b == [(0, "run_b1"), (1, "run_b2")] + + +class TestGetClusteringRuns: + """Test get_clustering_runs() function.""" + + def test_get_empty_pipeline(self, _temp_registry_db: Any): + """Test getting runs from a pipeline that doesn't exist.""" + runs = get_clustering_runs("nonexistent_pipeline") + assert runs == [] + + def test_get_runs_sorted_by_index(self, _temp_registry_db: Any): + """Test that runs are returned sorted by index.""" + pipeline_id = "pipeline_sort" + + # Register runs (indices will be auto-assigned in order) + register_clustering_run(pipeline_id, "run_000") + register_clustering_run(pipeline_id, "run_001") + register_clustering_run(pipeline_id, "run_002") + register_clustering_run(pipeline_id, "run_003") + + # Should be returned in sorted order + runs = get_clustering_runs(pipeline_id) + assert runs == [ + (0, "run_000"), + (1, "run_001"), + (2, "run_002"), + (3, "run_003"), + ] diff --git a/tests/clustering/test_filter_dead_components.py b/tests/clustering/test_filter_dead_components.py new file mode 100644 index 000000000..654631f37 --- /dev/null +++ b/tests/clustering/test_filter_dead_components.py @@ -0,0 +1,131 @@ +"""Tests for filter_dead_components function in activations.py""" + +import pytest +import torch +from torch import Tensor + +from spd.clustering.activations import FilteredActivations, filter_dead_components +from spd.clustering.consts import ComponentLabels + + +@pytest.mark.parametrize( + "max_values,threshold,expected_alive_indices", + [ + # No filtering when threshold is 0 + ([0.1, 0.2, 0.3], 0.0, [0, 1, 2]), + # Filter all when all below threshold + ([0.005, 0.003, 0.004], 0.01, []), + # Filter some components + ([0.0, 0.02, 0.0, 0.03, 0.0], 0.01, [1, 3]), + # Boundary cases: at threshold is kept + ([0.009, 0.01, 0.011], 0.01, [1, 2]), + # High threshold filters everything + ([0.1, 0.2, 0.3], 2.0, []), + # Negative threshold filters nothing + ([0.1, 0.2, 0.3], -0.01, [0, 1, 2]), + # Single component above threshold + ([0.5], 0.01, [0]), + ], +) +def test_filter_dead_components_thresholds( + max_values: list[float], + threshold: float, + expected_alive_indices: list[int], +) -> None: + """Test filtering with various max values and thresholds.""" + n_steps: int = 10 + n_components: int = len(max_values) + + activations: Tensor + labels: ComponentLabels + if n_components == 0: + activations = torch.zeros(n_steps, 0) + labels = ComponentLabels([]) + else: + activations = torch.zeros(n_steps, n_components) + # Set max values in first row + for i, val in enumerate(max_values): + activations[0, i] = val + labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=labels, filter_dead_threshold=threshold + ) + + assert result.labels == [f"comp_{i}" for i in expected_alive_indices] + assert result.n_alive == len(expected_alive_indices) + assert result.n_dead == n_components - len(expected_alive_indices) + assert result.activations.shape == (n_steps, len(expected_alive_indices)) + + # Check dead components labels + if threshold <= 0 or all(v >= threshold for v in max_values): + # No filtering occurred + assert result.dead_components_labels is None or result.dead_components_labels == [] + else: + dead_indices: list[int] = [ + i for i in range(n_components) if i not in expected_alive_indices + ] + expected_dead: list[str] = [f"comp_{i}" for i in dead_indices] + assert result.dead_components_labels is not None + assert set(result.dead_components_labels) == set(expected_dead) + + +@pytest.mark.parametrize( + "step_locations,threshold", + [ + # Max at different steps + ([0, 5, 9], 0.01), + # All at same step + ([0, 0, 0], 0.01), + # Random steps + ([3, 7, 1, 8], 0.05), + ], +) +def test_max_across_steps(step_locations: list[int], threshold: float) -> None: + """Verify that filter_dead_components correctly finds the maximum activation + across ALL time steps for each component, not just looking at a single step. + + This test creates components where the maximum activation occurs at different + time steps, ensuring the function scans the entire temporal dimension.""" + n_steps: int = 10 + n_components: int = len(step_locations) + activations: Tensor = torch.zeros(n_steps, n_components) + + # Set values above threshold at specified steps + for i, step in enumerate(step_locations): + activations[step, i] = threshold + 0.01 + + labels: ComponentLabels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=labels, filter_dead_threshold=threshold + ) + + # All components should be alive since their max is above threshold + assert result.n_alive == n_components + assert result.n_dead == 0 + assert result.labels == labels + + +@pytest.mark.parametrize("threshold", [0.001, 0.01, 0.1, 0.5]) +def test_linear_gradient_thresholds(threshold: float) -> None: + """Test with linearly spaced activation values.""" + n_steps: int = 10 + n_components: int = 10 + activations: Tensor = torch.zeros(n_steps, n_components) + + # Create linearly spaced max values: 0, 0.1, 0.2, ..., 0.9 + for i in range(n_components): + activations[0, i] = i * 0.1 + + labels: list[str] = [f"comp_{i}" for i in range(n_components)] + + result: FilteredActivations = filter_dead_components( + activations=activations, labels=ComponentLabels(labels), filter_dead_threshold=threshold + ) + + # Count how many components should be alive + expected_alive: int = sum(i * 0.1 >= threshold for i in range(n_components)) + + assert result.n_alive == expected_alive + assert result.n_dead == n_components - expected_alive diff --git a/tests/clustering/test_merge_config.py b/tests/clustering/test_merge_config.py new file mode 100644 index 000000000..63f4e88f7 --- /dev/null +++ b/tests/clustering/test_merge_config.py @@ -0,0 +1,179 @@ +"""Tests for MergeConfig with new sampling system.""" + +import pytest +import torch + +from spd.clustering.merge_config import MergeConfig + + +class TestMergeConfigSampling: + """Test MergeConfig integration with sampling system.""" + + def test_default_config(self): + """Test default MergeConfig uses range sampler.""" + config = MergeConfig() + + assert config.merge_pair_sampling_method == "range" + assert config.merge_pair_sampling_kwargs == {"threshold": 0.05} + + def test_range_sampler_config(self): + """Test MergeConfig with range sampler.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + + assert config.merge_pair_sampling_method == "range" + assert config.merge_pair_sampling_kwargs == {"threshold": 0.1} + + # Test that sampler works + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + + def test_mcmc_sampler_config(self): + """Test MergeConfig with MCMC sampler.""" + config = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 2.0} + ) + + assert config.merge_pair_sampling_method == "mcmc" + assert config.merge_pair_sampling_kwargs == {"temperature": 2.0} + + # Test that sampler works + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + + def test_invalid_sampler_method(self): + """Test that invalid sampler method raises error.""" + from pydantic import ValidationError + + # Pydantic validates at construction time + with pytest.raises(ValidationError): + _config = MergeConfig(merge_pair_sampling_method="invalid") # pyright: ignore[reportArgumentType] + + def test_config_with_all_parameters(self): + """Test MergeConfig with all parameters set.""" + config = MergeConfig( + activation_threshold=0.01, + alpha=1.5, + iters=200, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 0.5}, + filter_dead_threshold=0.001, + module_name_filter="model.layers", + ) + + assert config.activation_threshold == 0.01 + assert config.alpha == 1.5 + assert config.iters == 200 + assert config.merge_pair_sampling_method == "mcmc" + assert config.merge_pair_sampling_kwargs == {"temperature": 0.5} + assert config.filter_dead_threshold == 0.001 + assert config.module_name_filter == "model.layers" + + def test_config_serialization(self): + """Test that config can be serialized and deserialized.""" + config = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.5} + ) + + # Serialize to dict + config_dict = config.model_dump() + assert config_dict["merge_pair_sampling_method"] == "mcmc" + assert config_dict["merge_pair_sampling_kwargs"] == {"temperature": 1.5} + + # Deserialize from dict + config2 = MergeConfig(**config_dict) + assert config2.merge_pair_sampling_method == "mcmc" + assert config2.merge_pair_sampling_kwargs == {"temperature": 1.5} + + def test_config_json_serialization(self): + """Test JSON serialization of config.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.2} + ) + + # Serialize to JSON string + json_str = config.model_dump_json() + assert "range" in json_str + assert "0.2" in json_str + + # Parse back from JSON + import json + + config_dict = json.loads(json_str) + config2 = MergeConfig(**config_dict) + + assert config2.merge_pair_sampling_method == "range" + assert config2.merge_pair_sampling_kwargs == {"threshold": 0.2} + + def test_stable_hash_changes_with_sampling_params(self): + """Test that stable_hash changes when sampling parameters change.""" + config1 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + config2 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.2} + ) + config3 = MergeConfig( + merge_pair_sampling_method="mcmc", merge_pair_sampling_kwargs={"temperature": 1.0} + ) + + # Different configs should have different hashes + assert config1.stable_hash != config2.stable_hash + assert config1.stable_hash != config3.stable_hash + assert config2.stable_hash != config3.stable_hash + + # Same config should have same hash + config4 = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.1} + ) + assert config1.stable_hash == config4.stable_hash + + def test_empty_kwargs(self): + """Test that empty kwargs dict works.""" + config = MergeConfig(merge_pair_sampling_method="range", merge_pair_sampling_kwargs={}) + + # Should work with default parameters of the sampler + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Range sampler has default threshold=0.05 + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert pair[0] != pair[1] + + def test_extra_kwargs_filtered(self): + """Test that only valid kwargs are used by sampler.""" + config = MergeConfig( + merge_pair_sampling_method="range", merge_pair_sampling_kwargs={"threshold": 0.3} + ) + + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Should work with config's method + pair = config.merge_pair_sample(costs) + + assert isinstance(pair, tuple) + assert pair[0] != pair[1] diff --git a/tests/clustering/test_merge_integration.py b/tests/clustering/test_merge_integration.py new file mode 100644 index 000000000..8492300de --- /dev/null +++ b/tests/clustering/test_merge_integration.py @@ -0,0 +1,151 @@ +"""Integration tests for the merge system with new samplers.""" + +import torch + +from spd.clustering.consts import ComponentLabels +from spd.clustering.merge import merge_iteration +from spd.clustering.merge_config import MergeConfig + + +class TestMergeIntegration: + """Test the full merge iteration with different samplers.""" + + def test_merge_with_range_sampler(self): + """Test merge iteration with range sampler.""" + # Create test data + n_samples = 100 + n_components = 10 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Configure with range sampler + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=5, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.1}, + filter_dead_threshold=0.001, + ) + + # Run merge iteration + history = merge_iteration( + activations=activations, merge_config=config, component_labels=component_labels + ) + + # Check results + assert history is not None + assert len(history.merges.k_groups) > 0 + # First entry is after first merge, so should be n_components - 1 + assert history.merges.k_groups[0].item() == n_components - 1 + # After iterations, should have fewer groups (merges reduce count) + # Exact count depends on early stopping conditions + assert history.merges.k_groups[-1].item() < n_components + assert history.merges.k_groups[-1].item() >= 2 # Should stop before going below 2 + + def test_merge_with_mcmc_sampler(self): + """Test merge iteration with MCMC sampler.""" + # Create test data + n_samples = 100 + n_components = 10 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Configure with MCMC sampler + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=5, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 1.0}, + filter_dead_threshold=0.001, + ) + + # Run merge iteration + history = merge_iteration( + activations=activations, merge_config=config, component_labels=component_labels + ) + + # Check results + assert history is not None + assert len(history.merges.k_groups) > 0 + # First entry is after first merge, so should be n_components - 1 + assert history.merges.k_groups[0].item() == n_components - 1 + # Should have fewer groups after iterations + assert history.merges.k_groups[-1].item() < n_components + assert history.merges.k_groups[-1].item() >= 2 + + def test_merge_comparison_samplers(self): + """Compare behavior of different samplers with same data.""" + # Create test data with clear structure + n_samples = 100 + n_components = 8 + activations = torch.rand(n_samples, n_components) + + # Make some components more active to create cost structure + activations[:, 0] *= 2 # Component 0 is very active + activations[:, 1] *= 0.1 # Component 1 is rarely active + + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + # Run with range sampler (threshold=0 for deterministic minimum selection) + config_range = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=3, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.0}, # Always select minimum + ) + + history_range = merge_iteration( + activations=activations.clone(), + merge_config=config_range, + component_labels=ComponentLabels(component_labels.copy()), + ) + + # Run with MCMC sampler (low temperature for near-deterministic) + config_mcmc = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=3, + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 0.01}, # Very low temp + ) + + history_mcmc = merge_iteration( + activations=activations.clone(), + merge_config=config_mcmc, + component_labels=ComponentLabels(component_labels.copy()), + ) + + # Both should reduce groups from initial count + assert history_range.merges.k_groups[-1].item() < n_components + assert history_mcmc.merges.k_groups[-1].item() < n_components + assert history_range.merges.k_groups[-1].item() >= 2 + assert history_mcmc.merges.k_groups[-1].item() >= 2 + + def test_merge_with_small_components(self): + """Test merge with very few components.""" + # Edge case: only 3 components + n_samples = 50 + n_components = 3 + activations = torch.rand(n_samples, n_components) + component_labels = ComponentLabels([f"comp_{i}" for i in range(n_components)]) + + config = MergeConfig( + activation_threshold=0.1, + alpha=1.0, + iters=1, # Just one merge + merge_pair_sampling_method="mcmc", + merge_pair_sampling_kwargs={"temperature": 2.0}, + ) + + history = merge_iteration( + activations=activations, merge_config=config, component_labels=component_labels + ) + + # First entry is after first merge, so should be 3 - 1 = 2 + assert history.merges.k_groups[0].item() == 2 + # Early stopping may occur at 2 groups, so final count could be 2 or 3 + assert history.merges.k_groups[-1].item() >= 2 + assert history.merges.k_groups[-1].item() <= 3 diff --git a/tests/clustering/test_merge_pair_samplers.py b/tests/clustering/test_merge_pair_samplers.py new file mode 100644 index 000000000..66c59cb66 --- /dev/null +++ b/tests/clustering/test_merge_pair_samplers.py @@ -0,0 +1,257 @@ +"""Tests for merge pair sampling functionality.""" + +import pytest +import torch + +from spd.clustering.math.merge_pair_samplers import ( + MERGE_PAIR_SAMPLERS, + mcmc_sampler, + range_sampler, +) + + +class TestRangeSampler: + """Test range-based merge pair sampling.""" + + def test_range_sampler_basic(self): + """Test basic functionality of range sampler.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 # Make symmetric + costs.fill_diagonal_(float("inf")) # No self-merges + + # Test with different thresholds + pair_low = range_sampler(costs, threshold=0.0) + pair_mid = range_sampler(costs, threshold=0.5) + pair_high = range_sampler(costs, threshold=1.0) + + # All should return valid pairs + assert pair_low[0] != pair_low[1] + assert pair_mid[0] != pair_mid[1] + assert pair_high[0] != pair_high[1] + + # All indices should be in valid range + for pair in [pair_low, pair_mid, pair_high]: + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + def test_range_sampler_threshold_zero(self): + """Test that threshold=0 always selects minimum cost pair.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Find the true minimum + min_val = float("inf") + _min_pair = None + for i in range(k): + for j in range(k): + if i != j and costs[i, j] < min_val: + min_val = costs[i, j].item() + _min_pair = (i, j) + + # Sample multiple times with threshold=0 + for _ in range(10): + pair = range_sampler(costs, threshold=0.0) + # Should always get the minimum (or its symmetric equivalent) + assert costs[pair[0], pair[1]] == min_val or costs[pair[1], pair[0]] == min_val + + def test_range_sampler_threshold_one(self): + """Test that threshold=1 can select any non-diagonal pair.""" + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Sample many times to check we get different pairs + pairs_seen = set() + for _ in range(100): + pair = range_sampler(costs, threshold=1.0) + # Normalize pair order for comparison + normalized = tuple(sorted(pair)) + pairs_seen.add(normalized) + + # With threshold=1, we should see multiple different pairs + assert len(pairs_seen) > 1 + + def test_range_sampler_small_matrix(self): + """Test range sampler with 2x2 matrix.""" + costs = torch.tensor([[float("inf"), 1.0], [1.0, float("inf")]]) + + pair = range_sampler(costs, threshold=0.5) + # Only valid pair is (0, 1) or (1, 0) + assert set(pair) == {0, 1} + + +class TestMCMCSampler: + """Test MCMC-based merge pair sampling.""" + + def test_mcmc_sampler_basic(self): + """Test basic functionality of MCMC sampler.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Test with different temperatures + pair_low_temp = mcmc_sampler(costs, temperature=0.1) + pair_mid_temp = mcmc_sampler(costs, temperature=1.0) + pair_high_temp = mcmc_sampler(costs, temperature=10.0) + + # All should return valid pairs + for pair in [pair_low_temp, pair_mid_temp, pair_high_temp]: + assert pair[0] != pair[1] + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + def test_mcmc_sampler_low_temperature(self): + """Test that low temperature favors low-cost pairs.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Find minimum cost + min_val = float("inf") + for i in range(k): + for j in range(k): + if i != j: + min_val = min(min_val, costs[i, j].item()) + + # Sample many times with very low temperature + low_cost_count = 0 + n_samples = 100 + for _ in range(n_samples): + pair = mcmc_sampler(costs, temperature=0.01) + cost = costs[pair[0], pair[1]].item() + # Check if it's close to minimum + if abs(cost - min_val) < 0.5: # Within 0.5 of minimum + low_cost_count += 1 + + # Most samples should be near minimum with low temperature + assert low_cost_count > n_samples * 0.7 + + def test_mcmc_sampler_high_temperature(self): + """Test that high temperature gives more uniform sampling.""" + k = 4 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Sample many times with high temperature + pairs_count = {} + n_samples = 1000 + for _ in range(n_samples): + pair = mcmc_sampler(costs, temperature=100.0) + # Normalize pair order for counting + normalized = tuple(sorted(pair)) + pairs_count[normalized] = pairs_count.get(normalized, 0) + 1 + + # With high temperature, distribution should be relatively uniform + # There are k*(k-1)/2 unique pairs + expected_count = n_samples / (k * (k - 1) / 2) + for count in pairs_count.values(): + # Each pair count should be within reasonable range of expected + assert expected_count * 0.3 < count < expected_count * 1.7 + + def test_mcmc_sampler_small_matrix(self): + """Test MCMC sampler with 2x2 matrix.""" + costs = torch.tensor([[float("inf"), 1.0], [1.0, float("inf")]]) + + pair = mcmc_sampler(costs, temperature=1.0) + # Only valid pair is (0, 1) or (1, 0) + assert set(pair) == {0, 1} + + def test_mcmc_sampler_extreme_costs(self): + """Test MCMC sampler with extreme cost differences.""" + k = 3 + # Create matrix with one very low cost and rest high + costs = torch.full((k, k), 1000.0) + costs[0, 1] = costs[1, 0] = 1.0 # One low-cost pair + costs.fill_diagonal_(float("inf")) + + # With low temperature, should almost always select the low-cost pair + low_cost_selected = 0 + for _ in range(100): + pair = mcmc_sampler(costs, temperature=0.1) + if set(pair) == {0, 1}: + low_cost_selected += 1 + + assert low_cost_selected > 95 # Should almost always select (0,1) + + +class TestSamplerRegistry: + """Test the sampler registry.""" + + def test_registry_contains_samplers(self): + """Test that registry contains expected samplers.""" + assert "range" in MERGE_PAIR_SAMPLERS + assert "mcmc" in MERGE_PAIR_SAMPLERS + assert MERGE_PAIR_SAMPLERS["range"] is range_sampler + assert MERGE_PAIR_SAMPLERS["mcmc"] is mcmc_sampler + + def test_registry_samplers_callable(self): + """Test that all registry samplers are callable with correct signature.""" + k = 3 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + for name, sampler in MERGE_PAIR_SAMPLERS.items(): + # Should be callable + assert callable(sampler) + + # Test with default kwargs + if name == "range": + pair = sampler(costs, threshold=0.5) + elif name == "mcmc": + pair = sampler(costs, temperature=1.0) + else: + pytest.fail(f"Unknown sampler {name}") + + # Should return valid pair + assert isinstance(pair, tuple) + assert len(pair) == 2 + assert pair[0] != pair[1] + assert 0 <= pair[0] < k + assert 0 <= pair[1] < k + + +class TestSamplerIntegration: + """Integration tests for samplers with edge cases.""" + + def test_samplers_deterministic_with_seed(self): + """Test that samplers are deterministic with fixed seed.""" + k = 5 + costs = torch.randn(k, k) + costs = (costs + costs.T) / 2 + costs.fill_diagonal_(float("inf")) + + # Test range sampler + torch.manual_seed(42) + pair1 = range_sampler(costs, threshold=0.5) + torch.manual_seed(42) + pair2 = range_sampler(costs, threshold=0.5) + # Can't guarantee exact match due to Python's random module + # but both should be valid + assert pair1[0] != pair1[1] + assert pair2[0] != pair2[1] + + # Test MCMC sampler + torch.manual_seed(42) + pair1 = mcmc_sampler(costs, temperature=1.0) + torch.manual_seed(42) + pair2 = mcmc_sampler(costs, temperature=1.0) + assert pair1 == pair2 # Should be deterministic with same seed + + def test_samplers_all_infinite_costs(self): + """Test samplers handle all-infinite costs gracefully.""" + k = 3 + costs = torch.full((k, k), float("inf")) + + # This is an edge case - no valid pairs exist + # Samplers should handle this without crashing + # (though the result may not be meaningful) + with pytest.raises((ValueError, RuntimeError, IndexError)): + range_sampler(costs, threshold=0.5) diff --git a/tests/clustering/test_pipeline_config.py b/tests/clustering/test_pipeline_config.py new file mode 100644 index 000000000..ca6bad6ee --- /dev/null +++ b/tests/clustering/test_pipeline_config.py @@ -0,0 +1,137 @@ +"""Tests for ClusteringPipelineConfig and ClusteringRunConfig with inline config support.""" + +from pathlib import Path + +import pydantic_core +import pytest + +from spd.clustering.clustering_run_config import ClusteringRunConfig +from spd.clustering.merge_config import MergeConfig +from spd.clustering.scripts.run_pipeline import ClusteringPipelineConfig +from spd.settings import REPO_ROOT + + +class TestClusteringRunConfigStableHash: + """Test ClusteringRunConfig.stable_hash_b64() method.""" + + def test_stable_hash_b64(self): + """Test that stable_hash_b64 is deterministic, unique, and URL-safe.""" + # Create 4 configs: 2 identical, 2 different + config1 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config2 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config3 = ClusteringRunConfig( + model_path="wandb:test/project/run2", # Different model_path + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig(), + ) + config4 = ClusteringRunConfig( + model_path="wandb:test/project/run1", + batch_size=32, + dataset_seed=0, + merge_config=MergeConfig( + activation_threshold=0.2 + ), # Different merge_config to test nested fields + ) + + hash1 = config1.stable_hash_b64() + hash2 = config2.stable_hash_b64() + hash3 = config3.stable_hash_b64() + hash4 = config4.stable_hash_b64() + + # Identical configs produce identical hashes + assert hash1 == hash2 + + # Different configs produce different hashes + assert hash1 != hash3 + assert hash1 != hash4 + assert hash3 != hash4 + + # Hashes are strings + assert isinstance(hash1, str) + assert len(hash1) > 0 + + # Hashes are URL-safe base64 (no padding, URL-safe chars only) + assert "=" not in hash1 + valid_chars = set("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_") + assert all(c in valid_chars for c in hash1) + + +class TestClusteringPipelineConfigValidation: + """Test ClusteringPipelineConfig validation logic.""" + + def test_error_when_path_does_not_exist(self): + """Test that error is raised when clustering_run_config_path does not exist.""" + with pytest.raises(pydantic_core._pydantic_core.ValidationError): + ClusteringPipelineConfig( + clustering_run_config_path=Path("nonexistent/path.json"), + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + slurm_job_name_prefix=None, + slurm_partition=None, + wandb_entity="test", + create_git_snapshot=False, + ) + + def test_valid_config_with_existing_path(self): + """Test that config is valid when path points to existing file.""" + expected_path = Path("spd/clustering/configs/crc/resid_mlp1.json") + + config = ClusteringPipelineConfig( + clustering_run_config_path=expected_path, + n_runs=2, + distances_methods=["perm_invariant_hamming"], + base_output_dir=Path("/tmp/test"), + wandb_entity="test", + create_git_snapshot=False, + ) + + assert config.clustering_run_config_path == expected_path + + +def _get_config_files(path: Path): + """Helper to get all config files.""" + pipeline_config_files = ( + list(path.glob("*.yaml")) + list(path.glob("*.yml")) + list(path.glob("*.json")) + ) + assert len(pipeline_config_files) > 0, f"No pipeline files found in {path}" + return pipeline_config_files + + +class TestAllConfigsValidation: + """Test that all existing config files can be loaded and validated.""" + + @pytest.mark.parametrize( + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs"), + ids=lambda p: p.stem, + ) + def test_config_validate_pipeline(self, config_file: Path): + """Test that each pipeline config file is valid.""" + print(config_file) + _config = ClusteringPipelineConfig.from_file(config_file) + crc_path = _config.clustering_run_config_path + print(f"{crc_path = }") + assert crc_path.exists() + + @pytest.mark.parametrize( + "config_file", + _get_config_files(REPO_ROOT / "spd" / "clustering" / "configs" / "crc"), + ids=lambda p: p.stem, + ) + def test_config_validate_pipeline_clustering_run(self, config_file: Path): + """Test that each clustering run config file is valid.""" + print(config_file) + _config = ClusteringRunConfig.from_file(config_file) + assert isinstance(_config, ClusteringRunConfig) diff --git a/tests/clustering/test_run_clustering_happy_path.py b/tests/clustering/test_run_clustering_happy_path.py new file mode 100644 index 000000000..5e2cbbd1c --- /dev/null +++ b/tests/clustering/test_run_clustering_happy_path.py @@ -0,0 +1,38 @@ +import tempfile +from pathlib import Path + +import pytest + +from spd.clustering.clustering_run_config import ClusteringRunConfig, LoggingIntervals +from spd.clustering.merge_config import MergeConfig +from spd.clustering.scripts.run_clustering import main + + +@pytest.mark.slow +def test_run_clustering_happy_path(): + """Test that run_clustering.py runs without errors.""" + with tempfile.TemporaryDirectory() as temp_dir: + config = ClusteringRunConfig( + model_path="wandb:goodfire/spd/runs/zxbu57pt", # An ss_llama run + batch_size=4, + dataset_seed=0, + base_output_dir=Path(temp_dir), + ensemble_id=None, + merge_config=MergeConfig( + activation_threshold=0.01, + alpha=1.0, + iters=3, + merge_pair_sampling_method="range", + merge_pair_sampling_kwargs={"threshold": 0.05}, + ), + wandb_project=None, + wandb_entity="goodfire", + logging_intervals=LoggingIntervals( + stat=1, + tensor=100, + plot=100, + artifact=100, + ), + dataset_streaming=True, # tests in CI very slow without this, see https://github.com/goodfire-ai/spd/pull/199 + ) + main(config) diff --git a/tests/scripts_run/test_main.py b/tests/scripts_run/test_main.py index 8f8ce65e9..00a6ee044 100644 --- a/tests/scripts_run/test_main.py +++ b/tests/scripts_run/test_main.py @@ -36,7 +36,7 @@ class TestSPDRun: ("tms_5-2", True, 4, None), # Command count depends on sweep params ], ) - @patch("spd.scripts.run.submit_slurm_array") + @patch("spd.scripts.run.submit_slurm_script") @patch("spd.scripts.run.create_slurm_array_script") @patch("spd.scripts.run.load_sweep_params") def test_spd_run_not_local_no_sweep( @@ -88,6 +88,7 @@ def test_spd_run_not_local_no_sweep( # Verify command structure for cmd in commands: + assert isinstance(cmd, str) assert "python" in cmd assert "_decomposition.py" in cmd assert "json:" in cmd @@ -109,7 +110,7 @@ def test_spd_run_not_local_no_sweep( ("tms_5-2", True), ], ) - @patch("spd.scripts.run.subprocess.run") + @patch("spd.scripts.run.run_script_array_local") @patch("spd.scripts.run.load_sweep_params") def test_spd_run_local_no_sweep( self, @@ -133,30 +134,31 @@ def test_spd_run_local_no_sweep( **self._DEFAULT_MAIN_KWARGS, # pyright: ignore[reportArgumentType] ) - # Calculate expected number of subprocess calls + # Calculate expected number of commands num_experiments = len(experiments.split(",")) - expected_calls = num_experiments * 2 if sweep else num_experiments + expected_num_commands = num_experiments * 2 if sweep else num_experiments - # Assert subprocess.run was called the expected number of times - assert mock_subprocess.call_count == expected_calls + # Assert run_script_array_local was called exactly once + assert mock_subprocess.call_count == 1 - # Verify each subprocess call - for call in mock_subprocess.call_args_list: - args = call[0][0] # Get the command list + # Get the commands list from the call + commands = mock_subprocess.call_args[0][0] + assert len(commands) == expected_num_commands - # Should be a list of arguments - assert isinstance(args, list) - assert args[0] == "python" - assert "_decomposition.py" in args[1] + # Verify each command + for cmd in commands: + # Should be a string + assert isinstance(cmd, str) + assert "python" in cmd + assert "_decomposition.py" in cmd # Check for required arguments in the command - cmd_str = " ".join(args) - assert "json:" in cmd_str - assert "--sweep_id" in cmd_str - assert "--evals_id" in cmd_str + assert "json:" in cmd + assert "--sweep_id" in cmd + assert "--evals_id" in cmd if sweep: - assert "--sweep_params_json" in cmd_str + assert "--sweep_params_json" in cmd # No wandb functions should be called since use_wandb=False @@ -178,7 +180,7 @@ def test_invalid_experiment_name(self): **self._DEFAULT_MAIN_KWARGS, # pyright: ignore[reportArgumentType] ) - @patch("spd.scripts.run.subprocess.run") + @patch("spd.scripts.run.run_script_array_local") def test_sweep_params_integration(self, mock_subprocess): """Test that sweep parameters are correctly integrated into commands. @@ -196,12 +198,17 @@ def test_sweep_params_integration(self, mock_subprocess): **self._DEFAULT_MAIN_KWARGS, # pyright: ignore[reportArgumentType] ) + # Assert run_script_array_local was called exactly once + assert mock_subprocess.call_count == 1 + + # Get the commands list + commands = mock_subprocess.call_args[0][0] + # Verify multiple commands were generated (sweep should create multiple runs) - assert mock_subprocess.call_count > 1 + assert len(commands) > 1 # Check that sweep parameters are in the commands - for call in mock_subprocess.call_args_list: - args = call[0][0] - cmd_str = " ".join(args) - assert "--sweep_params_json" in cmd_str - assert "json:" in cmd_str + for cmd in commands: + assert isinstance(cmd, str) + assert "--sweep_params_json" in cmd + assert "json:" in cmd diff --git a/uv.lock b/uv.lock index 0877583f6..79220adbb 100644 --- a/uv.lock +++ b/uv.lock @@ -17,7 +17,7 @@ wheels = [ [[package]] name = "aiohttp" -version = "3.13.0" +version = "3.13.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohappyeyeballs" }, @@ -28,25 +28,25 @@ dependencies = [ { name = "propcache" }, { name = "yarl" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/62/f1/8515650ac3121a9e55c7b217c60e7fae3e0134b5acfe65691781b5356929/aiohttp-3.13.0.tar.gz", hash = "sha256:378dbc57dd8cf341ce243f13fa1fa5394d68e2e02c15cd5f28eae35a70ec7f67", size = 7832348, upload-time = "2025-10-06T19:58:48.089Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/ce/3b83ebba6b3207a7135e5fcaba49706f8a4b6008153b4e30540c982fae26/aiohttp-3.13.2.tar.gz", hash = "sha256:40176a52c186aefef6eb3cad2cdd30cd06e3afbe88fe8ab2af9c0b90f228daca", size = 7837994, upload-time = "2025-10-28T20:59:39.937Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/86/2c/ac53efdc9c10e41399acc2395af98f835b86d0141d5c3820857eb9f6a14a/aiohttp-3.13.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:00243e51f16f6ec0fb021659d4af92f675f3cf9f9b39efd142aa3ad641d8d1e6", size = 730090, upload-time = "2025-10-06T19:56:16.858Z" }, - { url = "https://files.pythonhosted.org/packages/13/18/1ac95683e1c1d48ef4503965c96f5401618a04c139edae12e200392daae8/aiohttp-3.13.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:059978d2fddc462e9211362cbc8446747ecd930537fa559d3d25c256f032ff54", size = 488041, upload-time = "2025-10-06T19:56:18.659Z" }, - { url = "https://files.pythonhosted.org/packages/fd/79/ef0d477c771a642d1a881b92d226314c43d3c74bc674c93e12e679397a97/aiohttp-3.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:564b36512a7da3b386143c611867e3f7cfb249300a1bf60889bd9985da67ab77", size = 486989, upload-time = "2025-10-06T19:56:20.371Z" }, - { url = "https://files.pythonhosted.org/packages/37/b4/0e440481a0e77a551d6c5dcab5d11f1ff6b2b2ddb8dedc24f54f5caad732/aiohttp-3.13.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4aa995b9156ae499393d949a456a7ab0b994a8241a96db73a3b73c7a090eff6a", size = 1718331, upload-time = "2025-10-06T19:56:22.188Z" }, - { url = "https://files.pythonhosted.org/packages/e6/59/76c421cc4a75bb1aceadb92f20ee6f05a990aa6960c64b59e8e0d340e3f5/aiohttp-3.13.0-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:55ca0e95a3905f62f00900255ed807c580775174252999286f283e646d675a49", size = 1686263, upload-time = "2025-10-06T19:56:24.393Z" }, - { url = "https://files.pythonhosted.org/packages/ec/ac/5095f12a79c7775f402cfc3e83651b6e0a92ade10ddf7f2c78c4fed79f71/aiohttp-3.13.0-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:49ce7525853a981fc35d380aa2353536a01a9ec1b30979ea4e35966316cace7e", size = 1754265, upload-time = "2025-10-06T19:56:26.365Z" }, - { url = "https://files.pythonhosted.org/packages/05/d7/a48e4989bd76cc70600c505bbdd0d90ca1ad7f9053eceeb9dbcf9345a9ec/aiohttp-3.13.0-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2117be9883501eaf95503bd313eb4c7a23d567edd44014ba15835a1e9ec6d852", size = 1856486, upload-time = "2025-10-06T19:56:28.438Z" }, - { url = "https://files.pythonhosted.org/packages/1e/02/45b388b49e37933f316e1fb39c0de6fb1d77384b0c8f4cf6af5f2cbe3ea6/aiohttp-3.13.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d169c47e40c911f728439da853b6fd06da83761012e6e76f11cb62cddae7282b", size = 1737545, upload-time = "2025-10-06T19:56:30.688Z" }, - { url = "https://files.pythonhosted.org/packages/6c/a7/4fde058f1605c34a219348a83a99f14724cc64e68a42480fc03cf40f9ea3/aiohttp-3.13.0-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:703ad3f742fc81e543638a7bebddd35acadaa0004a5e00535e795f4b6f2c25ca", size = 1552958, upload-time = "2025-10-06T19:56:32.528Z" }, - { url = "https://files.pythonhosted.org/packages/d1/12/0bac4d29231981e3aa234e88d1931f6ba38135ff4c2cf3afbb7895527630/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5bf635c3476f4119b940cc8d94ad454cbe0c377e61b4527f0192aabeac1e9370", size = 1681166, upload-time = "2025-10-06T19:56:34.81Z" }, - { url = "https://files.pythonhosted.org/packages/71/95/b829eb5f8ac1ca1d8085bb8df614c8acf3ff32e23ad5ad1173c7c9761daa/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:cfe6285ef99e7ee51cef20609be2bc1dd0e8446462b71c9db8bb296ba632810a", size = 1710516, upload-time = "2025-10-06T19:56:36.787Z" }, - { url = "https://files.pythonhosted.org/packages/47/6d/15ccf4ef3c254d899f62580e0c7fc717014f4d14a3ac31771e505d2c736c/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:34d8af6391c5f2e69749d7f037b614b8c5c42093c251f336bdbfa4b03c57d6c4", size = 1731354, upload-time = "2025-10-06T19:56:38.659Z" }, - { url = "https://files.pythonhosted.org/packages/46/6a/8acf6c57e03b6fdcc8b4c06392e66abaff3213ea275e41db3edb20738d91/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:12f5d820fadc5848d4559ea838aef733cf37ed2a1103bba148ac2f5547c14c29", size = 1548040, upload-time = "2025-10-06T19:56:40.578Z" }, - { url = "https://files.pythonhosted.org/packages/75/7d/fbfd59ab2a83fe2578ce79ac3db49727b81e9f4c3376217ad09c03c6d279/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:0f1338b61ea66f4757a0544ed8a02ccbf60e38d9cfb3225888888dd4475ebb96", size = 1756031, upload-time = "2025-10-06T19:56:42.492Z" }, - { url = "https://files.pythonhosted.org/packages/99/e7/cc9f0fdf06cab3ca61e6b62bff9a4b978b8ca736e9d76ddf54365673ab19/aiohttp-3.13.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:582770f82513419512da096e8df21ca44f86a2e56e25dc93c5ab4df0fe065bf0", size = 1714933, upload-time = "2025-10-06T19:56:45.542Z" }, - { url = "https://files.pythonhosted.org/packages/db/43/7abbe1de94748a58a71881163ee280fd3217db36e8344d109f63638fe16a/aiohttp-3.13.0-cp313-cp313-win32.whl", hash = "sha256:3194b8cab8dbc882f37c13ef1262e0a3d62064fa97533d3aa124771f7bf1ecee", size = 423799, upload-time = "2025-10-06T19:56:47.779Z" }, - { url = "https://files.pythonhosted.org/packages/c9/58/afab7f2b9e7df88c995995172eb78cae8a3d5a62d5681abaade86b3f0089/aiohttp-3.13.0-cp313-cp313-win_amd64.whl", hash = "sha256:7897298b3eedc790257fef8a6ec582ca04e9dbe568ba4a9a890913b925b8ea21", size = 450138, upload-time = "2025-10-06T19:56:49.49Z" }, + { url = "https://files.pythonhosted.org/packages/bf/78/7e90ca79e5aa39f9694dcfd74f4720782d3c6828113bb1f3197f7e7c4a56/aiohttp-3.13.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7519bdc7dfc1940d201651b52bf5e03f5503bda45ad6eacf64dda98be5b2b6be", size = 732139, upload-time = "2025-10-28T20:57:02.455Z" }, + { url = "https://files.pythonhosted.org/packages/db/ed/1f59215ab6853fbaa5c8495fa6cbc39edfc93553426152b75d82a5f32b76/aiohttp-3.13.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:088912a78b4d4f547a1f19c099d5a506df17eacec3c6f4375e2831ec1d995742", size = 490082, upload-time = "2025-10-28T20:57:04.784Z" }, + { url = "https://files.pythonhosted.org/packages/68/7b/fe0fe0f5e05e13629d893c760465173a15ad0039c0a5b0d0040995c8075e/aiohttp-3.13.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5276807b9de9092af38ed23ce120539ab0ac955547b38563a9ba4f5b07b95293", size = 489035, upload-time = "2025-10-28T20:57:06.894Z" }, + { url = "https://files.pythonhosted.org/packages/d2/04/db5279e38471b7ac801d7d36a57d1230feeee130bbe2a74f72731b23c2b1/aiohttp-3.13.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1237c1375eaef0db4dcd7c2559f42e8af7b87ea7d295b118c60c36a6e61cb811", size = 1720387, upload-time = "2025-10-28T20:57:08.685Z" }, + { url = "https://files.pythonhosted.org/packages/31/07/8ea4326bd7dae2bd59828f69d7fdc6e04523caa55e4a70f4a8725a7e4ed2/aiohttp-3.13.2-cp313-cp313-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:96581619c57419c3d7d78703d5b78c1e5e5fc0172d60f555bdebaced82ded19a", size = 1688314, upload-time = "2025-10-28T20:57:10.693Z" }, + { url = "https://files.pythonhosted.org/packages/48/ab/3d98007b5b87ffd519d065225438cc3b668b2f245572a8cb53da5dd2b1bc/aiohttp-3.13.2-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2713a95b47374169409d18103366de1050fe0ea73db358fc7a7acb2880422d4", size = 1756317, upload-time = "2025-10-28T20:57:12.563Z" }, + { url = "https://files.pythonhosted.org/packages/97/3d/801ca172b3d857fafb7b50c7c03f91b72b867a13abca982ed6b3081774ef/aiohttp-3.13.2-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:228a1cd556b3caca590e9511a89444925da87d35219a49ab5da0c36d2d943a6a", size = 1858539, upload-time = "2025-10-28T20:57:14.623Z" }, + { url = "https://files.pythonhosted.org/packages/f7/0d/4764669bdf47bd472899b3d3db91fffbe925c8e3038ec591a2fd2ad6a14d/aiohttp-3.13.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ac6cde5fba8d7d8c6ac963dbb0256a9854e9fafff52fbcc58fdf819357892c3e", size = 1739597, upload-time = "2025-10-28T20:57:16.399Z" }, + { url = "https://files.pythonhosted.org/packages/c4/52/7bd3c6693da58ba16e657eb904a5b6decfc48ecd06e9ac098591653b1566/aiohttp-3.13.2-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f2bef8237544f4e42878c61cef4e2839fee6346dc60f5739f876a9c50be7fcdb", size = 1555006, upload-time = "2025-10-28T20:57:18.288Z" }, + { url = "https://files.pythonhosted.org/packages/48/30/9586667acec5993b6f41d2ebcf96e97a1255a85f62f3c653110a5de4d346/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:16f15a4eac3bc2d76c45f7ebdd48a65d41b242eb6c31c2245463b40b34584ded", size = 1683220, upload-time = "2025-10-28T20:57:20.241Z" }, + { url = "https://files.pythonhosted.org/packages/71/01/3afe4c96854cfd7b30d78333852e8e851dceaec1c40fd00fec90c6402dd2/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:bb7fb776645af5cc58ab804c58d7eba545a97e047254a52ce89c157b5af6cd0b", size = 1712570, upload-time = "2025-10-28T20:57:22.253Z" }, + { url = "https://files.pythonhosted.org/packages/11/2c/22799d8e720f4697a9e66fd9c02479e40a49de3de2f0bbe7f9f78a987808/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:e1b4951125ec10c70802f2cb09736c895861cd39fd9dcb35107b4dc8ae6220b8", size = 1733407, upload-time = "2025-10-28T20:57:24.37Z" }, + { url = "https://files.pythonhosted.org/packages/34/cb/90f15dd029f07cebbd91f8238a8b363978b530cd128488085b5703683594/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:550bf765101ae721ee1d37d8095f47b1f220650f85fe1af37a90ce75bab89d04", size = 1550093, upload-time = "2025-10-28T20:57:26.257Z" }, + { url = "https://files.pythonhosted.org/packages/69/46/12dce9be9d3303ecbf4d30ad45a7683dc63d90733c2d9fe512be6716cd40/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:fe91b87fc295973096251e2d25a811388e7d8adf3bd2b97ef6ae78bc4ac6c476", size = 1758084, upload-time = "2025-10-28T20:57:28.349Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c8/0932b558da0c302ffd639fc6362a313b98fdf235dc417bc2493da8394df7/aiohttp-3.13.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e0c8e31cfcc4592cb200160344b2fb6ae0f9e4effe06c644b5a125d4ae5ebe23", size = 1716987, upload-time = "2025-10-28T20:57:30.233Z" }, + { url = "https://files.pythonhosted.org/packages/5d/8b/f5bd1a75003daed099baec373aed678f2e9b34f2ad40d85baa1368556396/aiohttp-3.13.2-cp313-cp313-win32.whl", hash = "sha256:0740f31a60848d6edb296a0df827473eede90c689b8f9f2a4cdde74889eb2254", size = 425859, upload-time = "2025-10-28T20:57:32.105Z" }, + { url = "https://files.pythonhosted.org/packages/5d/28/a8a9fc6957b2cee8902414e41816b5ab5536ecf43c3b1843c10e82c559b2/aiohttp-3.13.2-cp313-cp313-win_amd64.whl", hash = "sha256:a88d13e7ca367394908f8a276b89d04a3652044612b9a408a0bb22a5ed976a1a", size = 452192, upload-time = "2025-10-28T20:57:34.166Z" }, ] [[package]] @@ -77,6 +77,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/aa/f3/0b6ced594e51cc95d8c1fc1640d3623770d01e4969d29c0bd09945fafefa/altair-5.5.0-py3-none-any.whl", hash = "sha256:91a310b926508d560fe0148d02a194f38b824122641ef528113d029fcd129f8c", size = 731200, upload-time = "2024-11-23T23:39:56.4Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/a6/dc46877b911e40c00d395771ea710d5e77b6de7bacd5fdcd78d70cc5a48f/annotated_doc-0.0.3.tar.gz", hash = "sha256:e18370014c70187422c33e945053ff4c286f453a984eba84d0dbfa0c935adeda", size = 5535, upload-time = "2025-10-24T14:57:10.718Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/02/b7/cf592cb5de5cb3bade3357f8d2cf42bf103bbe39f459824b4939fd212911/annotated_doc-0.0.3-py3-none-any.whl", hash = "sha256:348ec6664a76f1fd3be81f43dffbee4c7e8ce931ba71ec67cc7f4ade7fbbb580", size = 5488, upload-time = "2025-10-24T14:57:09.462Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -138,6 +147,36 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, ] +[[package]] +name = "beautifulsoup4" +version = "4.14.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "soupsieve" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/e9/df2358efd7659577435e2177bfa69cba6c33216681af51a707193dec162a/beautifulsoup4-4.14.2.tar.gz", hash = "sha256:2a98ab9f944a11acee9cc848508ec28d9228abfd522ef0fad6a02a72e0ded69e", size = 625822, upload-time = "2025-09-29T10:05:42.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/fe/3aed5d0be4d404d12d36ab97e2f1791424d9ca39c2f754a6285d59a3b01d/beautifulsoup4-4.14.2-py3-none-any.whl", hash = "sha256:5ef6fa3a8cbece8488d66985560f97ed091e22bbc4e9c2338508a9d5de6d4515", size = 106392, upload-time = "2025-09-29T10:05:43.771Z" }, +] + +[[package]] +name = "bleach" +version = "6.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/07/18/3c8523962314be6bf4c8989c79ad9531c825210dd13a8669f6b84336e8bd/bleach-6.3.0.tar.gz", hash = "sha256:6f3b91b1c0a02bb9a78b5a454c92506aa0fdf197e1d5e114d2e00c6f64306d22", size = 203533, upload-time = "2025-10-27T17:57:39.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cd/3a/577b549de0cc09d95f11087ee63c739bba856cd3952697eec4c4bb91350a/bleach-6.3.0-py3-none-any.whl", hash = "sha256:fe10ec77c93ddf3d13a73b035abaac7a9f5e436513864ccdad516693213c65d6", size = 164437, upload-time = "2025-10-27T17:57:37.538Z" }, +] + +[package.optional-dependencies] +css = [ + { name = "tinycss2" }, +] + [[package]] name = "blinker" version = "1.9.0" @@ -331,7 +370,7 @@ wheels = [ [[package]] name = "datasets" -version = "4.2.0" +version = "4.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dill" }, @@ -349,9 +388,9 @@ dependencies = [ { name = "tqdm" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/48/0186fbc4b86a4f9ecaf04eb01e877e78b53bfa0b03be9c84b2298431ba33/datasets-4.2.0.tar.gz", hash = "sha256:8333a7db9f3bb8044c1b819a35d4e3e2809596c837793b0921382efffdc36e78", size = 582256, upload-time = "2025-10-09T16:10:15.534Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2a/47/325206ac160f7699ed9f1798afa8f8f8d5189b03bf3815654859ac1d5cba/datasets-4.3.0.tar.gz", hash = "sha256:bc9118ed9afd92346c5be7ed3aaa00177eb907c25467f9d072a0d22777efbd2b", size = 582801, upload-time = "2025-10-23T16:31:51.547Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/91/9e/0bbbd09b116fd8ee2d3617e28e6598551d2f0f24d3a2ce99cc87ec85aeb0/datasets-4.2.0-py3-none-any.whl", hash = "sha256:fdc43aaf4a73b31f64f80f72f195ab413a1141ed15555d675b2fd17926f8b026", size = 506316, upload-time = "2025-10-09T16:10:13.375Z" }, + { url = "https://files.pythonhosted.org/packages/ca/51/409a8184ed35453d9cbb3d6b20d524b1115c2c2d117b85d5e9b06cd70b45/datasets-4.3.0-py3-none-any.whl", hash = "sha256:0ea157e72138b3ca6c7d2415f19a164ecf7d4c4fa72da2a570da286882e96903", size = 506846, upload-time = "2025-10-23T16:31:49.965Z" }, ] [[package]] @@ -376,6 +415,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520, upload-time = "2021-03-08T10:59:26.269Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604, upload-time = "2021-03-08T10:59:24.45Z" }, +] + [[package]] name = "dill" version = "0.4.0" @@ -423,16 +471,26 @@ wheels = [ [[package]] name = "fastapi" -version = "0.119.1" +version = "0.120.1" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "annotated-doc" }, { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a6/f4/152127681182e6413e7a89684c434e19e7414ed7ac0c632999c3c6980640/fastapi-0.119.1.tar.gz", hash = "sha256:a5e3426edce3fe221af4e1992c6d79011b247e3b03cc57999d697fe76cbf8ae0", size = 338616, upload-time = "2025-10-20T11:30:27.734Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/cc/28aff6e246ee85bd571b26e4a793b84d42700e3bdc3008c3d747eda7b06d/fastapi-0.120.1.tar.gz", hash = "sha256:b5c6217e9ddca6dfcf54c97986180d4a1955e10c693d74943fc5327700178bff", size = 337616, upload-time = "2025-10-27T17:53:42.954Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/bb/1a74dbe87e9a595bf63052c886dfef965dc5b91d149456a8301eb3d41ce2/fastapi-0.120.1-py3-none-any.whl", hash = "sha256:0e8a2c328e96c117272d8c794d3a97d205f753cc2e69dd7ee387b7488a75601f", size = 108254, upload-time = "2025-10-27T17:53:40.076Z" }, +] + +[[package]] +name = "fastjsonschema" +version = "2.21.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/b5/23b216d9d985a956623b6bd12d4086b60f0059b27799f23016af04a74ea1/fastjsonschema-2.21.2.tar.gz", hash = "sha256:b1eb43748041c880796cd077f1a07c3d94e93ae84bba5ed36800a33554ae05de", size = 374130, upload-time = "2025-08-14T18:49:36.666Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/26/e6d959b4ac959fdb3e9c4154656fc160794db6af8e64673d52759456bf07/fastapi-0.119.1-py3-none-any.whl", hash = "sha256:0b8c2a2cce853216e150e9bd4faaed88227f8eb37de21cb200771f491586a27f", size = 108123, upload-time = "2025-10-20T11:30:26.185Z" }, + { url = "https://files.pythonhosted.org/packages/cb/a8/20d0723294217e47de6d9e2e40fd4a9d2f7c4b6ef974babd482a59743694/fastjsonschema-2.21.2-py3-none-any.whl", hash = "sha256:1c797122d0a86c5cace2e54bf4e819c36223b552017172f32c5c024a6b77e463", size = 24024, upload-time = "2025-08-14T18:49:34.776Z" }, ] [[package]] @@ -563,17 +621,24 @@ wheels = [ [[package]] name = "hf-xet" -version = "1.1.10" +version = "1.2.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/74/31/feeddfce1748c4a233ec1aa5b7396161c07ae1aa9b7bdbc9a72c3c7dd768/hf_xet-1.1.10.tar.gz", hash = "sha256:408aef343800a2102374a883f283ff29068055c111f003ff840733d3b715bb97", size = 487910, upload-time = "2025-09-12T20:10:27.12Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/a2/343e6d05de96908366bdc0081f2d8607d61200be2ac802769c4284cc65bd/hf_xet-1.1.10-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:686083aca1a6669bc85c21c0563551cbcdaa5cf7876a91f3d074a030b577231d", size = 2761466, upload-time = "2025-09-12T20:10:22.836Z" }, - { url = "https://files.pythonhosted.org/packages/31/f9/6215f948ac8f17566ee27af6430ea72045e0418ce757260248b483f4183b/hf_xet-1.1.10-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:71081925383b66b24eedff3013f8e6bbd41215c3338be4b94ba75fd75b21513b", size = 2623807, upload-time = "2025-09-12T20:10:21.118Z" }, - { url = "https://files.pythonhosted.org/packages/15/07/86397573efefff941e100367bbda0b21496ffcdb34db7ab51912994c32a2/hf_xet-1.1.10-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b6bceb6361c80c1cc42b5a7b4e3efd90e64630bcf11224dcac50ef30a47e435", size = 3186960, upload-time = "2025-09-12T20:10:19.336Z" }, - { url = "https://files.pythonhosted.org/packages/01/a7/0b2e242b918cc30e1f91980f3c4b026ff2eedaf1e2ad96933bca164b2869/hf_xet-1.1.10-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:eae7c1fc8a664e54753ffc235e11427ca61f4b0477d757cc4eb9ae374b69f09c", size = 3087167, upload-time = "2025-09-12T20:10:17.255Z" }, - { url = "https://files.pythonhosted.org/packages/4a/25/3e32ab61cc7145b11eee9d745988e2f0f4fafda81b25980eebf97d8cff15/hf_xet-1.1.10-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0a0005fd08f002180f7a12d4e13b22be277725bc23ed0529f8add5c7a6309c06", size = 3248612, upload-time = "2025-09-12T20:10:24.093Z" }, - { url = "https://files.pythonhosted.org/packages/2c/3d/ab7109e607ed321afaa690f557a9ada6d6d164ec852fd6bf9979665dc3d6/hf_xet-1.1.10-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:f900481cf6e362a6c549c61ff77468bd59d6dd082f3170a36acfef2eb6a6793f", size = 3353360, upload-time = "2025-09-12T20:10:25.563Z" }, - { url = "https://files.pythonhosted.org/packages/ee/0e/471f0a21db36e71a2f1752767ad77e92d8cde24e974e03d662931b1305ec/hf_xet-1.1.10-cp37-abi3-win_amd64.whl", hash = "sha256:5f54b19cc347c13235ae7ee98b330c26dd65ef1df47e5316ffb1e87713ca7045", size = 2804691, upload-time = "2025-09-12T20:10:28.433Z" }, + { url = "https://files.pythonhosted.org/packages/9e/a5/85ef910a0aa034a2abcfadc360ab5ac6f6bc4e9112349bd40ca97551cff0/hf_xet-1.2.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:ceeefcd1b7aed4956ae8499e2199607765fbd1c60510752003b6cc0b8413b649", size = 2861870, upload-time = "2025-10-24T19:04:11.422Z" }, + { url = "https://files.pythonhosted.org/packages/ea/40/e2e0a7eb9a51fe8828ba2d47fe22a7e74914ea8a0db68a18c3aa7449c767/hf_xet-1.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:b70218dd548e9840224df5638fdc94bd033552963cfa97f9170829381179c813", size = 2717584, upload-time = "2025-10-24T19:04:09.586Z" }, + { url = "https://files.pythonhosted.org/packages/a5/7d/daf7f8bc4594fdd59a8a596f9e3886133fdc68e675292218a5e4c1b7e834/hf_xet-1.2.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d40b18769bb9a8bc82a9ede575ce1a44c75eb80e7375a01d76259089529b5dc", size = 3315004, upload-time = "2025-10-24T19:04:00.314Z" }, + { url = "https://files.pythonhosted.org/packages/b1/ba/45ea2f605fbf6d81c8b21e4d970b168b18a53515923010c312c06cd83164/hf_xet-1.2.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:cd3a6027d59cfb60177c12d6424e31f4b5ff13d8e3a1247b3a584bf8977e6df5", size = 3222636, upload-time = "2025-10-24T19:03:58.111Z" }, + { url = "https://files.pythonhosted.org/packages/4a/1d/04513e3cab8f29ab8c109d309ddd21a2705afab9d52f2ba1151e0c14f086/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:6de1fc44f58f6dd937956c8d304d8c2dea264c80680bcfa61ca4a15e7b76780f", size = 3408448, upload-time = "2025-10-24T19:04:20.951Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7c/60a2756d7feec7387db3a1176c632357632fbe7849fce576c5559d4520c7/hf_xet-1.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f182f264ed2acd566c514e45da9f2119110e48a87a327ca271027904c70c5832", size = 3503401, upload-time = "2025-10-24T19:04:22.549Z" }, + { url = "https://files.pythonhosted.org/packages/4e/64/48fffbd67fb418ab07451e4ce641a70de1c40c10a13e25325e24858ebe5a/hf_xet-1.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:293a7a3787e5c95d7be1857358a9130694a9c6021de3f27fa233f37267174382", size = 2900866, upload-time = "2025-10-24T19:04:33.461Z" }, + { url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" }, + { url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" }, + { url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" }, + { url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" }, + { url = "https://files.pythonhosted.org/packages/0b/dd/7ac658d54b9fb7999a0ccb07ad863b413cbaf5cf172f48ebcd9497ec7263/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737", size = 3413812, upload-time = "2025-10-24T19:04:24.585Z" }, + { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, + { url = "https://files.pythonhosted.org/packages/cb/44/870d44b30e1dcfb6a65932e3e1506c103a8a5aea9103c337e7a53180322c/hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69", size = 2905735, upload-time = "2025-10-24T19:04:35.928Z" }, ] [[package]] @@ -606,7 +671,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.35.3" +version = "0.36.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -618,9 +683,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/10/7e/a0a97de7c73671863ca6b3f61fa12518caf35db37825e43d63a70956738c/huggingface_hub-0.35.3.tar.gz", hash = "sha256:350932eaa5cc6a4747efae85126ee220e4ef1b54e29d31c3b45c5612ddf0b32a", size = 461798, upload-time = "2025-09-29T14:29:58.625Z" } +sdist = { url = "https://files.pythonhosted.org/packages/98/63/4910c5fa9128fdadf6a9c5ac138e8b1b6cee4ca44bf7915bbfbce4e355ee/huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25", size = 463358, upload-time = "2025-10-23T12:12:01.413Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/31/a0/651f93d154cb72323358bf2bbae3e642bdb5d2f1bfc874d096f7cb159fa0/huggingface_hub-0.35.3-py3-none-any.whl", hash = "sha256:0e3a01829c19d86d03793e4577816fe3bdfc1602ac62c7fb220d593d351224ba", size = 564262, upload-time = "2025-09-29T14:29:55.813Z" }, + { url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" }, ] [[package]] @@ -643,16 +708,16 @@ wheels = [ [[package]] name = "iniconfig" -version = "2.1.0" +version = "2.3.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793, upload-time = "2025-03-19T20:09:59.721Z" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] [[package]] name = "ipykernel" -version = "7.0.1" +version = "7.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "appnope", marker = "sys_platform == 'darwin'" }, @@ -669,9 +734,9 @@ dependencies = [ { name = "tornado" }, { name = "traitlets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a8/4c/9f0024c8457286c6bfd5405a15d650ec5ea36f420ef9bbc58b301f66cfc5/ipykernel-7.0.1.tar.gz", hash = "sha256:2d3fd7cdef22071c2abbad78f142b743228c5d59cd470d034871ae0ac359533c", size = 171460, upload-time = "2025-10-14T16:17:07.325Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/a4/4948be6eb88628505b83a1f2f40d90254cab66abf2043b3c40fa07dfce0f/ipykernel-7.1.0.tar.gz", hash = "sha256:58a3fc88533d5930c3546dc7eac66c6d288acde4f801e2001e65edc5dc9cf0db", size = 174579, upload-time = "2025-10-27T09:46:39.471Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/f7/761037905ffdec673533bfa43af8d4c31c859c778dfc3bbb71899875ec18/ipykernel-7.0.1-py3-none-any.whl", hash = "sha256:87182a8305e28954b6721087dec45b171712610111d494c17bb607befa1c4000", size = 118157, upload-time = "2025-10-14T16:17:05.606Z" }, + { url = "https://files.pythonhosted.org/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl", hash = "sha256:763b5ec6c5b7776f6a8d7ce09b267693b4e5ce75cb50ae696aaefb3c85e1ea4c", size = 117968, upload-time = "2025-10-27T09:46:37.805Z" }, ] [[package]] @@ -743,6 +808,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/a1/3d680cbfd5f4b8f15abc1d571870c5fc3e594bb582bc3b64ea099db13e56/jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67", size = 134899, upload-time = "2025-03-05T20:05:00.369Z" }, ] +[[package]] +name = "joblib" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, +] + [[package]] name = "jsonschema" version = "4.25.1" @@ -799,6 +873,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e7/e7/80988e32bf6f73919a113473a604f5a8f09094de312b9d52b79c2df7612b/jupyter_core-5.9.1-py3-none-any.whl", hash = "sha256:ebf87fdc6073d142e114c72c9e29a9d7ca03fad818c5d300ce2adc1fb0743407", size = 29032, upload-time = "2025-10-16T19:19:16.783Z" }, ] +[[package]] +name = "jupyterlab-pygments" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/51/9187be60d989df97f5f0aba133fa54e7300f17616e065d1ada7d7646b6d6/jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d", size = 512900, upload-time = "2023-11-23T09:26:37.44Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780", size = 15884, upload-time = "2023-11-23T09:26:34.325Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.9" @@ -897,14 +980,23 @@ wheels = [ [[package]] name = "matplotlib-inline" -version = "0.1.7" +version = "0.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "traitlets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159, upload-time = "2024-04-15T13:44:44.803Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c7/74/97e72a36efd4ae2bccb3463284300f8953f199b5ffbc04cbbb0ec78f74b1/matplotlib_inline-0.2.1.tar.gz", hash = "sha256:e1ee949c340d771fc39e241ea75683deb94762c8fa5f2927ec57c83c4dffa9fe", size = 8110, upload-time = "2025-10-23T09:00:22.126Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, + { url = "https://files.pythonhosted.org/packages/af/33/ee4519fa02ed11a94aef9559552f3b17bb863f2ecfe1a35dc7f548cde231/matplotlib_inline-0.2.1-py3-none-any.whl", hash = "sha256:d56ce5156ba6085e00a9d54fead6ed29a9c47e215cd1bba2e976ef39f5710a76", size = 9516, upload-time = "2025-10-23T09:00:20.675Z" }, +] + +[[package]] +name = "mistune" +version = "3.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/02/a7fb8b21d4d55ac93cdcde9d3638da5dd0ebdd3a4fed76c7725e10b81cbe/mistune-3.1.4.tar.gz", hash = "sha256:b5a7f801d389f724ec702840c11d8fc48f2b33519102fc7ee739e8177b672164", size = 94588, upload-time = "2025-08-29T07:20:43.594Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/f0/8282d9641415e9e33df173516226b404d367a0fc55e1a60424a152913abc/mistune-3.1.4-py3-none-any.whl", hash = "sha256:93691da911e5d9d2e23bc54472892aff676df27a75274962ff9edc210364266d", size = 53481, upload-time = "2025-08-29T07:20:42.218Z" }, ] [[package]] @@ -977,13 +1069,77 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, ] +[[package]] +name = "muutils" +version = "0.8.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/c8/556e999e5e5662ca2d74aa486962b2e7a955e58723af6cadca293be0bd37/muutils-0.8.12.tar.gz", hash = "sha256:ffc0d2c5b0e3bbf4c442dd810880aec7d9f95995e7677e14dc72f0a5ef12b993", size = 3348223, upload-time = "2025-10-28T17:52:25.769Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0a/65/d6e07cbff0caf10b2c77fad77e9138973c689dbe50c5ecb3b96764630276/muutils-0.8.12-py3-none-any.whl", hash = "sha256:19ecc6f2cab6e162d6f84f6f0d96377dc387a0e7105334c0b6d8eb90934eaeea", size = 129087, upload-time = "2025-10-28T17:52:23.013Z" }, +] + [[package]] name = "narwhals" -version = "2.8.0" +version = "2.10.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ae/05/79a5b5a795f36c1aaa002d194c1ef71e5d95f7e1900155bbfde734815ab9/narwhals-2.8.0.tar.gz", hash = "sha256:52e0b22d54718264ae703bd9293af53b04abc995a1414908c3b807ba8c913858", size = 574277, upload-time = "2025-10-13T08:44:28.81Z" } +sdist = { url = "https://files.pythonhosted.org/packages/56/e5/ef07d31c2e07d99eecac8e14ace5c20aeb00ecba4ed5bb00343136380524/narwhals-2.10.0.tar.gz", hash = "sha256:1c05bbef2048a4045263de7d98c3d06140583eb13d796dd733b2157f05d24485", size = 582423, upload-time = "2025-10-27T17:55:55.632Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/86/ac808ecb94322a3f1ea31627d13ab3e50dd4333564d711e0e481ad0f4586/narwhals-2.8.0-py3-none-any.whl", hash = "sha256:6304856676ba4a79fd34148bda63aed8060dd6edb1227edf3659ce5e091de73c", size = 415852, upload-time = "2025-10-13T08:44:25.421Z" }, + { url = "https://files.pythonhosted.org/packages/29/13/024ae0586d901f8a6f99e2d29b4ae217e8ef11d3fd944cdfc3bbde5f2a08/narwhals-2.10.0-py3-none-any.whl", hash = "sha256:baed44e8fc38e800e3a585e3fa9843a7079a6fad5fbffbecee4348d6ac52298c", size = 418077, upload-time = "2025-10-27T17:55:53.709Z" }, +] + +[[package]] +name = "nbclient" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "nbformat" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/66/7ffd18d58eae90d5721f9f39212327695b749e23ad44b3881744eaf4d9e8/nbclient-0.10.2.tar.gz", hash = "sha256:90b7fc6b810630db87a6d0c2250b1f0ab4cf4d3c27a299b0cde78a4ed3fd9193", size = 62424, upload-time = "2024-12-19T10:32:27.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/6d/e7fa07f03a4a7b221d94b4d586edb754a9b0dc3c9e2c93353e9fa4e0d117/nbclient-0.10.2-py3-none-any.whl", hash = "sha256:4ffee11e788b4a27fabeb7955547e4318a5298f34342a4bfd01f2e1faaeadc3d", size = 25434, upload-time = "2024-12-19T10:32:24.139Z" }, +] + +[[package]] +name = "nbconvert" +version = "7.16.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "beautifulsoup4" }, + { name = "bleach", extra = ["css"] }, + { name = "defusedxml" }, + { name = "jinja2" }, + { name = "jupyter-core" }, + { name = "jupyterlab-pygments" }, + { name = "markupsafe" }, + { name = "mistune" }, + { name = "nbclient" }, + { name = "nbformat" }, + { name = "packaging" }, + { name = "pandocfilters" }, + { name = "pygments" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/59/f28e15fc47ffb73af68a8d9b47367a8630d76e97ae85ad18271b9db96fdf/nbconvert-7.16.6.tar.gz", hash = "sha256:576a7e37c6480da7b8465eefa66c17844243816ce1ccc372633c6b71c3c0f582", size = 857715, upload-time = "2025-01-28T09:29:14.724Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/9a/cd673b2f773a12c992f41309ef81b99da1690426bd2f96957a7ade0d3ed7/nbconvert-7.16.6-py3-none-any.whl", hash = "sha256:1375a7b67e0c2883678c48e506dc320febb57685e5ee67faa51b18a90f3a712b", size = 258525, upload-time = "2025-01-28T09:29:12.551Z" }, +] + +[[package]] +name = "nbformat" +version = "5.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastjsonschema" }, + { name = "jsonschema" }, + { name = "jupyter-core" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/fd/91545e604bc3dad7dca9ed03284086039b294c6b3d75c0d2fa45f9e9caf3/nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a", size = 142749, upload-time = "2024-04-04T11:20:37.371Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b", size = 78454, upload-time = "2024-04-04T11:20:34.895Z" }, ] [[package]] @@ -1096,7 +1252,7 @@ name = "nvidia-cudnn-cu12" version = "9.10.2.21" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/ba/51/e123d997aa098c61d029f76663dedbfb9bc8dcf8c60cbd6adbe42f76d049/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:949452be657fa16687d0930933f032835951ef0892b37d2d53824d1a84dc97a8", size = 706758467, upload-time = "2025-06-06T21:54:08.597Z" }, @@ -1107,7 +1263,7 @@ name = "nvidia-cufft-cu12" version = "11.3.3.83" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/1f/13/ee4e00f30e676b66ae65b4f08cb5bcbb8392c03f54f2d5413ea99a5d1c80/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d2dd21ec0b88cf61b62e6b43564355e5222e4a3fb394cac0db101f2dd0d4f74", size = 193118695, upload-time = "2025-03-07T01:45:27.821Z" }, @@ -1134,9 +1290,9 @@ name = "nvidia-cusolver-cu12" version = "11.7.3.90" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux'" }, - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-cublas-cu12" }, + { name = "nvidia-cusparse-cu12" }, + { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/85/48/9a13d2975803e8cf2777d5ed57b87a0b6ca2cc795f9a4f59796a910bfb80/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_x86_64.whl", hash = "sha256:4376c11ad263152bd50ea295c05370360776f8c3427b30991df774f9fb26c450", size = 267506905, upload-time = "2025-03-07T01:47:16.273Z" }, @@ -1147,7 +1303,7 @@ name = "nvidia-cusparse-cu12" version = "12.5.8.93" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux'" }, + { name = "nvidia-nvjitlink-cu12" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/c2/f5/e1854cb2f2bcd4280c44736c93550cc300ff4b8c95ebe370d0aa7d2b473d/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ec05d76bbbd8b61b06a80e1eaf8cf4959c3d4ce8e711b65ebd0443bb0ebb13b", size = 288216466, upload-time = "2025-03-07T01:48:13.779Z" }, @@ -1221,6 +1377,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bd/17/e756653095a083d8a37cbd816cb87148debcfcd920129b25f99dd8d04271/pandas-2.3.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c4fc4c21971a1a9f4bdb4c73978c7f7256caa3e62b323f70d6cb80db583350bc", size = 13199233, upload-time = "2025-09-29T23:24:24.876Z" }, ] +[[package]] +name = "pandocfilters" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/6f/3dd4940bbe001c06a65f88e36bad298bc7a0de5036115639926b0c5c0458/pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e", size = 8454, upload-time = "2024-01-18T20:08:13.726Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/af/4fbc8cab944db5d21b7e2a5b8e9211a03a79852b1157e2c102fcc61ac440/pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc", size = 8663, upload-time = "2024-01-18T20:08:11.28Z" }, +] + [[package]] name = "parso" version = "0.8.5" @@ -1377,18 +1542,22 @@ wheels = [ [[package]] name = "psutil" -version = "7.1.0" +version = "7.1.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b3/31/4723d756b59344b643542936e37a31d1d3204bcdc42a7daa8ee9eb06fb50/psutil-7.1.0.tar.gz", hash = "sha256:655708b3c069387c8b77b072fc429a57d0e214221d01c0a772df7dfedcb3bcd2", size = 497660, upload-time = "2025-09-17T20:14:52.902Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/ec/7b8e6b9b1d22708138630ef34c53ab2b61032c04f16adfdbb96791c8c70c/psutil-7.1.2.tar.gz", hash = "sha256:aa225cdde1335ff9684708ee8c72650f6598d5ed2114b9a7c5802030b1785018", size = 487424, upload-time = "2025-10-25T10:46:34.931Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/62/ce4051019ee20ce0ed74432dd73a5bb087a6704284a470bb8adff69a0932/psutil-7.1.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:76168cef4397494250e9f4e73eb3752b146de1dd950040b29186d0cce1d5ca13", size = 245242, upload-time = "2025-09-17T20:14:56.126Z" }, - { url = "https://files.pythonhosted.org/packages/38/61/f76959fba841bf5b61123fbf4b650886dc4094c6858008b5bf73d9057216/psutil-7.1.0-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:5d007560c8c372efdff9e4579c2846d71de737e4605f611437255e81efcca2c5", size = 246682, upload-time = "2025-09-17T20:14:58.25Z" }, - { url = "https://files.pythonhosted.org/packages/88/7a/37c99d2e77ec30d63398ffa6a660450b8a62517cabe44b3e9bae97696e8d/psutil-7.1.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:22e4454970b32472ce7deaa45d045b34d3648ce478e26a04c7e858a0a6e75ff3", size = 287994, upload-time = "2025-09-17T20:14:59.901Z" }, - { url = "https://files.pythonhosted.org/packages/9d/de/04c8c61232f7244aa0a4b9a9fbd63a89d5aeaf94b2fc9d1d16e2faa5cbb0/psutil-7.1.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c70e113920d51e89f212dd7be06219a9b88014e63a4cec69b684c327bc474e3", size = 291163, upload-time = "2025-09-17T20:15:01.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/58/c4f976234bf6d4737bc8c02a81192f045c307b72cf39c9e5c5a2d78927f6/psutil-7.1.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7d4a113425c037300de3ac8b331637293da9be9713855c4fc9d2d97436d7259d", size = 293625, upload-time = "2025-09-17T20:15:04.492Z" }, - { url = "https://files.pythonhosted.org/packages/79/87/157c8e7959ec39ced1b11cc93c730c4fb7f9d408569a6c59dbd92ceb35db/psutil-7.1.0-cp37-abi3-win32.whl", hash = "sha256:09ad740870c8d219ed8daae0ad3b726d3bf9a028a198e7f3080f6a1888b99bca", size = 244812, upload-time = "2025-09-17T20:15:07.462Z" }, - { url = "https://files.pythonhosted.org/packages/bf/e9/b44c4f697276a7a95b8e94d0e320a7bf7f3318521b23de69035540b39838/psutil-7.1.0-cp37-abi3-win_amd64.whl", hash = "sha256:57f5e987c36d3146c0dd2528cd42151cf96cd359b9d67cfff836995cc5df9a3d", size = 247965, upload-time = "2025-09-17T20:15:09.673Z" }, - { url = "https://files.pythonhosted.org/packages/26/65/1070a6e3c036f39142c2820c4b52e9243246fcfc3f96239ac84472ba361e/psutil-7.1.0-cp37-abi3-win_arm64.whl", hash = "sha256:6937cb68133e7c97b6cc9649a570c9a18ba0efebed46d8c5dae4c07fa1b67a07", size = 244971, upload-time = "2025-09-17T20:15:12.262Z" }, + { url = "https://files.pythonhosted.org/packages/b8/d9/b56cc9f883140ac10021a8c9b0f4e16eed1ba675c22513cdcbce3ba64014/psutil-7.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0cc5c6889b9871f231ed5455a9a02149e388fffcb30b607fb7a8896a6d95f22e", size = 238575, upload-time = "2025-10-25T10:46:38.728Z" }, + { url = "https://files.pythonhosted.org/packages/36/eb/28d22de383888deb252c818622196e709da98816e296ef95afda33f1c0a2/psutil-7.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8e9e77a977208d84aa363a4a12e0f72189d58bbf4e46b49aae29a2c6e93ef206", size = 239297, upload-time = "2025-10-25T10:46:41.347Z" }, + { url = "https://files.pythonhosted.org/packages/89/5d/220039e2f28cc129626e54d63892ab05c0d56a29818bfe7268dcb5008932/psutil-7.1.2-cp313-cp313t-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7d9623a5e4164d2220ecceb071f4b333b3c78866141e8887c072129185f41278", size = 280420, upload-time = "2025-10-25T10:46:44.122Z" }, + { url = "https://files.pythonhosted.org/packages/ba/7a/286f0e1c167445b2ef4a6cbdfc8c59fdb45a5a493788950cf8467201dc73/psutil-7.1.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:364b1c10fe4ed59c89ec49e5f1a70da353b27986fa8233b4b999df4742a5ee2f", size = 283049, upload-time = "2025-10-25T10:46:47.095Z" }, + { url = "https://files.pythonhosted.org/packages/aa/cc/7eb93260794a42e39b976f3a4dde89725800b9f573b014fac142002a5c98/psutil-7.1.2-cp313-cp313t-win_amd64.whl", hash = "sha256:f101ef84de7e05d41310e3ccbdd65a6dd1d9eed85e8aaf0758405d022308e204", size = 248713, upload-time = "2025-10-25T10:46:49.573Z" }, + { url = "https://files.pythonhosted.org/packages/ab/1a/0681a92b53366e01f0a099f5237d0c8a2f79d322ac589cccde5e30c8a4e2/psutil-7.1.2-cp313-cp313t-win_arm64.whl", hash = "sha256:20c00824048a95de67f00afedc7b08b282aa08638585b0206a9fb51f28f1a165", size = 244644, upload-time = "2025-10-25T10:46:51.924Z" }, + { url = "https://files.pythonhosted.org/packages/ae/89/b9f8d47ddbc52d7301fc868e8224e5f44ed3c7f55e6d0f54ecaf5dd9ff5e/psutil-7.1.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:c9ba5c19f2d46203ee8c152c7b01df6eec87d883cfd8ee1af2ef2727f6b0f814", size = 237244, upload-time = "2025-10-25T10:47:07.086Z" }, + { url = "https://files.pythonhosted.org/packages/c8/7a/8628c2f6b240680a67d73d8742bb9ff39b1820a693740e43096d5dcb01e5/psutil-7.1.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:2a486030d2fe81bec023f703d3d155f4823a10a47c36784c84f1cc7f8d39bedb", size = 238101, upload-time = "2025-10-25T10:47:09.523Z" }, + { url = "https://files.pythonhosted.org/packages/30/28/5e27f4d5a0e347f8e3cc16cd7d35533dbce086c95807f1f0e9cd77e26c10/psutil-7.1.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3efd8fc791492e7808a51cb2b94889db7578bfaea22df931424f874468e389e3", size = 258675, upload-time = "2025-10-25T10:47:11.082Z" }, + { url = "https://files.pythonhosted.org/packages/e5/5c/79cf60c9acf36d087f0db0f82066fca4a780e97e5b3a2e4c38209c03d170/psutil-7.1.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e2aeb9b64f481b8eabfc633bd39e0016d4d8bbcd590d984af764d80bf0851b8a", size = 260203, upload-time = "2025-10-25T10:47:13.226Z" }, + { url = "https://files.pythonhosted.org/packages/f7/03/0a464404c51685dcb9329fdd660b1721e076ccd7b3d97dee066bcc9ffb15/psutil-7.1.2-cp37-abi3-win_amd64.whl", hash = "sha256:8e17852114c4e7996fe9da4745c2bdef001ebbf2f260dec406290e66628bdb91", size = 246714, upload-time = "2025-10-25T10:47:15.093Z" }, + { url = "https://files.pythonhosted.org/packages/6a/32/97ca2090f2f1b45b01b6aa7ae161cfe50671de097311975ca6eea3e7aabc/psutil-7.1.2-cp37-abi3-win_arm64.whl", hash = "sha256:3e988455e61c240cc879cb62a008c2699231bf3e3d061d7fce4234463fd2abb4", size = 243742, upload-time = "2025-10-25T10:47:17.302Z" }, ] [[package]] @@ -1411,24 +1580,24 @@ wheels = [ [[package]] name = "pyarrow" -version = "21.0.0" +version = "22.0.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ef/c2/ea068b8f00905c06329a3dfcd40d0fcc2b7d0f2e355bdb25b65e0a0e4cd4/pyarrow-21.0.0.tar.gz", hash = "sha256:5051f2dccf0e283ff56335760cbc8622cf52264d67e359d5569541ac11b6d5bc", size = 1133487, upload-time = "2025-07-18T00:57:31.761Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/53/04a7fdc63e6056116c9ddc8b43bc28c12cdd181b85cbeadb79278475f3ae/pyarrow-22.0.0.tar.gz", hash = "sha256:3d600dc583260d845c7d8a6db540339dd883081925da2bd1c5cb808f720b3cd9", size = 1151151, upload-time = "2025-10-24T12:30:00.762Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/ca/c7eaa8e62db8fb37ce942b1ea0c6d7abfe3786ca193957afa25e71b81b66/pyarrow-21.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e99310a4ebd4479bcd1964dff9e14af33746300cb014aa4a3781738ac63baf4a", size = 31154306, upload-time = "2025-07-18T00:56:04.42Z" }, - { url = "https://files.pythonhosted.org/packages/ce/e8/e87d9e3b2489302b3a1aea709aaca4b781c5252fcb812a17ab6275a9a484/pyarrow-21.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:d2fe8e7f3ce329a71b7ddd7498b3cfac0eeb200c2789bd840234f0dc271a8efe", size = 32680622, upload-time = "2025-07-18T00:56:07.505Z" }, - { url = "https://files.pythonhosted.org/packages/84/52/79095d73a742aa0aba370c7942b1b655f598069489ab387fe47261a849e1/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:f522e5709379d72fb3da7785aa489ff0bb87448a9dc5a75f45763a795a089ebd", size = 41104094, upload-time = "2025-07-18T00:56:10.994Z" }, - { url = "https://files.pythonhosted.org/packages/89/4b/7782438b551dbb0468892a276b8c789b8bbdb25ea5c5eb27faadd753e037/pyarrow-21.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:69cbbdf0631396e9925e048cfa5bce4e8c3d3b41562bbd70c685a8eb53a91e61", size = 42825576, upload-time = "2025-07-18T00:56:15.569Z" }, - { url = "https://files.pythonhosted.org/packages/b3/62/0f29de6e0a1e33518dec92c65be0351d32d7ca351e51ec5f4f837a9aab91/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:731c7022587006b755d0bdb27626a1a3bb004bb56b11fb30d98b6c1b4718579d", size = 43368342, upload-time = "2025-07-18T00:56:19.531Z" }, - { url = "https://files.pythonhosted.org/packages/90/c7/0fa1f3f29cf75f339768cc698c8ad4ddd2481c1742e9741459911c9ac477/pyarrow-21.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:dc56bc708f2d8ac71bd1dcb927e458c93cec10b98eb4120206a4091db7b67b99", size = 45131218, upload-time = "2025-07-18T00:56:23.347Z" }, - { url = "https://files.pythonhosted.org/packages/01/63/581f2076465e67b23bc5a37d4a2abff8362d389d29d8105832e82c9c811c/pyarrow-21.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:186aa00bca62139f75b7de8420f745f2af12941595bbbfa7ed3870ff63e25636", size = 26087551, upload-time = "2025-07-18T00:56:26.758Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ab/357d0d9648bb8241ee7348e564f2479d206ebe6e1c47ac5027c2e31ecd39/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:a7a102574faa3f421141a64c10216e078df467ab9576684d5cd696952546e2da", size = 31290064, upload-time = "2025-07-18T00:56:30.214Z" }, - { url = "https://files.pythonhosted.org/packages/3f/8a/5685d62a990e4cac2043fc76b4661bf38d06efed55cf45a334b455bd2759/pyarrow-21.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:1e005378c4a2c6db3ada3ad4c217b381f6c886f0a80d6a316fe586b90f77efd7", size = 32727837, upload-time = "2025-07-18T00:56:33.935Z" }, - { url = "https://files.pythonhosted.org/packages/fc/de/c0828ee09525c2bafefd3e736a248ebe764d07d0fd762d4f0929dbc516c9/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:65f8e85f79031449ec8706b74504a316805217b35b6099155dd7e227eef0d4b6", size = 41014158, upload-time = "2025-07-18T00:56:37.528Z" }, - { url = "https://files.pythonhosted.org/packages/6e/26/a2865c420c50b7a3748320b614f3484bfcde8347b2639b2b903b21ce6a72/pyarrow-21.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:3a81486adc665c7eb1a2bde0224cfca6ceaba344a82a971ef059678417880eb8", size = 42667885, upload-time = "2025-07-18T00:56:41.483Z" }, - { url = "https://files.pythonhosted.org/packages/0a/f9/4ee798dc902533159250fb4321267730bc0a107d8c6889e07c3add4fe3a5/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:fc0d2f88b81dcf3ccf9a6ae17f89183762c8a94a5bdcfa09e05cfe413acf0503", size = 43276625, upload-time = "2025-07-18T00:56:48.002Z" }, - { url = "https://files.pythonhosted.org/packages/5a/da/e02544d6997037a4b0d22d8e5f66bc9315c3671371a8b18c79ade1cefe14/pyarrow-21.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:6299449adf89df38537837487a4f8d3bd91ec94354fdd2a7d30bc11c48ef6e79", size = 44951890, upload-time = "2025-07-18T00:56:52.568Z" }, - { url = "https://files.pythonhosted.org/packages/e5/4e/519c1bc1876625fe6b71e9a28287c43ec2f20f73c658b9ae1d485c0c206e/pyarrow-21.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:222c39e2c70113543982c6b34f3077962b44fca38c0bd9e68bb6781534425c10", size = 26371006, upload-time = "2025-07-18T00:56:56.379Z" }, + { url = "https://files.pythonhosted.org/packages/a6/d6/d0fac16a2963002fc22c8fa75180a838737203d558f0ed3b564c4a54eef5/pyarrow-22.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e6e95176209257803a8b3d0394f21604e796dadb643d2f7ca21b66c9c0b30c9a", size = 34204629, upload-time = "2025-10-24T10:06:20.274Z" }, + { url = "https://files.pythonhosted.org/packages/c6/9c/1d6357347fbae062ad3f17082f9ebc29cc733321e892c0d2085f42a2212b/pyarrow-22.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:001ea83a58024818826a9e3f89bf9310a114f7e26dfe404a4c32686f97bd7901", size = 35985783, upload-time = "2025-10-24T10:06:27.301Z" }, + { url = "https://files.pythonhosted.org/packages/ff/c0/782344c2ce58afbea010150df07e3a2f5fdad299cd631697ae7bd3bac6e3/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ce20fe000754f477c8a9125543f1936ea5b8867c5406757c224d745ed033e691", size = 45020999, upload-time = "2025-10-24T10:06:35.387Z" }, + { url = "https://files.pythonhosted.org/packages/1b/8b/5362443737a5307a7b67c1017c42cd104213189b4970bf607e05faf9c525/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e0a15757fccb38c410947df156f9749ae4a3c89b2393741a50521f39a8cf202a", size = 47724601, upload-time = "2025-10-24T10:06:43.551Z" }, + { url = "https://files.pythonhosted.org/packages/69/4d/76e567a4fc2e190ee6072967cb4672b7d9249ac59ae65af2d7e3047afa3b/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:cedb9dd9358e4ea1d9bce3665ce0797f6adf97ff142c8e25b46ba9cdd508e9b6", size = 48001050, upload-time = "2025-10-24T10:06:52.284Z" }, + { url = "https://files.pythonhosted.org/packages/01/5e/5653f0535d2a1aef8223cee9d92944cb6bccfee5cf1cd3f462d7cb022790/pyarrow-22.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:252be4a05f9d9185bb8c18e83764ebcfea7185076c07a7a662253af3a8c07941", size = 50307877, upload-time = "2025-10-24T10:07:02.405Z" }, + { url = "https://files.pythonhosted.org/packages/2d/f8/1d0bd75bf9328a3b826e24a16e5517cd7f9fbf8d34a3184a4566ef5a7f29/pyarrow-22.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:a4893d31e5ef780b6edcaf63122df0f8d321088bb0dee4c8c06eccb1ca28d145", size = 27977099, upload-time = "2025-10-24T10:08:07.259Z" }, + { url = "https://files.pythonhosted.org/packages/90/81/db56870c997805bf2b0f6eeeb2d68458bf4654652dccdcf1bf7a42d80903/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:f7fe3dbe871294ba70d789be16b6e7e52b418311e166e0e3cba9522f0f437fb1", size = 34336685, upload-time = "2025-10-24T10:07:11.47Z" }, + { url = "https://files.pythonhosted.org/packages/1c/98/0727947f199aba8a120f47dfc229eeb05df15bcd7a6f1b669e9f882afc58/pyarrow-22.0.0-cp313-cp313t-macosx_12_0_x86_64.whl", hash = "sha256:ba95112d15fd4f1105fb2402c4eab9068f0554435e9b7085924bcfaac2cc306f", size = 36032158, upload-time = "2025-10-24T10:07:18.626Z" }, + { url = "https://files.pythonhosted.org/packages/96/b4/9babdef9c01720a0785945c7cf550e4acd0ebcd7bdd2e6f0aa7981fa85e2/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c064e28361c05d72eed8e744c9605cbd6d2bb7481a511c74071fd9b24bc65d7d", size = 44892060, upload-time = "2025-10-24T10:07:26.002Z" }, + { url = "https://files.pythonhosted.org/packages/f8/ca/2f8804edd6279f78a37062d813de3f16f29183874447ef6d1aadbb4efa0f/pyarrow-22.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:6f9762274496c244d951c819348afbcf212714902742225f649cf02823a6a10f", size = 47504395, upload-time = "2025-10-24T10:07:34.09Z" }, + { url = "https://files.pythonhosted.org/packages/b9/f0/77aa5198fd3943682b2e4faaf179a674f0edea0d55d326d83cb2277d9363/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:a9d9ffdc2ab696f6b15b4d1f7cec6658e1d788124418cb30030afbae31c64746", size = 48066216, upload-time = "2025-10-24T10:07:43.528Z" }, + { url = "https://files.pythonhosted.org/packages/79/87/a1937b6e78b2aff18b706d738c9e46ade5bfcf11b294e39c87706a0089ac/pyarrow-22.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:ec1a15968a9d80da01e1d30349b2b0d7cc91e96588ee324ce1b5228175043e95", size = 50288552, upload-time = "2025-10-24T10:07:53.519Z" }, + { url = "https://files.pythonhosted.org/packages/60/ae/b5a5811e11f25788ccfdaa8f26b6791c9807119dffcf80514505527c384c/pyarrow-22.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:bba208d9c7decf9961998edf5c65e3ea4355d5818dd6cd0f6809bec1afb951cc", size = 28262504, upload-time = "2025-10-24T10:08:00.932Z" }, ] [[package]] @@ -1571,11 +1740,11 @@ wheels = [ [[package]] name = "python-dotenv" -version = "1.1.1" +version = "1.2.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f6/b0/4bc07ccd3572a2f9df7e6782f52b0c6c90dcbb803ac4a167702d7d0dfe1e/python_dotenv-1.1.1.tar.gz", hash = "sha256:a8a6399716257f45be6a007360200409fce5cda2661e3dec71d23dc15f6189ab", size = 41978, upload-time = "2025-06-24T04:21:07.341Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl", hash = "sha256:31f23644fe2602f88ff55e1f5c79ba497e01224ee7737937930c448e4d0e24dc", size = 20556, upload-time = "2025-06-24T04:21:06.073Z" }, + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, ] [[package]] @@ -1653,38 +1822,38 @@ wheels = [ [[package]] name = "regex" -version = "2025.9.18" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/49/d3/eaa0d28aba6ad1827ad1e716d9a93e1ba963ada61887498297d3da715133/regex-2025.9.18.tar.gz", hash = "sha256:c5ba23274c61c6fef447ba6a39333297d0c247f53059dba0bca415cac511edc4", size = 400917, upload-time = "2025-09-19T00:38:35.79Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/c7/5c48206a60ce33711cf7dcaeaed10dd737733a3569dc7e1dce324dd48f30/regex-2025.9.18-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2a40f929cd907c7e8ac7566ac76225a77701a6221bca937bdb70d56cb61f57b2", size = 485955, upload-time = "2025-09-19T00:36:26.822Z" }, - { url = "https://files.pythonhosted.org/packages/e9/be/74fc6bb19a3c491ec1ace943e622b5a8539068771e8705e469b2da2306a7/regex-2025.9.18-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c90471671c2cdf914e58b6af62420ea9ecd06d1554d7474d50133ff26ae88feb", size = 289583, upload-time = "2025-09-19T00:36:28.577Z" }, - { url = "https://files.pythonhosted.org/packages/25/c4/9ceaa433cb5dc515765560f22a19578b95b92ff12526e5a259321c4fc1a0/regex-2025.9.18-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1a351aff9e07a2dabb5022ead6380cff17a4f10e4feb15f9100ee56c4d6d06af", size = 287000, upload-time = "2025-09-19T00:36:30.161Z" }, - { url = "https://files.pythonhosted.org/packages/7d/e6/68bc9393cb4dc68018456568c048ac035854b042bc7c33cb9b99b0680afa/regex-2025.9.18-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bc4b8e9d16e20ddfe16430c23468a8707ccad3365b06d4536142e71823f3ca29", size = 797535, upload-time = "2025-09-19T00:36:31.876Z" }, - { url = "https://files.pythonhosted.org/packages/6a/1c/ebae9032d34b78ecfe9bd4b5e6575b55351dc8513485bb92326613732b8c/regex-2025.9.18-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4b8cdbddf2db1c5e80338ba2daa3cfa3dec73a46fff2a7dda087c8efbf12d62f", size = 862603, upload-time = "2025-09-19T00:36:33.344Z" }, - { url = "https://files.pythonhosted.org/packages/3b/74/12332c54b3882557a4bcd2b99f8be581f5c6a43cf1660a85b460dd8ff468/regex-2025.9.18-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a276937d9d75085b2c91fb48244349c6954f05ee97bba0963ce24a9d915b8b68", size = 910829, upload-time = "2025-09-19T00:36:34.826Z" }, - { url = "https://files.pythonhosted.org/packages/86/70/ba42d5ed606ee275f2465bfc0e2208755b06cdabd0f4c7c4b614d51b57ab/regex-2025.9.18-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92a8e375ccdc1256401c90e9dc02b8642894443d549ff5e25e36d7cf8a80c783", size = 802059, upload-time = "2025-09-19T00:36:36.664Z" }, - { url = "https://files.pythonhosted.org/packages/da/c5/fcb017e56396a7f2f8357412638d7e2963440b131a3ca549be25774b3641/regex-2025.9.18-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:0dc6893b1f502d73037cf807a321cdc9be29ef3d6219f7970f842475873712ac", size = 786781, upload-time = "2025-09-19T00:36:38.168Z" }, - { url = "https://files.pythonhosted.org/packages/c6/ee/21c4278b973f630adfb3bcb23d09d83625f3ab1ca6e40ebdffe69901c7a1/regex-2025.9.18-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:a61e85bfc63d232ac14b015af1261f826260c8deb19401c0597dbb87a864361e", size = 856578, upload-time = "2025-09-19T00:36:40.129Z" }, - { url = "https://files.pythonhosted.org/packages/87/0b/de51550dc7274324435c8f1539373ac63019b0525ad720132866fff4a16a/regex-2025.9.18-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:1ef86a9ebc53f379d921fb9a7e42b92059ad3ee800fcd9e0fe6181090e9f6c23", size = 849119, upload-time = "2025-09-19T00:36:41.651Z" }, - { url = "https://files.pythonhosted.org/packages/60/52/383d3044fc5154d9ffe4321696ee5b2ee4833a28c29b137c22c33f41885b/regex-2025.9.18-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:d3bc882119764ba3a119fbf2bd4f1b47bc56c1da5d42df4ed54ae1e8e66fdf8f", size = 788219, upload-time = "2025-09-19T00:36:43.575Z" }, - { url = "https://files.pythonhosted.org/packages/20/bd/2614fc302671b7359972ea212f0e3a92df4414aaeacab054a8ce80a86073/regex-2025.9.18-cp313-cp313-win32.whl", hash = "sha256:3810a65675845c3bdfa58c3c7d88624356dd6ee2fc186628295e0969005f928d", size = 264517, upload-time = "2025-09-19T00:36:45.503Z" }, - { url = "https://files.pythonhosted.org/packages/07/0f/ab5c1581e6563a7bffdc1974fb2d25f05689b88e2d416525271f232b1946/regex-2025.9.18-cp313-cp313-win_amd64.whl", hash = "sha256:16eaf74b3c4180ede88f620f299e474913ab6924d5c4b89b3833bc2345d83b3d", size = 275481, upload-time = "2025-09-19T00:36:46.965Z" }, - { url = "https://files.pythonhosted.org/packages/49/22/ee47672bc7958f8c5667a587c2600a4fba8b6bab6e86bd6d3e2b5f7cac42/regex-2025.9.18-cp313-cp313-win_arm64.whl", hash = "sha256:4dc98ba7dd66bd1261927a9f49bd5ee2bcb3660f7962f1ec02617280fc00f5eb", size = 268598, upload-time = "2025-09-19T00:36:48.314Z" }, - { url = "https://files.pythonhosted.org/packages/e8/83/6887e16a187c6226cb85d8301e47d3b73ecc4505a3a13d8da2096b44fd76/regex-2025.9.18-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:fe5d50572bc885a0a799410a717c42b1a6b50e2f45872e2b40f4f288f9bce8a2", size = 489765, upload-time = "2025-09-19T00:36:49.996Z" }, - { url = "https://files.pythonhosted.org/packages/51/c5/e2f7325301ea2916ff301c8d963ba66b1b2c1b06694191df80a9c4fea5d0/regex-2025.9.18-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:1b9d9a2d6cda6621551ca8cf7a06f103adf72831153f3c0d982386110870c4d3", size = 291228, upload-time = "2025-09-19T00:36:51.654Z" }, - { url = "https://files.pythonhosted.org/packages/91/60/7d229d2bc6961289e864a3a3cfebf7d0d250e2e65323a8952cbb7e22d824/regex-2025.9.18-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:13202e4c4ac0ef9a317fff817674b293c8f7e8c68d3190377d8d8b749f566e12", size = 289270, upload-time = "2025-09-19T00:36:53.118Z" }, - { url = "https://files.pythonhosted.org/packages/3c/d7/b4f06868ee2958ff6430df89857fbf3d43014bbf35538b6ec96c2704e15d/regex-2025.9.18-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:874ff523b0fecffb090f80ae53dc93538f8db954c8bb5505f05b7787ab3402a0", size = 806326, upload-time = "2025-09-19T00:36:54.631Z" }, - { url = "https://files.pythonhosted.org/packages/d6/e4/bca99034a8f1b9b62ccf337402a8e5b959dd5ba0e5e5b2ead70273df3277/regex-2025.9.18-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d13ab0490128f2bb45d596f754148cd750411afc97e813e4b3a61cf278a23bb6", size = 871556, upload-time = "2025-09-19T00:36:56.208Z" }, - { url = "https://files.pythonhosted.org/packages/6d/df/e06ffaf078a162f6dd6b101a5ea9b44696dca860a48136b3ae4a9caf25e2/regex-2025.9.18-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:05440bc172bc4b4b37fb9667e796597419404dbba62e171e1f826d7d2a9ebcef", size = 913817, upload-time = "2025-09-19T00:36:57.807Z" }, - { url = "https://files.pythonhosted.org/packages/9e/05/25b05480b63292fd8e84800b1648e160ca778127b8d2367a0a258fa2e225/regex-2025.9.18-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5514b8e4031fdfaa3d27e92c75719cbe7f379e28cacd939807289bce76d0e35a", size = 811055, upload-time = "2025-09-19T00:36:59.762Z" }, - { url = "https://files.pythonhosted.org/packages/70/97/7bc7574655eb651ba3a916ed4b1be6798ae97af30104f655d8efd0cab24b/regex-2025.9.18-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:65d3c38c39efce73e0d9dc019697b39903ba25b1ad45ebbd730d2cf32741f40d", size = 794534, upload-time = "2025-09-19T00:37:01.405Z" }, - { url = "https://files.pythonhosted.org/packages/b4/c2/d5da49166a52dda879855ecdba0117f073583db2b39bb47ce9a3378a8e9e/regex-2025.9.18-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:ae77e447ebc144d5a26d50055c6ddba1d6ad4a865a560ec7200b8b06bc529368", size = 866684, upload-time = "2025-09-19T00:37:03.441Z" }, - { url = "https://files.pythonhosted.org/packages/bd/2d/0a5c4e6ec417de56b89ff4418ecc72f7e3feca806824c75ad0bbdae0516b/regex-2025.9.18-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:e3ef8cf53dc8df49d7e28a356cf824e3623764e9833348b655cfed4524ab8a90", size = 853282, upload-time = "2025-09-19T00:37:04.985Z" }, - { url = "https://files.pythonhosted.org/packages/f4/8e/d656af63e31a86572ec829665d6fa06eae7e144771e0330650a8bb865635/regex-2025.9.18-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:9feb29817df349c976da9a0debf775c5c33fc1c8ad7b9f025825da99374770b7", size = 797830, upload-time = "2025-09-19T00:37:06.697Z" }, - { url = "https://files.pythonhosted.org/packages/db/ce/06edc89df8f7b83ffd321b6071be4c54dc7332c0f77860edc40ce57d757b/regex-2025.9.18-cp313-cp313t-win32.whl", hash = "sha256:168be0d2f9b9d13076940b1ed774f98595b4e3c7fc54584bba81b3cc4181742e", size = 267281, upload-time = "2025-09-19T00:37:08.568Z" }, - { url = "https://files.pythonhosted.org/packages/83/9a/2b5d9c8b307a451fd17068719d971d3634ca29864b89ed5c18e499446d4a/regex-2025.9.18-cp313-cp313t-win_amd64.whl", hash = "sha256:d59ecf3bb549e491c8104fea7313f3563c7b048e01287db0a90485734a70a730", size = 278724, upload-time = "2025-09-19T00:37:10.023Z" }, - { url = "https://files.pythonhosted.org/packages/3d/70/177d31e8089a278a764f8ec9a3faac8d14a312d622a47385d4b43905806f/regex-2025.9.18-cp313-cp313t-win_arm64.whl", hash = "sha256:dbef80defe9fb21310948a2595420b36c6d641d9bea4c991175829b2cc4bc06a", size = 269771, upload-time = "2025-09-19T00:37:13.041Z" }, +version = "2025.10.23" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/c8/1d2160d36b11fbe0a61acb7c3c81ab032d9ec8ad888ac9e0a61b85ab99dd/regex-2025.10.23.tar.gz", hash = "sha256:8cbaf8ceb88f96ae2356d01b9adf5e6306fa42fa6f7eab6b97794e37c959ac26", size = 401266, upload-time = "2025-10-21T15:58:20.23Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/28/c6/195a6217a43719d5a6a12cc192a22d12c40290cecfa577f00f4fb822f07d/regex-2025.10.23-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:b7690f95404a1293923a296981fd943cca12c31a41af9c21ba3edd06398fc193", size = 488956, upload-time = "2025-10-21T15:55:42.887Z" }, + { url = "https://files.pythonhosted.org/packages/4c/93/181070cd1aa2fa541ff2d3afcf763ceecd4937b34c615fa92765020a6c90/regex-2025.10.23-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1a32d77aeaea58a13230100dd8797ac1a84c457f3af2fdf0d81ea689d5a9105b", size = 290997, upload-time = "2025-10-21T15:55:44.53Z" }, + { url = "https://files.pythonhosted.org/packages/b6/c5/9d37fbe3a40ed8dda78c23e1263002497540c0d1522ed75482ef6c2000f0/regex-2025.10.23-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b24b29402f264f70a3c81f45974323b41764ff7159655360543b7cabb73e7d2f", size = 288686, upload-time = "2025-10-21T15:55:46.186Z" }, + { url = "https://files.pythonhosted.org/packages/5f/e7/db610ff9f10c2921f9b6ac0c8d8be4681b28ddd40fc0549429366967e61f/regex-2025.10.23-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:563824a08c7c03d96856d84b46fdb3bbb7cfbdf79da7ef68725cda2ce169c72a", size = 798466, upload-time = "2025-10-21T15:55:48.24Z" }, + { url = "https://files.pythonhosted.org/packages/90/10/aab883e1fa7fe2feb15ac663026e70ca0ae1411efa0c7a4a0342d9545015/regex-2025.10.23-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a0ec8bdd88d2e2659c3518087ee34b37e20bd169419ffead4240a7004e8ed03b", size = 863996, upload-time = "2025-10-21T15:55:50.478Z" }, + { url = "https://files.pythonhosted.org/packages/a2/b0/8f686dd97a51f3b37d0238cd00a6d0f9ccabe701f05b56de1918571d0d61/regex-2025.10.23-cp313-cp313-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b577601bfe1d33913fcd9276d7607bbac827c4798d9e14d04bf37d417a6c41cb", size = 912145, upload-time = "2025-10-21T15:55:52.215Z" }, + { url = "https://files.pythonhosted.org/packages/a3/ca/639f8cd5b08797bca38fc5e7e07f76641a428cf8c7fca05894caf045aa32/regex-2025.10.23-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7c9f2c68ac6cb3de94eea08a437a75eaa2bd33f9e97c84836ca0b610a5804368", size = 803370, upload-time = "2025-10-21T15:55:53.944Z" }, + { url = "https://files.pythonhosted.org/packages/0d/1e/a40725bb76959eddf8abc42a967bed6f4851b39f5ac4f20e9794d7832aa5/regex-2025.10.23-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:89f8b9ea3830c79468e26b0e21c3585f69f105157c2154a36f6b7839f8afb351", size = 787767, upload-time = "2025-10-21T15:55:56.004Z" }, + { url = "https://files.pythonhosted.org/packages/3d/d8/8ee9858062936b0f99656dce390aa667c6e7fb0c357b1b9bf76fb5e2e708/regex-2025.10.23-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:98fd84c4e4ea185b3bb5bf065261ab45867d8875032f358a435647285c722673", size = 858335, upload-time = "2025-10-21T15:55:58.185Z" }, + { url = "https://files.pythonhosted.org/packages/d8/0a/ed5faaa63fa8e3064ab670e08061fbf09e3a10235b19630cf0cbb9e48c0a/regex-2025.10.23-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:1e11d3e5887b8b096f96b4154dfb902f29c723a9556639586cd140e77e28b313", size = 850402, upload-time = "2025-10-21T15:56:00.023Z" }, + { url = "https://files.pythonhosted.org/packages/79/14/d05f617342f4b2b4a23561da500ca2beab062bfcc408d60680e77ecaf04d/regex-2025.10.23-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f13450328a6634348d47a88367e06b64c9d84980ef6a748f717b13f8ce64e87", size = 789739, upload-time = "2025-10-21T15:56:01.967Z" }, + { url = "https://files.pythonhosted.org/packages/f9/7b/e8ce8eef42a15f2c3461f8b3e6e924bbc86e9605cb534a393aadc8d3aff8/regex-2025.10.23-cp313-cp313-win32.whl", hash = "sha256:37be9296598a30c6a20236248cb8b2c07ffd54d095b75d3a2a2ee5babdc51df1", size = 266054, upload-time = "2025-10-21T15:56:05.291Z" }, + { url = "https://files.pythonhosted.org/packages/71/2d/55184ed6be6473187868d2f2e6a0708195fc58270e62a22cbf26028f2570/regex-2025.10.23-cp313-cp313-win_amd64.whl", hash = "sha256:ea7a3c283ce0f06fe789365841e9174ba05f8db16e2fd6ae00a02df9572c04c0", size = 276917, upload-time = "2025-10-21T15:56:07.303Z" }, + { url = "https://files.pythonhosted.org/packages/9c/d4/927eced0e2bd45c45839e556f987f8c8f8683268dd3c00ad327deb3b0172/regex-2025.10.23-cp313-cp313-win_arm64.whl", hash = "sha256:d9a4953575f300a7bab71afa4cd4ac061c7697c89590a2902b536783eeb49a4f", size = 270105, upload-time = "2025-10-21T15:56:09.857Z" }, + { url = "https://files.pythonhosted.org/packages/3e/b3/95b310605285573341fc062d1d30b19a54f857530e86c805f942c4ff7941/regex-2025.10.23-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:7d6606524fa77b3912c9ef52a42ef63c6cfbfc1077e9dc6296cd5da0da286044", size = 491850, upload-time = "2025-10-21T15:56:11.685Z" }, + { url = "https://files.pythonhosted.org/packages/a4/8f/207c2cec01e34e56db1eff606eef46644a60cf1739ecd474627db90ad90b/regex-2025.10.23-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:c037aadf4d64bdc38af7db3dbd34877a057ce6524eefcb2914d6d41c56f968cc", size = 292537, upload-time = "2025-10-21T15:56:13.963Z" }, + { url = "https://files.pythonhosted.org/packages/98/3b/025240af4ada1dc0b5f10d73f3e5122d04ce7f8908ab8881e5d82b9d61b6/regex-2025.10.23-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:99018c331fb2529084a0c9b4c713dfa49fafb47c7712422e49467c13a636c656", size = 290904, upload-time = "2025-10-21T15:56:16.016Z" }, + { url = "https://files.pythonhosted.org/packages/81/8e/104ac14e2d3450c43db18ec03e1b96b445a94ae510b60138f00ce2cb7ca1/regex-2025.10.23-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fd8aba965604d70306eb90a35528f776e59112a7114a5162824d43b76fa27f58", size = 807311, upload-time = "2025-10-21T15:56:17.818Z" }, + { url = "https://files.pythonhosted.org/packages/19/63/78aef90141b7ce0be8a18e1782f764f6997ad09de0e05251f0d2503a914a/regex-2025.10.23-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:238e67264b4013e74136c49f883734f68656adf8257bfa13b515626b31b20f8e", size = 873241, upload-time = "2025-10-21T15:56:19.941Z" }, + { url = "https://files.pythonhosted.org/packages/b3/a8/80eb1201bb49ae4dba68a1b284b4211ed9daa8e74dc600018a10a90399fb/regex-2025.10.23-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b2eb48bd9848d66fd04826382f5e8491ae633de3233a3d64d58ceb4ecfa2113a", size = 914794, upload-time = "2025-10-21T15:56:22.488Z" }, + { url = "https://files.pythonhosted.org/packages/f0/d5/1984b6ee93281f360a119a5ca1af6a8ca7d8417861671388bf750becc29b/regex-2025.10.23-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d36591ce06d047d0c0fe2fc5f14bfbd5b4525d08a7b6a279379085e13f0e3d0e", size = 812581, upload-time = "2025-10-21T15:56:24.319Z" }, + { url = "https://files.pythonhosted.org/packages/c4/39/11ebdc6d9927172a64ae237d16763145db6bd45ebb4055c17b88edab72a7/regex-2025.10.23-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b5d4ece8628d6e364302006366cea3ee887db397faebacc5dacf8ef19e064cf8", size = 795346, upload-time = "2025-10-21T15:56:26.232Z" }, + { url = "https://files.pythonhosted.org/packages/3b/b4/89a591bcc08b5e436af43315284bd233ba77daf0cf20e098d7af12f006c1/regex-2025.10.23-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:39a7e8083959cb1c4ff74e483eecb5a65d3b3e1d821b256e54baf61782c906c6", size = 868214, upload-time = "2025-10-21T15:56:28.597Z" }, + { url = "https://files.pythonhosted.org/packages/3d/ff/58ba98409c1dbc8316cdb20dafbc63ed267380a07780cafecaf5012dabc9/regex-2025.10.23-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:842d449a8fefe546f311656cf8c0d6729b08c09a185f1cad94c756210286d6a8", size = 854540, upload-time = "2025-10-21T15:56:30.875Z" }, + { url = "https://files.pythonhosted.org/packages/9a/f2/4a9e9338d67626e2071b643f828a482712ad15889d7268e11e9a63d6f7e9/regex-2025.10.23-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d614986dc68506be8f00474f4f6960e03e4ca9883f7df47744800e7d7c08a494", size = 799346, upload-time = "2025-10-21T15:56:32.725Z" }, + { url = "https://files.pythonhosted.org/packages/63/be/543d35c46bebf6f7bf2be538cca74d6585f25714700c36f37f01b92df551/regex-2025.10.23-cp313-cp313t-win32.whl", hash = "sha256:a5b7a26b51a9df473ec16a1934d117443a775ceb7b39b78670b2e21893c330c9", size = 268657, upload-time = "2025-10-21T15:56:34.577Z" }, + { url = "https://files.pythonhosted.org/packages/14/9f/4dd6b7b612037158bb2c9bcaa710e6fb3c40ad54af441b9c53b3a137a9f1/regex-2025.10.23-cp313-cp313t-win_amd64.whl", hash = "sha256:ce81c5544a5453f61cb6f548ed358cfb111e3b23f3cd42d250a4077a6be2a7b6", size = 280075, upload-time = "2025-10-21T15:56:36.767Z" }, + { url = "https://files.pythonhosted.org/packages/81/7a/5bd0672aa65d38c8da6747c17c8b441bdb53d816c569e3261013af8e83cf/regex-2025.10.23-cp313-cp313t-win_arm64.whl", hash = "sha256:e9bf7f6699f490e4e43c44757aa179dab24d1960999c84ab5c3d5377714ed473", size = 271219, upload-time = "2025-10-21T15:56:39.033Z" }, ] [[package]] @@ -1704,65 +1873,65 @@ wheels = [ [[package]] name = "rpds-py" -version = "0.27.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e9/dd/2c0cbe774744272b0ae725f44032c77bdcab6e8bcf544bffa3b6e70c8dba/rpds_py-0.27.1.tar.gz", hash = "sha256:26a1c73171d10b7acccbded82bf6a586ab8203601e565badc74bbbf8bc5a10f8", size = 27479, upload-time = "2025-08-27T12:16:36.024Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/77/610aeee8d41e39080c7e14afa5387138e3c9fa9756ab893d09d99e7d8e98/rpds_py-0.27.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e4b9fcfbc021633863a37e92571d6f91851fa656f0180246e84cbd8b3f6b329b", size = 361741, upload-time = "2025-08-27T12:13:31.039Z" }, - { url = "https://files.pythonhosted.org/packages/3a/fc/c43765f201c6a1c60be2043cbdb664013def52460a4c7adace89d6682bf4/rpds_py-0.27.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1441811a96eadca93c517d08df75de45e5ffe68aa3089924f963c782c4b898cf", size = 345574, upload-time = "2025-08-27T12:13:32.902Z" }, - { url = "https://files.pythonhosted.org/packages/20/42/ee2b2ca114294cd9847d0ef9c26d2b0851b2e7e00bf14cc4c0b581df0fc3/rpds_py-0.27.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55266dafa22e672f5a4f65019015f90336ed31c6383bd53f5e7826d21a0e0b83", size = 385051, upload-time = "2025-08-27T12:13:34.228Z" }, - { url = "https://files.pythonhosted.org/packages/fd/e8/1e430fe311e4799e02e2d1af7c765f024e95e17d651612425b226705f910/rpds_py-0.27.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d78827d7ac08627ea2c8e02c9e5b41180ea5ea1f747e9db0915e3adf36b62dcf", size = 398395, upload-time = "2025-08-27T12:13:36.132Z" }, - { url = "https://files.pythonhosted.org/packages/82/95/9dc227d441ff2670651c27a739acb2535ccaf8b351a88d78c088965e5996/rpds_py-0.27.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae92443798a40a92dc5f0b01d8a7c93adde0c4dc965310a29ae7c64d72b9fad2", size = 524334, upload-time = "2025-08-27T12:13:37.562Z" }, - { url = "https://files.pythonhosted.org/packages/87/01/a670c232f401d9ad461d9a332aa4080cd3cb1d1df18213dbd0d2a6a7ab51/rpds_py-0.27.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c46c9dd2403b66a2a3b9720ec4b74d4ab49d4fabf9f03dfdce2d42af913fe8d0", size = 407691, upload-time = "2025-08-27T12:13:38.94Z" }, - { url = "https://files.pythonhosted.org/packages/03/36/0a14aebbaa26fe7fab4780c76f2239e76cc95a0090bdb25e31d95c492fcd/rpds_py-0.27.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2efe4eb1d01b7f5f1939f4ef30ecea6c6b3521eec451fb93191bf84b2a522418", size = 386868, upload-time = "2025-08-27T12:13:40.192Z" }, - { url = "https://files.pythonhosted.org/packages/3b/03/8c897fb8b5347ff6c1cc31239b9611c5bf79d78c984430887a353e1409a1/rpds_py-0.27.1-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:15d3b4d83582d10c601f481eca29c3f138d44c92187d197aff663a269197c02d", size = 405469, upload-time = "2025-08-27T12:13:41.496Z" }, - { url = "https://files.pythonhosted.org/packages/da/07/88c60edc2df74850d496d78a1fdcdc7b54360a7f610a4d50008309d41b94/rpds_py-0.27.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:4ed2e16abbc982a169d30d1a420274a709949e2cbdef119fe2ec9d870b42f274", size = 422125, upload-time = "2025-08-27T12:13:42.802Z" }, - { url = "https://files.pythonhosted.org/packages/6b/86/5f4c707603e41b05f191a749984f390dabcbc467cf833769b47bf14ba04f/rpds_py-0.27.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a75f305c9b013289121ec0f1181931975df78738cdf650093e6b86d74aa7d8dd", size = 562341, upload-time = "2025-08-27T12:13:44.472Z" }, - { url = "https://files.pythonhosted.org/packages/b2/92/3c0cb2492094e3cd9baf9e49bbb7befeceb584ea0c1a8b5939dca4da12e5/rpds_py-0.27.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:67ce7620704745881a3d4b0ada80ab4d99df390838839921f99e63c474f82cf2", size = 592511, upload-time = "2025-08-27T12:13:45.898Z" }, - { url = "https://files.pythonhosted.org/packages/10/bb/82e64fbb0047c46a168faa28d0d45a7851cd0582f850b966811d30f67ad8/rpds_py-0.27.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:9d992ac10eb86d9b6f369647b6a3f412fc0075cfd5d799530e84d335e440a002", size = 557736, upload-time = "2025-08-27T12:13:47.408Z" }, - { url = "https://files.pythonhosted.org/packages/00/95/3c863973d409210da7fb41958172c6b7dbe7fc34e04d3cc1f10bb85e979f/rpds_py-0.27.1-cp313-cp313-win32.whl", hash = "sha256:4f75e4bd8ab8db624e02c8e2fc4063021b58becdbe6df793a8111d9343aec1e3", size = 221462, upload-time = "2025-08-27T12:13:48.742Z" }, - { url = "https://files.pythonhosted.org/packages/ce/2c/5867b14a81dc217b56d95a9f2a40fdbc56a1ab0181b80132beeecbd4b2d6/rpds_py-0.27.1-cp313-cp313-win_amd64.whl", hash = "sha256:f9025faafc62ed0b75a53e541895ca272815bec18abe2249ff6501c8f2e12b83", size = 232034, upload-time = "2025-08-27T12:13:50.11Z" }, - { url = "https://files.pythonhosted.org/packages/c7/78/3958f3f018c01923823f1e47f1cc338e398814b92d83cd278364446fac66/rpds_py-0.27.1-cp313-cp313-win_arm64.whl", hash = "sha256:ed10dc32829e7d222b7d3b93136d25a406ba9788f6a7ebf6809092da1f4d279d", size = 222392, upload-time = "2025-08-27T12:13:52.587Z" }, - { url = "https://files.pythonhosted.org/packages/01/76/1cdf1f91aed5c3a7bf2eba1f1c4e4d6f57832d73003919a20118870ea659/rpds_py-0.27.1-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:92022bbbad0d4426e616815b16bc4127f83c9a74940e1ccf3cfe0b387aba0228", size = 358355, upload-time = "2025-08-27T12:13:54.012Z" }, - { url = "https://files.pythonhosted.org/packages/c3/6f/bf142541229374287604caf3bb2a4ae17f0a580798fd72d3b009b532db4e/rpds_py-0.27.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:47162fdab9407ec3f160805ac3e154df042e577dd53341745fc7fb3f625e6d92", size = 342138, upload-time = "2025-08-27T12:13:55.791Z" }, - { url = "https://files.pythonhosted.org/packages/1a/77/355b1c041d6be40886c44ff5e798b4e2769e497b790f0f7fd1e78d17e9a8/rpds_py-0.27.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb89bec23fddc489e5d78b550a7b773557c9ab58b7946154a10a6f7a214a48b2", size = 380247, upload-time = "2025-08-27T12:13:57.683Z" }, - { url = "https://files.pythonhosted.org/packages/d6/a4/d9cef5c3946ea271ce2243c51481971cd6e34f21925af2783dd17b26e815/rpds_py-0.27.1-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e48af21883ded2b3e9eb48cb7880ad8598b31ab752ff3be6457001d78f416723", size = 390699, upload-time = "2025-08-27T12:13:59.137Z" }, - { url = "https://files.pythonhosted.org/packages/3a/06/005106a7b8c6c1a7e91b73169e49870f4af5256119d34a361ae5240a0c1d/rpds_py-0.27.1-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6f5b7bd8e219ed50299e58551a410b64daafb5017d54bbe822e003856f06a802", size = 521852, upload-time = "2025-08-27T12:14:00.583Z" }, - { url = "https://files.pythonhosted.org/packages/e5/3e/50fb1dac0948e17a02eb05c24510a8fe12d5ce8561c6b7b7d1339ab7ab9c/rpds_py-0.27.1-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:08f1e20bccf73b08d12d804d6e1c22ca5530e71659e6673bce31a6bb71c1e73f", size = 402582, upload-time = "2025-08-27T12:14:02.034Z" }, - { url = "https://files.pythonhosted.org/packages/cb/b0/f4e224090dc5b0ec15f31a02d746ab24101dd430847c4d99123798661bfc/rpds_py-0.27.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dc5dceeaefcc96dc192e3a80bbe1d6c410c469e97bdd47494a7d930987f18b2", size = 384126, upload-time = "2025-08-27T12:14:03.437Z" }, - { url = "https://files.pythonhosted.org/packages/54/77/ac339d5f82b6afff1df8f0fe0d2145cc827992cb5f8eeb90fc9f31ef7a63/rpds_py-0.27.1-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:d76f9cc8665acdc0c9177043746775aa7babbf479b5520b78ae4002d889f5c21", size = 399486, upload-time = "2025-08-27T12:14:05.443Z" }, - { url = "https://files.pythonhosted.org/packages/d6/29/3e1c255eee6ac358c056a57d6d6869baa00a62fa32eea5ee0632039c50a3/rpds_py-0.27.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:134fae0e36022edad8290a6661edf40c023562964efea0cc0ec7f5d392d2aaef", size = 414832, upload-time = "2025-08-27T12:14:06.902Z" }, - { url = "https://files.pythonhosted.org/packages/3f/db/6d498b844342deb3fa1d030598db93937a9964fcf5cb4da4feb5f17be34b/rpds_py-0.27.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:eb11a4f1b2b63337cfd3b4d110af778a59aae51c81d195768e353d8b52f88081", size = 557249, upload-time = "2025-08-27T12:14:08.37Z" }, - { url = "https://files.pythonhosted.org/packages/60/f3/690dd38e2310b6f68858a331399b4d6dbb9132c3e8ef8b4333b96caf403d/rpds_py-0.27.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:13e608ac9f50a0ed4faec0e90ece76ae33b34c0e8656e3dceb9a7db994c692cd", size = 587356, upload-time = "2025-08-27T12:14:10.034Z" }, - { url = "https://files.pythonhosted.org/packages/86/e3/84507781cccd0145f35b1dc32c72675200c5ce8d5b30f813e49424ef68fc/rpds_py-0.27.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dd2135527aa40f061350c3f8f89da2644de26cd73e4de458e79606384f4f68e7", size = 555300, upload-time = "2025-08-27T12:14:11.783Z" }, - { url = "https://files.pythonhosted.org/packages/e5/ee/375469849e6b429b3516206b4580a79e9ef3eb12920ddbd4492b56eaacbe/rpds_py-0.27.1-cp313-cp313t-win32.whl", hash = "sha256:3020724ade63fe320a972e2ffd93b5623227e684315adce194941167fee02688", size = 216714, upload-time = "2025-08-27T12:14:13.629Z" }, - { url = "https://files.pythonhosted.org/packages/21/87/3fc94e47c9bd0742660e84706c311a860dcae4374cf4a03c477e23ce605a/rpds_py-0.27.1-cp313-cp313t-win_amd64.whl", hash = "sha256:8ee50c3e41739886606388ba3ab3ee2aae9f35fb23f833091833255a31740797", size = 228943, upload-time = "2025-08-27T12:14:14.937Z" }, +version = "0.28.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/48/dc/95f074d43452b3ef5d06276696ece4b3b5d696e7c9ad7173c54b1390cd70/rpds_py-0.28.0.tar.gz", hash = "sha256:abd4df20485a0983e2ca334a216249b6186d6e3c1627e106651943dbdb791aea", size = 27419, upload-time = "2025-10-22T22:24:29.327Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d3/03/ce566d92611dfac0085c2f4b048cd53ed7c274a5c05974b882a908d540a2/rpds_py-0.28.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e9e184408a0297086f880556b6168fa927d677716f83d3472ea333b42171ee3b", size = 366235, upload-time = "2025-10-22T22:22:28.397Z" }, + { url = "https://files.pythonhosted.org/packages/00/34/1c61da1b25592b86fd285bd7bd8422f4c9d748a7373b46126f9ae792a004/rpds_py-0.28.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:edd267266a9b0448f33dc465a97cfc5d467594b600fe28e7fa2f36450e03053a", size = 348241, upload-time = "2025-10-22T22:22:30.171Z" }, + { url = "https://files.pythonhosted.org/packages/fc/00/ed1e28616848c61c493a067779633ebf4b569eccaacf9ccbdc0e7cba2b9d/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85beb8b3f45e4e32f6802fb6cd6b17f615ef6c6a52f265371fb916fae02814aa", size = 378079, upload-time = "2025-10-22T22:22:31.644Z" }, + { url = "https://files.pythonhosted.org/packages/11/b2/ccb30333a16a470091b6e50289adb4d3ec656fd9951ba8c5e3aaa0746a67/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d2412be8d00a1b895f8ad827cc2116455196e20ed994bb704bf138fe91a42724", size = 393151, upload-time = "2025-10-22T22:22:33.453Z" }, + { url = "https://files.pythonhosted.org/packages/8c/d0/73e2217c3ee486d555cb84920597480627d8c0240ff3062005c6cc47773e/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cf128350d384b777da0e68796afdcebc2e9f63f0e9f242217754e647f6d32491", size = 517520, upload-time = "2025-10-22T22:22:34.949Z" }, + { url = "https://files.pythonhosted.org/packages/c4/91/23efe81c700427d0841a4ae7ea23e305654381831e6029499fe80be8a071/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a2036d09b363aa36695d1cc1a97b36865597f4478470b0697b5ee9403f4fe399", size = 408699, upload-time = "2025-10-22T22:22:36.584Z" }, + { url = "https://files.pythonhosted.org/packages/ca/ee/a324d3198da151820a326c1f988caaa4f37fc27955148a76fff7a2d787a9/rpds_py-0.28.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8e1e9be4fa6305a16be628959188e4fd5cd6f1b0e724d63c6d8b2a8adf74ea6", size = 385720, upload-time = "2025-10-22T22:22:38.014Z" }, + { url = "https://files.pythonhosted.org/packages/19/ad/e68120dc05af8b7cab4a789fccd8cdcf0fe7e6581461038cc5c164cd97d2/rpds_py-0.28.0-cp313-cp313-manylinux_2_31_riscv64.whl", hash = "sha256:0a403460c9dd91a7f23fc3188de6d8977f1d9603a351d5db6cf20aaea95b538d", size = 401096, upload-time = "2025-10-22T22:22:39.869Z" }, + { url = "https://files.pythonhosted.org/packages/99/90/c1e070620042459d60df6356b666bb1f62198a89d68881816a7ed121595a/rpds_py-0.28.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d7366b6553cdc805abcc512b849a519167db8f5e5c3472010cd1228b224265cb", size = 411465, upload-time = "2025-10-22T22:22:41.395Z" }, + { url = "https://files.pythonhosted.org/packages/68/61/7c195b30d57f1b8d5970f600efee72a4fad79ec829057972e13a0370fd24/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:5b43c6a3726efd50f18d8120ec0551241c38785b68952d240c45ea553912ac41", size = 558832, upload-time = "2025-10-22T22:22:42.871Z" }, + { url = "https://files.pythonhosted.org/packages/b0/3d/06f3a718864773f69941d4deccdf18e5e47dd298b4628062f004c10f3b34/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:0cb7203c7bc69d7c1585ebb33a2e6074492d2fc21ad28a7b9d40457ac2a51ab7", size = 583230, upload-time = "2025-10-22T22:22:44.877Z" }, + { url = "https://files.pythonhosted.org/packages/66/df/62fc783781a121e77fee9a21ead0a926f1b652280a33f5956a5e7833ed30/rpds_py-0.28.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7a52a5169c664dfb495882adc75c304ae1d50df552fbd68e100fdc719dee4ff9", size = 553268, upload-time = "2025-10-22T22:22:46.441Z" }, + { url = "https://files.pythonhosted.org/packages/84/85/d34366e335140a4837902d3dea89b51f087bd6a63c993ebdff59e93ee61d/rpds_py-0.28.0-cp313-cp313-win32.whl", hash = "sha256:2e42456917b6687215b3e606ab46aa6bca040c77af7df9a08a6dcfe8a4d10ca5", size = 217100, upload-time = "2025-10-22T22:22:48.342Z" }, + { url = "https://files.pythonhosted.org/packages/3c/1c/f25a3f3752ad7601476e3eff395fe075e0f7813fbb9862bd67c82440e880/rpds_py-0.28.0-cp313-cp313-win_amd64.whl", hash = "sha256:e0a0311caedc8069d68fc2bf4c9019b58a2d5ce3cd7cb656c845f1615b577e1e", size = 227759, upload-time = "2025-10-22T22:22:50.219Z" }, + { url = "https://files.pythonhosted.org/packages/e0/d6/5f39b42b99615b5bc2f36ab90423ea404830bdfee1c706820943e9a645eb/rpds_py-0.28.0-cp313-cp313-win_arm64.whl", hash = "sha256:04c1b207ab8b581108801528d59ad80aa83bb170b35b0ddffb29c20e411acdc1", size = 217326, upload-time = "2025-10-22T22:22:51.647Z" }, + { url = "https://files.pythonhosted.org/packages/5c/8b/0c69b72d1cee20a63db534be0df271effe715ef6c744fdf1ff23bb2b0b1c/rpds_py-0.28.0-cp313-cp313t-macosx_10_12_x86_64.whl", hash = "sha256:f296ea3054e11fc58ad42e850e8b75c62d9a93a9f981ad04b2e5ae7d2186ff9c", size = 355736, upload-time = "2025-10-22T22:22:53.211Z" }, + { url = "https://files.pythonhosted.org/packages/f7/6d/0c2ee773cfb55c31a8514d2cece856dd299170a49babd50dcffb15ddc749/rpds_py-0.28.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:5a7306c19b19005ad98468fcefeb7100b19c79fc23a5f24a12e06d91181193fa", size = 342677, upload-time = "2025-10-22T22:22:54.723Z" }, + { url = "https://files.pythonhosted.org/packages/e2/1c/22513ab25a27ea205144414724743e305e8153e6abe81833b5e678650f5a/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5d9b86aa501fed9862a443c5c3116f6ead8bc9296185f369277c42542bd646b", size = 371847, upload-time = "2025-10-22T22:22:56.295Z" }, + { url = "https://files.pythonhosted.org/packages/60/07/68e6ccdb4b05115ffe61d31afc94adef1833d3a72f76c9632d4d90d67954/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e5bbc701eff140ba0e872691d573b3d5d30059ea26e5785acba9132d10c8c31d", size = 381800, upload-time = "2025-10-22T22:22:57.808Z" }, + { url = "https://files.pythonhosted.org/packages/73/bf/6d6d15df80781d7f9f368e7c1a00caf764436518c4877fb28b029c4624af/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a5690671cd672a45aa8616d7374fdf334a1b9c04a0cac3c854b1136e92374fe", size = 518827, upload-time = "2025-10-22T22:22:59.826Z" }, + { url = "https://files.pythonhosted.org/packages/7b/d3/2decbb2976cc452cbf12a2b0aaac5f1b9dc5dd9d1f7e2509a3ee00421249/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9f1d92ecea4fa12f978a367c32a5375a1982834649cdb96539dcdc12e609ab1a", size = 399471, upload-time = "2025-10-22T22:23:01.968Z" }, + { url = "https://files.pythonhosted.org/packages/b1/2c/f30892f9e54bd02e5faca3f6a26d6933c51055e67d54818af90abed9748e/rpds_py-0.28.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d252db6b1a78d0a3928b6190156042d54c93660ce4d98290d7b16b5296fb7cc", size = 377578, upload-time = "2025-10-22T22:23:03.52Z" }, + { url = "https://files.pythonhosted.org/packages/f0/5d/3bce97e5534157318f29ac06bf2d279dae2674ec12f7cb9c12739cee64d8/rpds_py-0.28.0-cp313-cp313t-manylinux_2_31_riscv64.whl", hash = "sha256:d61b355c3275acb825f8777d6c4505f42b5007e357af500939d4a35b19177259", size = 390482, upload-time = "2025-10-22T22:23:05.391Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f0/886bd515ed457b5bd93b166175edb80a0b21a210c10e993392127f1e3931/rpds_py-0.28.0-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:acbe5e8b1026c0c580d0321c8aae4b0a1e1676861d48d6e8c6586625055b606a", size = 402447, upload-time = "2025-10-22T22:23:06.93Z" }, + { url = "https://files.pythonhosted.org/packages/42/b5/71e8777ac55e6af1f4f1c05b47542a1eaa6c33c1cf0d300dca6a1c6e159a/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:8aa23b6f0fc59b85b4c7d89ba2965af274346f738e8d9fc2455763602e62fd5f", size = 552385, upload-time = "2025-10-22T22:23:08.557Z" }, + { url = "https://files.pythonhosted.org/packages/5d/cb/6ca2d70cbda5a8e36605e7788c4aa3bea7c17d71d213465a5a675079b98d/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7b14b0c680286958817c22d76fcbca4800ddacef6f678f3a7c79a1fe7067fe37", size = 575642, upload-time = "2025-10-22T22:23:10.348Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d4/407ad9960ca7856d7b25c96dcbe019270b5ffdd83a561787bc682c797086/rpds_py-0.28.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:bcf1d210dfee61a6c86551d67ee1031899c0fdbae88b2d44a569995d43797712", size = 544507, upload-time = "2025-10-22T22:23:12.434Z" }, + { url = "https://files.pythonhosted.org/packages/51/31/2f46fe0efcac23fbf5797c6b6b7e1c76f7d60773e525cb65fcbc582ee0f2/rpds_py-0.28.0-cp313-cp313t-win32.whl", hash = "sha256:3aa4dc0fdab4a7029ac63959a3ccf4ed605fee048ba67ce89ca3168da34a1342", size = 205376, upload-time = "2025-10-22T22:23:13.979Z" }, + { url = "https://files.pythonhosted.org/packages/92/e4/15947bda33cbedfc134490a41841ab8870a72a867a03d4969d886f6594a2/rpds_py-0.28.0-cp313-cp313t-win_amd64.whl", hash = "sha256:7b7d9d83c942855e4fdcfa75d4f96f6b9e272d42fffcb72cd4bb2577db2e2907", size = 215907, upload-time = "2025-10-22T22:23:15.5Z" }, ] [[package]] name = "ruff" -version = "0.14.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9e/58/6ca66896635352812de66f71cdf9ff86b3a4f79071ca5730088c0cd0fc8d/ruff-0.14.1.tar.gz", hash = "sha256:1dd86253060c4772867c61791588627320abcb6ed1577a90ef432ee319729b69", size = 5513429, upload-time = "2025-10-16T18:05:41.766Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/39/9cc5ab181478d7a18adc1c1e051a84ee02bec94eb9bdfd35643d7c74ca31/ruff-0.14.1-py3-none-linux_armv6l.whl", hash = "sha256:083bfc1f30f4a391ae09c6f4f99d83074416b471775b59288956f5bc18e82f8b", size = 12445415, upload-time = "2025-10-16T18:04:48.227Z" }, - { url = "https://files.pythonhosted.org/packages/ef/2e/1226961855ccd697255988f5a2474890ac7c5863b080b15bd038df820818/ruff-0.14.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:f6fa757cd717f791009f7669fefb09121cc5f7d9bd0ef211371fad68c2b8b224", size = 12784267, upload-time = "2025-10-16T18:04:52.515Z" }, - { url = "https://files.pythonhosted.org/packages/c1/ea/fd9e95863124ed159cd0667ec98449ae461de94acda7101f1acb6066da00/ruff-0.14.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d6191903d39ac156921398e9c86b7354d15e3c93772e7dbf26c9fcae59ceccd5", size = 11781872, upload-time = "2025-10-16T18:04:55.396Z" }, - { url = "https://files.pythonhosted.org/packages/1e/5a/e890f7338ff537dba4589a5e02c51baa63020acfb7c8cbbaea4831562c96/ruff-0.14.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ed04f0e04f7a4587244e5c9d7df50e6b5bf2705d75059f409a6421c593a35896", size = 12226558, upload-time = "2025-10-16T18:04:58.166Z" }, - { url = "https://files.pythonhosted.org/packages/a6/7a/8ab5c3377f5bf31e167b73651841217542bcc7aa1c19e83030835cc25204/ruff-0.14.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5c9e6cf6cd4acae0febbce29497accd3632fe2025c0c583c8b87e8dbdeae5f61", size = 12187898, upload-time = "2025-10-16T18:05:01.455Z" }, - { url = "https://files.pythonhosted.org/packages/48/8d/ba7c33aa55406955fc124e62c8259791c3d42e3075a71710fdff9375134f/ruff-0.14.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a6fa2458527794ecdfbe45f654e42c61f2503a230545a91af839653a0a93dbc6", size = 12939168, upload-time = "2025-10-16T18:05:04.397Z" }, - { url = "https://files.pythonhosted.org/packages/b4/c2/70783f612b50f66d083380e68cbd1696739d88e9b4f6164230375532c637/ruff-0.14.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:39f1c392244e338b21d42ab29b8a6392a722c5090032eb49bb4d6defcdb34345", size = 14386942, upload-time = "2025-10-16T18:05:07.102Z" }, - { url = "https://files.pythonhosted.org/packages/48/44/cd7abb9c776b66d332119d67f96acf15830d120f5b884598a36d9d3f4d83/ruff-0.14.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7382fa12a26cce1f95070ce450946bec357727aaa428983036362579eadcc5cf", size = 13990622, upload-time = "2025-10-16T18:05:09.882Z" }, - { url = "https://files.pythonhosted.org/packages/eb/56/4259b696db12ac152fe472764b4f78bbdd9b477afd9bc3a6d53c01300b37/ruff-0.14.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dd0bf2be3ae8521e1093a487c4aa3b455882f139787770698530d28ed3fbb37c", size = 13431143, upload-time = "2025-10-16T18:05:13.46Z" }, - { url = "https://files.pythonhosted.org/packages/e0/35/266a80d0eb97bd224b3265b9437bd89dde0dcf4faf299db1212e81824e7e/ruff-0.14.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cabcaa9ccf8089fb4fdb78d17cc0e28241520f50f4c2e88cb6261ed083d85151", size = 13132844, upload-time = "2025-10-16T18:05:16.1Z" }, - { url = "https://files.pythonhosted.org/packages/65/6e/d31ce218acc11a8d91ef208e002a31acf315061a85132f94f3df7a252b18/ruff-0.14.1-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:747d583400f6125ec11a4c14d1c8474bf75d8b419ad22a111a537ec1a952d192", size = 13401241, upload-time = "2025-10-16T18:05:19.395Z" }, - { url = "https://files.pythonhosted.org/packages/9f/b5/dbc4221bf0b03774b3b2f0d47f39e848d30664157c15b965a14d890637d2/ruff-0.14.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5a6e74c0efd78515a1d13acbfe6c90f0f5bd822aa56b4a6d43a9ffb2ae6e56cd", size = 12132476, upload-time = "2025-10-16T18:05:22.163Z" }, - { url = "https://files.pythonhosted.org/packages/98/4b/ac99194e790ccd092d6a8b5f341f34b6e597d698e3077c032c502d75ea84/ruff-0.14.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:0ea6a864d2fb41a4b6d5b456ed164302a0d96f4daac630aeba829abfb059d020", size = 12139749, upload-time = "2025-10-16T18:05:25.162Z" }, - { url = "https://files.pythonhosted.org/packages/47/26/7df917462c3bb5004e6fdfcc505a49e90bcd8a34c54a051953118c00b53a/ruff-0.14.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:0826b8764f94229604fa255918d1cc45e583e38c21c203248b0bfc9a0e930be5", size = 12544758, upload-time = "2025-10-16T18:05:28.018Z" }, - { url = "https://files.pythonhosted.org/packages/64/d0/81e7f0648e9764ad9b51dd4be5e5dac3fcfff9602428ccbae288a39c2c22/ruff-0.14.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:cbc52160465913a1a3f424c81c62ac8096b6a491468e7d872cb9444a860bc33d", size = 13221811, upload-time = "2025-10-16T18:05:30.707Z" }, - { url = "https://files.pythonhosted.org/packages/c3/07/3c45562c67933cc35f6d5df4ca77dabbcd88fddaca0d6b8371693d29fd56/ruff-0.14.1-py3-none-win32.whl", hash = "sha256:e037ea374aaaff4103240ae79168c0945ae3d5ae8db190603de3b4012bd1def6", size = 12319467, upload-time = "2025-10-16T18:05:33.261Z" }, - { url = "https://files.pythonhosted.org/packages/02/88/0ee4ca507d4aa05f67e292d2e5eb0b3e358fbcfe527554a2eda9ac422d6b/ruff-0.14.1-py3-none-win_amd64.whl", hash = "sha256:59d599cdff9c7f925a017f6f2c256c908b094e55967f93f2821b1439928746a1", size = 13401123, upload-time = "2025-10-16T18:05:35.984Z" }, - { url = "https://files.pythonhosted.org/packages/b8/81/4b6387be7014858d924b843530e1b2a8e531846807516e9bea2ee0936bf7/ruff-0.14.1-py3-none-win_arm64.whl", hash = "sha256:e3b443c4c9f16ae850906b8d0a707b2a4c16f8d2f0a7fe65c475c5886665ce44", size = 12436636, upload-time = "2025-10-16T18:05:38.995Z" }, +version = "0.14.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/34/8218a19b2055b80601e8fd201ec723c74c7fe1ca06d525a43ed07b6d8e85/ruff-0.14.2.tar.gz", hash = "sha256:98da787668f239313d9c902ca7c523fe11b8ec3f39345553a51b25abc4629c96", size = 5539663, upload-time = "2025-10-23T19:37:00.956Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/dd/23eb2db5ad9acae7c845700493b72d3ae214dce0b226f27df89216110f2b/ruff-0.14.2-py3-none-linux_armv6l.whl", hash = "sha256:7cbe4e593505bdec5884c2d0a4d791a90301bc23e49a6b1eb642dd85ef9c64f1", size = 12533390, upload-time = "2025-10-23T19:36:18.044Z" }, + { url = "https://files.pythonhosted.org/packages/5a/8c/5f9acff43ddcf3f85130d0146d0477e28ccecc495f9f684f8f7119b74c0d/ruff-0.14.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8d54b561729cee92f8d89c316ad7a3f9705533f5903b042399b6ae0ddfc62e11", size = 12887187, upload-time = "2025-10-23T19:36:22.664Z" }, + { url = "https://files.pythonhosted.org/packages/99/fa/047646491479074029665022e9f3dc6f0515797f40a4b6014ea8474c539d/ruff-0.14.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5c8753dfa44ebb2cde10ce5b4d2ef55a41fb9d9b16732a2c5df64620dbda44a3", size = 11925177, upload-time = "2025-10-23T19:36:24.778Z" }, + { url = "https://files.pythonhosted.org/packages/15/8b/c44cf7fe6e59ab24a9d939493a11030b503bdc2a16622cede8b7b1df0114/ruff-0.14.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d0bbeffb8d9f4fccf7b5198d566d0bad99a9cb622f1fc3467af96cb8773c9e3", size = 12358285, upload-time = "2025-10-23T19:36:26.979Z" }, + { url = "https://files.pythonhosted.org/packages/45/01/47701b26254267ef40369aea3acb62a7b23e921c27372d127e0f3af48092/ruff-0.14.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7047f0c5a713a401e43a88d36843d9c83a19c584e63d664474675620aaa634a8", size = 12303832, upload-time = "2025-10-23T19:36:29.192Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5c/ae7244ca4fbdf2bee9d6405dcd5bc6ae51ee1df66eb7a9884b77b8af856d/ruff-0.14.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3bf8d2f9aa1602599217d82e8e0af7fd33e5878c4d98f37906b7c93f46f9a839", size = 13036995, upload-time = "2025-10-23T19:36:31.861Z" }, + { url = "https://files.pythonhosted.org/packages/27/4c/0860a79ce6fd4c709ac01173f76f929d53f59748d0dcdd662519835dae43/ruff-0.14.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1c505b389e19c57a317cf4b42db824e2fca96ffb3d86766c1c9f8b96d32048a7", size = 14512649, upload-time = "2025-10-23T19:36:33.915Z" }, + { url = "https://files.pythonhosted.org/packages/7f/7f/d365de998069720a3abfc250ddd876fc4b81a403a766c74ff9bde15b5378/ruff-0.14.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a307fc45ebd887b3f26b36d9326bb70bf69b01561950cdcc6c0bdf7bb8e0f7cc", size = 14088182, upload-time = "2025-10-23T19:36:36.983Z" }, + { url = "https://files.pythonhosted.org/packages/6c/ea/d8e3e6b209162000a7be1faa41b0a0c16a133010311edc3329753cc6596a/ruff-0.14.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:61ae91a32c853172f832c2f40bd05fd69f491db7289fb85a9b941ebdd549781a", size = 13599516, upload-time = "2025-10-23T19:36:39.208Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ea/c7810322086db68989fb20a8d5221dd3b79e49e396b01badca07b433ab45/ruff-0.14.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1967e40286f63ee23c615e8e7e98098dedc7301568bd88991f6e544d8ae096", size = 13272690, upload-time = "2025-10-23T19:36:41.453Z" }, + { url = "https://files.pythonhosted.org/packages/a9/39/10b05acf8c45786ef501d454e00937e1b97964f846bf28883d1f9619928a/ruff-0.14.2-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:2877f02119cdebf52a632d743a2e302dea422bfae152ebe2f193d3285a3a65df", size = 13496497, upload-time = "2025-10-23T19:36:43.61Z" }, + { url = "https://files.pythonhosted.org/packages/59/a1/1f25f8301e13751c30895092485fada29076e5e14264bdacc37202e85d24/ruff-0.14.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e681c5bc777de5af898decdcb6ba3321d0d466f4cb43c3e7cc2c3b4e7b843a05", size = 12266116, upload-time = "2025-10-23T19:36:45.625Z" }, + { url = "https://files.pythonhosted.org/packages/5c/fa/0029bfc9ce16ae78164e6923ef392e5f173b793b26cc39aa1d8b366cf9dc/ruff-0.14.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:e21be42d72e224736f0c992cdb9959a2fa53c7e943b97ef5d081e13170e3ffc5", size = 12281345, upload-time = "2025-10-23T19:36:47.618Z" }, + { url = "https://files.pythonhosted.org/packages/a5/ab/ece7baa3c0f29b7683be868c024f0838770c16607bea6852e46b202f1ff6/ruff-0.14.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:b8264016f6f209fac16262882dbebf3f8be1629777cf0f37e7aff071b3e9b92e", size = 12629296, upload-time = "2025-10-23T19:36:49.789Z" }, + { url = "https://files.pythonhosted.org/packages/a4/7f/638f54b43f3d4e48c6a68062794e5b367ddac778051806b9e235dfb7aa81/ruff-0.14.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5ca36b4cb4db3067a3b24444463ceea5565ea78b95fe9a07ca7cb7fd16948770", size = 13371610, upload-time = "2025-10-23T19:36:51.882Z" }, + { url = "https://files.pythonhosted.org/packages/8d/35/3654a973ebe5b32e1fd4a08ed2d46755af7267da7ac710d97420d7b8657d/ruff-0.14.2-py3-none-win32.whl", hash = "sha256:41775927d287685e08f48d8eb3f765625ab0b7042cc9377e20e64f4eb0056ee9", size = 12415318, upload-time = "2025-10-23T19:36:53.961Z" }, + { url = "https://files.pythonhosted.org/packages/71/30/3758bcf9e0b6a4193a6f51abf84254aba00887dfa8c20aba18aa366c5f57/ruff-0.14.2-py3-none-win_amd64.whl", hash = "sha256:0df3424aa5c3c08b34ed8ce099df1021e3adaca6e90229273496b839e5a7e1af", size = 13565279, upload-time = "2025-10-23T19:36:56.578Z" }, + { url = "https://files.pythonhosted.org/packages/2e/5d/aa883766f8ef9ffbe6aa24f7192fb71632f31a30e77eb39aa2b0dc4290ac/ruff-0.14.2-py3-none-win_arm64.whl", hash = "sha256:ea9d635e83ba21569fbacda7e78afbfeb94911c9434aff06192d9bc23fd5495a", size = 12554956, upload-time = "2025-10-23T19:36:58.714Z" }, ] [[package]] @@ -1787,17 +1956,72 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/c3/c0be1135726618dc1e28d181b8c442403d8dbb9e273fd791de2d4384bcdd/safetensors-0.6.2-cp38-abi3-win_amd64.whl", hash = "sha256:c7b214870df923cbc1593c3faee16bec59ea462758699bd3fee399d00aac072c", size = 320192, upload-time = "2025-08-08T13:13:59.467Z" }, ] +[[package]] +name = "scikit-learn" +version = "1.7.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "joblib" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "threadpoolctl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/98/c2/a7855e41c9d285dfe86dc50b250978105dce513d6e459ea66a6aeb0e1e0c/scikit_learn-1.7.2.tar.gz", hash = "sha256:20e9e49ecd130598f1ca38a1d85090e1a600147b9c02fa6f15d69cb53d968fda", size = 7193136, upload-time = "2025-09-09T08:21:29.075Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/93/a3038cb0293037fd335f77f31fe053b89c72f17b1c8908c576c29d953e84/scikit_learn-1.7.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0b7dacaa05e5d76759fb071558a8b5130f4845166d88654a0f9bdf3eb57851b7", size = 9212382, upload-time = "2025-09-09T08:20:54.731Z" }, + { url = "https://files.pythonhosted.org/packages/40/dd/9a88879b0c1104259136146e4742026b52df8540c39fec21a6383f8292c7/scikit_learn-1.7.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:abebbd61ad9e1deed54cca45caea8ad5f79e1b93173dece40bb8e0c658dbe6fe", size = 8592042, upload-time = "2025-09-09T08:20:57.313Z" }, + { url = "https://files.pythonhosted.org/packages/46/af/c5e286471b7d10871b811b72ae794ac5fe2989c0a2df07f0ec723030f5f5/scikit_learn-1.7.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:502c18e39849c0ea1a5d681af1dbcf15f6cce601aebb657aabbfe84133c1907f", size = 9434180, upload-time = "2025-09-09T08:20:59.671Z" }, + { url = "https://files.pythonhosted.org/packages/f1/fd/df59faa53312d585023b2da27e866524ffb8faf87a68516c23896c718320/scikit_learn-1.7.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7a4c328a71785382fe3fe676a9ecf2c86189249beff90bf85e22bdb7efaf9ae0", size = 9283660, upload-time = "2025-09-09T08:21:01.71Z" }, + { url = "https://files.pythonhosted.org/packages/a7/c7/03000262759d7b6f38c836ff9d512f438a70d8a8ddae68ee80de72dcfb63/scikit_learn-1.7.2-cp313-cp313-win_amd64.whl", hash = "sha256:63a9afd6f7b229aad94618c01c252ce9e6fa97918c5ca19c9a17a087d819440c", size = 8702057, upload-time = "2025-09-09T08:21:04.234Z" }, + { url = "https://files.pythonhosted.org/packages/55/87/ef5eb1f267084532c8e4aef98a28b6ffe7425acbfd64b5e2f2e066bc29b3/scikit_learn-1.7.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9acb6c5e867447b4e1390930e3944a005e2cb115922e693c08a323421a6966e8", size = 9558731, upload-time = "2025-09-09T08:21:06.381Z" }, + { url = "https://files.pythonhosted.org/packages/93/f8/6c1e3fc14b10118068d7938878a9f3f4e6d7b74a8ddb1e5bed65159ccda8/scikit_learn-1.7.2-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:2a41e2a0ef45063e654152ec9d8bcfc39f7afce35b08902bfe290c2498a67a6a", size = 9038852, upload-time = "2025-09-09T08:21:08.628Z" }, + { url = "https://files.pythonhosted.org/packages/83/87/066cafc896ee540c34becf95d30375fe5cbe93c3b75a0ee9aa852cd60021/scikit_learn-1.7.2-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:98335fb98509b73385b3ab2bd0639b1f610541d3988ee675c670371d6a87aa7c", size = 9527094, upload-time = "2025-09-09T08:21:11.486Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2b/4903e1ccafa1f6453b1ab78413938c8800633988c838aa0be386cbb33072/scikit_learn-1.7.2-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:191e5550980d45449126e23ed1d5e9e24b2c68329ee1f691a3987476e115e09c", size = 9367436, upload-time = "2025-09-09T08:21:13.602Z" }, + { url = "https://files.pythonhosted.org/packages/b5/aa/8444be3cfb10451617ff9d177b3c190288f4563e6c50ff02728be67ad094/scikit_learn-1.7.2-cp313-cp313t-win_amd64.whl", hash = "sha256:57dc4deb1d3762c75d685507fbd0bc17160144b2f2ba4ccea5dc285ab0d0e973", size = 9275749, upload-time = "2025-09-09T08:21:15.96Z" }, +] + +[[package]] +name = "scipy" +version = "1.16.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/ca/d8ace4f98322d01abcd52d381134344bf7b431eba7ed8b42bdea5a3c2ac9/scipy-1.16.3.tar.gz", hash = "sha256:01e87659402762f43bd2fee13370553a17ada367d42e7487800bf2916535aecb", size = 30597883, upload-time = "2025-10-28T17:38:54.068Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/f1/57e8327ab1508272029e27eeef34f2302ffc156b69e7e233e906c2a5c379/scipy-1.16.3-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:d2ec56337675e61b312179a1ad124f5f570c00f920cc75e1000025451b88241c", size = 36617856, upload-time = "2025-10-28T17:33:31.375Z" }, + { url = "https://files.pythonhosted.org/packages/44/13/7e63cfba8a7452eb756306aa2fd9b37a29a323b672b964b4fdeded9a3f21/scipy-1.16.3-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:16b8bc35a4cc24db80a0ec836a9286d0e31b2503cb2fd7ff7fb0e0374a97081d", size = 28874306, upload-time = "2025-10-28T17:33:36.516Z" }, + { url = "https://files.pythonhosted.org/packages/15/65/3a9400efd0228a176e6ec3454b1fa998fbbb5a8defa1672c3f65706987db/scipy-1.16.3-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:5803c5fadd29de0cf27fa08ccbfe7a9e5d741bf63e4ab1085437266f12460ff9", size = 20865371, upload-time = "2025-10-28T17:33:42.094Z" }, + { url = "https://files.pythonhosted.org/packages/33/d7/eda09adf009a9fb81827194d4dd02d2e4bc752cef16737cc4ef065234031/scipy-1.16.3-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:b81c27fc41954319a943d43b20e07c40bdcd3ff7cf013f4fb86286faefe546c4", size = 23524877, upload-time = "2025-10-28T17:33:48.483Z" }, + { url = "https://files.pythonhosted.org/packages/7d/6b/3f911e1ebc364cb81320223a3422aab7d26c9c7973109a9cd0f27c64c6c0/scipy-1.16.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0c3b4dd3d9b08dbce0f3440032c52e9e2ab9f96ade2d3943313dfe51a7056959", size = 33342103, upload-time = "2025-10-28T17:33:56.495Z" }, + { url = "https://files.pythonhosted.org/packages/21/f6/4bfb5695d8941e5c570a04d9fcd0d36bce7511b7d78e6e75c8f9791f82d0/scipy-1.16.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7dc1360c06535ea6116a2220f760ae572db9f661aba2d88074fe30ec2aa1ff88", size = 35697297, upload-time = "2025-10-28T17:34:04.722Z" }, + { url = "https://files.pythonhosted.org/packages/04/e1/6496dadbc80d8d896ff72511ecfe2316b50313bfc3ebf07a3f580f08bd8c/scipy-1.16.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:663b8d66a8748051c3ee9c96465fb417509315b99c71550fda2591d7dd634234", size = 36021756, upload-time = "2025-10-28T17:34:13.482Z" }, + { url = "https://files.pythonhosted.org/packages/fe/bd/a8c7799e0136b987bda3e1b23d155bcb31aec68a4a472554df5f0937eef7/scipy-1.16.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:eab43fae33a0c39006a88096cd7b4f4ef545ea0447d250d5ac18202d40b6611d", size = 38696566, upload-time = "2025-10-28T17:34:22.384Z" }, + { url = "https://files.pythonhosted.org/packages/cd/01/1204382461fcbfeb05b6161b594f4007e78b6eba9b375382f79153172b4d/scipy-1.16.3-cp313-cp313-win_amd64.whl", hash = "sha256:062246acacbe9f8210de8e751b16fc37458213f124bef161a5a02c7a39284304", size = 38529877, upload-time = "2025-10-28T17:35:51.076Z" }, + { url = "https://files.pythonhosted.org/packages/7f/14/9d9fbcaa1260a94f4bb5b64ba9213ceb5d03cd88841fe9fd1ffd47a45b73/scipy-1.16.3-cp313-cp313-win_arm64.whl", hash = "sha256:50a3dbf286dbc7d84f176f9a1574c705f277cb6565069f88f60db9eafdbe3ee2", size = 25455366, upload-time = "2025-10-28T17:35:59.014Z" }, + { url = "https://files.pythonhosted.org/packages/e2/a3/9ec205bd49f42d45d77f1730dbad9ccf146244c1647605cf834b3a8c4f36/scipy-1.16.3-cp313-cp313t-macosx_10_14_x86_64.whl", hash = "sha256:fb4b29f4cf8cc5a8d628bc8d8e26d12d7278cd1f219f22698a378c3d67db5e4b", size = 37027931, upload-time = "2025-10-28T17:34:31.451Z" }, + { url = "https://files.pythonhosted.org/packages/25/06/ca9fd1f3a4589cbd825b1447e5db3a8ebb969c1eaf22c8579bd286f51b6d/scipy-1.16.3-cp313-cp313t-macosx_12_0_arm64.whl", hash = "sha256:8d09d72dc92742988b0e7750bddb8060b0c7079606c0d24a8cc8e9c9c11f9079", size = 29400081, upload-time = "2025-10-28T17:34:39.087Z" }, + { url = "https://files.pythonhosted.org/packages/6a/56/933e68210d92657d93fb0e381683bc0e53a965048d7358ff5fbf9e6a1b17/scipy-1.16.3-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:03192a35e661470197556de24e7cb1330d84b35b94ead65c46ad6f16f6b28f2a", size = 21391244, upload-time = "2025-10-28T17:34:45.234Z" }, + { url = "https://files.pythonhosted.org/packages/a8/7e/779845db03dc1418e215726329674b40576879b91814568757ff0014ad65/scipy-1.16.3-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:57d01cb6f85e34f0946b33caa66e892aae072b64b034183f3d87c4025802a119", size = 23929753, upload-time = "2025-10-28T17:34:51.793Z" }, + { url = "https://files.pythonhosted.org/packages/4c/4b/f756cf8161d5365dcdef9e5f460ab226c068211030a175d2fc7f3f41ca64/scipy-1.16.3-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:96491a6a54e995f00a28a3c3badfff58fd093bf26cd5fb34a2188c8c756a3a2c", size = 33496912, upload-time = "2025-10-28T17:34:59.8Z" }, + { url = "https://files.pythonhosted.org/packages/09/b5/222b1e49a58668f23839ca1542a6322bb095ab8d6590d4f71723869a6c2c/scipy-1.16.3-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:cd13e354df9938598af2be05822c323e97132d5e6306b83a3b4ee6724c6e522e", size = 35802371, upload-time = "2025-10-28T17:35:08.173Z" }, + { url = "https://files.pythonhosted.org/packages/c1/8d/5964ef68bb31829bde27611f8c9deeac13764589fe74a75390242b64ca44/scipy-1.16.3-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:63d3cdacb8a824a295191a723ee5e4ea7768ca5ca5f2838532d9f2e2b3ce2135", size = 36190477, upload-time = "2025-10-28T17:35:16.7Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f2/b31d75cb9b5fa4dd39a0a931ee9b33e7f6f36f23be5ef560bf72e0f92f32/scipy-1.16.3-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e7efa2681ea410b10dde31a52b18b0154d66f2485328830e45fdf183af5aefc6", size = 38796678, upload-time = "2025-10-28T17:35:26.354Z" }, + { url = "https://files.pythonhosted.org/packages/b4/1e/b3723d8ff64ab548c38d87055483714fefe6ee20e0189b62352b5e015bb1/scipy-1.16.3-cp313-cp313t-win_amd64.whl", hash = "sha256:2d1ae2cf0c350e7705168ff2429962a89ad90c2d49d1dd300686d8b2a5af22fc", size = 38640178, upload-time = "2025-10-28T17:35:35.304Z" }, + { url = "https://files.pythonhosted.org/packages/8e/f3/d854ff38789aca9b0cc23008d607ced9de4f7ab14fa1ca4329f86b3758ca/scipy-1.16.3-cp313-cp313t-win_arm64.whl", hash = "sha256:0c623a54f7b79dd88ef56da19bc2873afec9673a48f3b85b18e4d402bdd29a5a", size = 25803246, upload-time = "2025-10-28T17:35:42.155Z" }, +] + [[package]] name = "sentry-sdk" -version = "2.42.0" +version = "2.42.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/b2/7481156cf42b7f66cffb371e504b7ace12b4f016b8872ffcf0873ae9534b/sentry_sdk-2.42.0.tar.gz", hash = "sha256:91c69c9372fb5fb4df0ac39456ccf7286f0428b3ee1cdd389f9dd36c04e0f5c9", size = 351242, upload-time = "2025-10-15T07:41:15.577Z" } +sdist = { url = "https://files.pythonhosted.org/packages/31/04/ec8c1dd9250847303d98516e917978cb1c7083024770d86d657d2ccb5a70/sentry_sdk-2.42.1.tar.gz", hash = "sha256:8598cc6edcfe74cb8074ba6a7c15338cdee93d63d3eb9b9943b4b568354ad5b6", size = 354839, upload-time = "2025-10-20T12:38:40.45Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/4a/9810a246ec5d1df2ae066efefeecfa91d3c548fa2bd5390184e016112887/sentry_sdk-2.42.0-py2.py3-none-any.whl", hash = "sha256:1a7986e638306ff158f52dd47d9480a4055e6c289388caa90628acb2563fe7bd", size = 379496, upload-time = "2025-10-15T07:41:13.802Z" }, + { url = "https://files.pythonhosted.org/packages/0f/cb/c21b96ff379923310b4fb2c06e8d560d801e24aeb300faa72a04776868fc/sentry_sdk-2.42.1-py2.py3-none-any.whl", hash = "sha256:f8716b50c927d3beb41bc88439dc6bcd872237b596df5b14613e2ade104aee02", size = 380952, upload-time = "2025-10-20T12:38:38.88Z" }, ] [[package]] @@ -1856,6 +2080,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "soupsieve" +version = "2.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6d/e6/21ccce3262dd4889aa3332e5a119a3491a95e8f60939870a3a035aabac0d/soupsieve-2.8.tar.gz", hash = "sha256:e2dd4a40a628cb5f28f6d4b0db8800b8f581b65bb380b97de22ba5ca8d72572f", size = 103472, upload-time = "2025-08-27T15:39:51.78Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl", hash = "sha256:0cc76456a30e20f5d7f2e14a98a4ae2ee4e5abdc7c5ea0aafe795f344bc7984c", size = 36679, upload-time = "2025-08-27T15:39:50.179Z" }, +] + [[package]] name = "spd" version = "0.0.1" @@ -1868,9 +2101,12 @@ dependencies = [ { name = "ipykernel" }, { name = "jaxtyping" }, { name = "matplotlib" }, + { name = "muutils" }, { name = "numpy" }, { name = "pydantic" }, { name = "python-dotenv" }, + { name = "scikit-learn" }, + { name = "scipy" }, { name = "simple-stories-train" }, { name = "streamlit" }, { name = "streamlit-antd-components" }, @@ -1887,6 +2123,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "basedpyright" }, + { name = "nbconvert" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -1903,9 +2140,12 @@ requires-dist = [ { name = "ipykernel" }, { name = "jaxtyping" }, { name = "matplotlib" }, + { name = "muutils" }, { name = "numpy" }, { name = "pydantic", specifier = "<2.12" }, { name = "python-dotenv" }, + { name = "scikit-learn" }, + { name = "scipy", specifier = ">=1.14.1" }, { name = "simple-stories-train", git = "https://github.com/goodfire-ai/simple_stories_train.git?rev=dev" }, { name = "streamlit" }, { name = "streamlit-antd-components" }, @@ -1922,6 +2162,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "basedpyright", specifier = "<1.32.0" }, + { name = "nbconvert" }, { name = "pre-commit" }, { name = "pytest" }, { name = "pytest-cov" }, @@ -1945,14 +2186,14 @@ wheels = [ [[package]] name = "starlette" -version = "0.48.0" +version = "0.49.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a7/a5/d6f429d43394057b67a6b5bbe6eae2f77a6bf7459d961fdb224bf206eee6/starlette-0.48.0.tar.gz", hash = "sha256:7e8cee469a8ab2352911528110ce9088fdc6a37d9876926e73da7ce4aa4c7a46", size = 2652949, upload-time = "2025-09-13T08:41:05.699Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/3f/507c21db33b66fb027a332f2cb3abbbe924cc3a79ced12f01ed8645955c9/starlette-0.49.1.tar.gz", hash = "sha256:481a43b71e24ed8c43b11ea02f5353d77840e01480881b8cb5a26b8cae64a8cb", size = 2654703, upload-time = "2025-10-28T17:34:10.928Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/72/2db2f49247d0a18b4f1bb9a5a39a0162869acf235f3a96418363947b3d46/starlette-0.48.0-py3-none-any.whl", hash = "sha256:0764ca97b097582558ecb498132ed0c7d942f233f365b86ba37770e026510659", size = 73736, upload-time = "2025-09-13T08:41:03.869Z" }, + { url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" }, ] [[package]] @@ -2018,11 +2259,20 @@ wheels = [ [[package]] name = "termcolor" -version = "3.1.0" +version = "3.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/87/56/ab275c2b56a5e2342568838f0d5e3e66a32354adcc159b495e374cda43f5/termcolor-3.2.0.tar.gz", hash = "sha256:610e6456feec42c4bcd28934a8c87a06c3fa28b01561d46aa09a9881b8622c58", size = 14423, upload-time = "2025-10-25T19:11:42.586Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/d5/141f53d7c1eb2a80e6d3e9a390228c3222c27705cbe7f048d3623053f3ca/termcolor-3.2.0-py3-none-any.whl", hash = "sha256:a10343879eba4da819353c55cb8049b0933890c2ebf9ad5d3ecd2bb32ea96ea6", size = 7698, upload-time = "2025-10-25T19:11:41.536Z" }, +] + +[[package]] +name = "threadpoolctl" +version = "3.6.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/6c/3d75c196ac07ac8749600b60b03f4f6094d54e132c4d94ebac6ee0e0add0/termcolor-3.1.0.tar.gz", hash = "sha256:6a6dd7fbee581909eeec6a756cff1d7f7c376063b14e4a298dc4980309e55970", size = 14324, upload-time = "2025-04-30T11:37:53.791Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274, upload-time = "2025-03-13T13:49:23.031Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4f/bd/de8d508070629b6d84a30d01d57e4a65c69aa7f5abe7560b8fad3b50ea59/termcolor-3.1.0-py3-none-any.whl", hash = "sha256:591dd26b5c2ce03b9e43f391264626557873ce1d379019786f99b0c2bee140aa", size = 7684, upload-time = "2025-04-30T11:37:52.382Z" }, + { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, ] [[package]] @@ -2051,6 +2301,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/93/e0/6cc82a562bc6365785a3ff0af27a2a092d57c47d7a81d9e2295d8c36f011/tiktoken-0.12.0-cp313-cp313t-win_amd64.whl", hash = "sha256:dc2dd125a62cb2b3d858484d6c614d136b5b848976794edfb63688d539b8b93f", size = 878777, upload-time = "2025-10-06T20:22:18.036Z" }, ] +[[package]] +name = "tinycss2" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/fd/7a5ee21fd08ff70d3d33a5781c255cbe779659bd03278feb98b19ee550f4/tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7", size = 87085, upload-time = "2024-10-24T14:58:29.895Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610, upload-time = "2024-10-24T14:58:28.029Z" }, +] + [[package]] name = "tokenizers" version = "0.22.1" @@ -2210,7 +2472,7 @@ name = "triton" version = "3.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "setuptools", marker = "sys_platform == 'linux'" }, + { name = "setuptools" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/30/7b/0a685684ed5322d2af0bddefed7906674f67974aa88b0fae6e82e3b766f6/triton-3.4.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:00be2964616f4c619193cb0d1b29a99bd4b001d7dc333816073f92cf2a8ccdeb", size = 155569223, upload-time = "2025-07-30T19:58:44.017Z" }, @@ -2271,16 +2533,16 @@ wheels = [ [[package]] name = "virtualenv" -version = "20.35.3" +version = "20.35.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, { name = "filelock" }, { name = "platformdirs" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a4/d5/b0ccd381d55c8f45d46f77df6ae59fbc23d19e901e2d523395598e5f4c93/virtualenv-20.35.3.tar.gz", hash = "sha256:4f1a845d131133bdff10590489610c98c168ff99dc75d6c96853801f7f67af44", size = 6002907, upload-time = "2025-10-10T21:23:33.178Z" } +sdist = { url = "https://files.pythonhosted.org/packages/20/28/e6f1a6f655d620846bd9df527390ecc26b3805a0c5989048c210e22c5ca9/virtualenv-20.35.4.tar.gz", hash = "sha256:643d3914d73d3eeb0c552cbb12d7e82adf0e504dbf86a3182f8771a153a1971c", size = 6028799, upload-time = "2025-10-29T06:57:40.511Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/27/73/d9a94da0e9d470a543c1b9d3ccbceb0f59455983088e727b8a1824ed90fb/virtualenv-20.35.3-py3-none-any.whl", hash = "sha256:63d106565078d8c8d0b206d48080f938a8b25361e19432d2c9db40d2899c810a", size = 5981061, upload-time = "2025-10-10T21:23:30.433Z" }, + { url = "https://files.pythonhosted.org/packages/79/0c/c05523fa3181fdf0c9c52a6ba91a23fbf3246cc095f26f6516f9c60e6771/virtualenv-20.35.4-py3-none-any.whl", hash = "sha256:c21c9cede36c9753eeade68ba7d523529f228a403463376cf821eaae2b650f1b", size = 6005095, upload-time = "2025-10-29T06:57:37.598Z" }, ] [[package]] @@ -2294,7 +2556,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.22.2" +version = "0.22.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -2308,17 +2570,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/a8/680bd77e11a278e6c14a2cb4646e8ab9525b2baaa81c3d12dc0f616aa4aa/wandb-0.22.2.tar.gz", hash = "sha256:510f5a1ac30d16921c36c3b932da852f046641d4aee98a86a7f5ec03a6e95bda", size = 41401439, upload-time = "2025-10-07T19:54:21.88Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c1/d1/6b70f365ed86bd69debba8ad55dec8606fc21006e7ca703a5a091bd3b719/wandb-0.22.3.tar.gz", hash = "sha256:04468a8ab2769a46f5e384c9c4ada5da0dced005ca689a8424e4b8b5cb2a0291", size = 44337368, upload-time = "2025-10-28T23:59:10.275Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/b3/8c637fb594cfd574ce9c9f7d0ac2f2d12742eb38ec59dcbb713beae95343/wandb-0.22.2-py3-none-macosx_12_0_arm64.whl", hash = "sha256:2e29c9fa4462b5411b2cd2175ae33eff4309c91de7c426bca6bc8e7abc7e5dec", size = 18677549, upload-time = "2025-10-07T19:54:00.839Z" }, - { url = "https://files.pythonhosted.org/packages/d3/f3/e309a726eaebddad6b8d9a73a50891e5796962ec8a091bb6a61d31692d1e/wandb-0.22.2-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:c42d594cd7a9da4fd39ecdb0abbc081b61f304123277b2b6c4ba84283956fd21", size = 19715188, upload-time = "2025-10-07T19:54:03.805Z" }, - { url = "https://files.pythonhosted.org/packages/f9/73/fad59910215876008f4781b57d828d1b19b3677c9b46af615e7229746435/wandb-0.22.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5188d84e66d3fd584f3b3ae4d2a70e78f29403c0528e6aecaa4188a1fcf54d8", size = 18463148, upload-time = "2025-10-07T19:54:05.676Z" }, - { url = "https://files.pythonhosted.org/packages/87/11/572c1913b5b92e4c519f735adfae572b46f2d79d99ede63eec0d6a272d6e/wandb-0.22.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:88ccd484af9f21cfc127976793c3cf66cfe1acd75bd8cd650086a64e88bac4bf", size = 19908645, upload-time = "2025-10-07T19:54:07.693Z" }, - { url = "https://files.pythonhosted.org/packages/6d/0d/133aa82f5a505ba638b4fda5014cefddfe7f1f6238ef4afc0871ec61c41f/wandb-0.22.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:abf0ed175e791af64110e0a0b99ce02bbbbd1017722bc32d3bc328efb86450cd", size = 18501348, upload-time = "2025-10-07T19:54:10.234Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d5/776203be2601872f01dacc6a5b4274106ec0db7cd3bf2cdb3b741f8fc932/wandb-0.22.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:44e77c56403b90bf3473a7ca3bfc4d42c636b7c0e31a5fb9cd0382f08302f74b", size = 20001756, upload-time = "2025-10-07T19:54:12.452Z" }, - { url = "https://files.pythonhosted.org/packages/30/43/ae3fa46e20b1d9a6508dd9abe716d57205c038ed4661c5c98ace48a60eac/wandb-0.22.2-py3-none-win32.whl", hash = "sha256:44d12bd379dbe15be5ceed6bdf23803d42f648ba0dd111297b4c47a3c7be6dbd", size = 19075950, upload-time = "2025-10-07T19:54:14.892Z" }, - { url = "https://files.pythonhosted.org/packages/09/59/c174321e868205f7a659d1e5ec51f546e62267296d6f4179bb9119294964/wandb-0.22.2-py3-none-win_amd64.whl", hash = "sha256:c95eb221bf316c0872f7ac55071856b9f25f95a2de983ada48acf653ce259386", size = 19075953, upload-time = "2025-10-07T19:54:16.837Z" }, - { url = "https://files.pythonhosted.org/packages/7a/a2/c7c24fda78513cab5686949d8cb36459dbbccbbb4b2b6fc67237ece31a00/wandb-0.22.2-py3-none-win_arm64.whl", hash = "sha256:20d2ab9aa10445aab3d60914a980f002a4f66566e28b0cd156b1e462f0080a0d", size = 17383217, upload-time = "2025-10-07T19:54:19.384Z" }, + { url = "https://files.pythonhosted.org/packages/23/02/87fb60f587ec249f784a40bd91c30de1b2b24d691ee72675d5b66c3d0728/wandb-0.22.3-py3-none-macosx_12_0_arm64.whl", hash = "sha256:81b3b6e405f38342b0a080898b7d00c5b9375432f5ba358942a09e65cdcfe781", size = 18758047, upload-time = "2025-10-28T23:58:46.56Z" }, + { url = "https://files.pythonhosted.org/packages/26/88/64081740ef2b2efc7fbcb2139a07a849e42bcb09ae0c56ae50c41bd0ad63/wandb-0.22.3-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:d29c16817cca6401b4919069ec7570c781eacb67dc0b1ff2e0096a9a59581720", size = 19798011, upload-time = "2025-10-28T23:58:49.718Z" }, + { url = "https://files.pythonhosted.org/packages/19/72/c4f922b33dbb84d1c81ee045ff8791dd14e26d79e1e9bbafff964b7043e2/wandb-0.22.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb955d73a4ba55df9adc61fafbabef5556784d33fc39c7b5c8165d2694ddeb3b", size = 18542713, upload-time = "2025-10-28T23:58:51.927Z" }, + { url = "https://files.pythonhosted.org/packages/ad/98/3ce5f6e2086d91b0c51b38ae7ff591109e7da2bb25fe1a12eec0cdbaa494/wandb-0.22.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23f3ebe41a26506117a098fdfd2706ed0e50b37899bfbefe3a0628fcbd70c69d", size = 19984910, upload-time = "2025-10-28T23:58:54.641Z" }, + { url = "https://files.pythonhosted.org/packages/5e/57/e68cb38427b60490d6ddf1b992e6c7f36be83be1079d291ce87a8d347f48/wandb-0.22.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2973462bed5d4a653b1a97cf9fc350673bb200fb356a2f4eba34beae9b87e0aa", size = 18581776, upload-time = "2025-10-28T23:58:56.975Z" }, + { url = "https://files.pythonhosted.org/packages/66/6d/543f907ce0c6b6da13628b23d19ca7282c559fd73eb47b04977b9a61d0c6/wandb-0.22.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c5c2bd18f95c1639863c527da0a5818ac6b0e5194f9c691426b265908ddd8b2c", size = 20078800, upload-time = "2025-10-28T23:58:59.217Z" }, + { url = "https://files.pythonhosted.org/packages/da/91/1decaf1a6ac2017481c782e0fad7f90bc9ae4057f3d76d478cb6527f3dd3/wandb-0.22.3-py3-none-win32.whl", hash = "sha256:09ca1edfe0fd6dc30447d368acddb825668e60ee705c98594a6bbfd30d34d47e", size = 19160297, upload-time = "2025-10-28T23:59:01.536Z" }, + { url = "https://files.pythonhosted.org/packages/4c/ba/3b092634279994b0c79fe05220532822be09f3a353ae95c54e7142769db8/wandb-0.22.3-py3-none-win_amd64.whl", hash = "sha256:55403bf93872c9978433d101324f51e43e78c70c809bf6d06ca7b2760e39f497", size = 19160300, upload-time = "2025-10-28T23:59:04.06Z" }, + { url = "https://files.pythonhosted.org/packages/7f/80/4662fce9eebcc8c71f5083e9152ccaf7d43d4ca9c446e1422f9aa784a51c/wandb-0.22.3-py3-none-win_arm64.whl", hash = "sha256:49f66b05882abfa53816cc8d01b3c2435a89c5a090176802fa6928b5979d34d9", size = 17461959, upload-time = "2025-10-28T23:59:07.059Z" }, ] [[package]] @@ -2361,6 +2623,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, ] +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721, upload-time = "2017-04-05T20:21:34.189Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774, upload-time = "2017-04-05T20:21:32.581Z" }, +] + [[package]] name = "xxhash" version = "3.6.0"