diff --git a/spd/scripts/plot_component_activations/__init__.py b/spd/scripts/plot_component_activations/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/spd/scripts/plot_component_activations.py b/spd/scripts/plot_component_activations/plot_component_activations.py similarity index 67% rename from spd/scripts/plot_component_activations.py rename to spd/scripts/plot_component_activations/plot_component_activations.py index 8ab4a2d0d..2c30d3994 100644 --- a/spd/scripts/plot_component_activations.py +++ b/spd/scripts/plot_component_activations/plot_component_activations.py @@ -6,22 +6,27 @@ - Filter: Only plots datapoints where CI > threshold Usage: - python -m spd.scripts.plot_component_activations s-7884efcc - python -m spd.scripts.plot_component_activations s-7884efcc --ci-threshold 0.0 + python -m spd.scripts.plot_component_activations.plot_component_activations \ + wandb:goodfire/spd/runs/ """ -import argparse from collections import defaultdict from pathlib import Path +import fire import matplotlib.pyplot as plt import numpy as np from spd.harvest.repo import HarvestRepo from spd.harvest.schemas import ComponentData +from spd.log import logger +from spd.spd_types import ModelPath +from spd.utils.wandb_utils import parse_wandb_run_path +SCRIPT_DIR = Path(__file__).parent -def extract_activations( + +def _extract_activations( components: list[ComponentData], ci_threshold: float, ) -> tuple[dict[str, dict[str, list[float]]], dict[str, dict[str, list[float]]]]: @@ -47,7 +52,7 @@ def extract_activations( return dict(all_activations), dict(filtered_activations) -def normalize_per_component( +def _normalize_per_component( all_activations: dict[str, list[float]], filtered_activations: dict[str, list[float]], ) -> dict[str, np.ndarray]: @@ -67,14 +72,14 @@ def normalize_per_component( return normalized -def order_by_median(normalized: dict[str, np.ndarray]) -> list[str]: +def _order_by_median(normalized: dict[str, np.ndarray]) -> list[str]: """Order component keys by median of their normalized activations (descending).""" medians = [(key, np.median(acts)) for key, acts in normalized.items()] medians.sort(key=lambda x: x[1], reverse=True) return [key for key, _ in medians] -def order_by_frequency( +def _order_by_frequency( normalized: dict[str, np.ndarray], firing_counts: dict[str, int] ) -> list[str]: """Order component keys by pre-calculated firing counts (descending).""" @@ -83,14 +88,14 @@ def order_by_frequency( return [key for key, _ in freqs] -def create_layer_scatter_plot( +def _create_layer_scatter_plot( normalized_by_key: dict[str, np.ndarray], ordered_keys: list[str], layer_name: str, run_id: str, output_path: Path, - x_label: str = "Component Rank (by median activation)", - y_label: str = "Normalized Component Activation", + x_label: str, + y_label: str, ) -> None: """Create scatter plot for a single layer.""" x_vals = [] @@ -123,86 +128,87 @@ def create_layer_scatter_plot( plt.close(fig) -def main(): - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("run_id", help="WandB run ID (e.g., 's-7884efcc')") - parser.add_argument( - "--ci-threshold", - type=float, - default=0.1, - help="Minimum CI value to include (default: 0.1)", - ) - args = parser.parse_args() +def plot_component_activations( + wandb_path: ModelPath, + ci_threshold: float = 0.1, +) -> None: + _entity, _project, run_id = parse_wandb_run_path(str(wandb_path)) - base_output_dir = Path(__file__).parent / "out" / args.run_id / "component-act-scatter" - output_dir_median = base_output_dir / "order-by-median" - output_dir_freq = base_output_dir / "order-by-freq" - output_dir_median.mkdir(parents=True, exist_ok=True) - output_dir_freq.mkdir(parents=True, exist_ok=True) + out_dir = SCRIPT_DIR / "out" / run_id + out_dir_median = out_dir / "order-by-median" + out_dir_freq = out_dir / "order-by-freq" + out_dir_median.mkdir(parents=True, exist_ok=True) + out_dir_freq.mkdir(parents=True, exist_ok=True) - repo = HarvestRepo.open(args.run_id) - assert repo is not None, f"No harvest data for {args.run_id}" + repo = HarvestRepo.open(run_id) + assert repo is not None, f"No harvest data for {run_id}" - print(f"Loading components for run {args.run_id}...") + logger.info(f"Loading components for run {run_id}...") components = repo.get_all_components() - print(f"Loaded {len(components)} components") + logger.info(f"Loaded {len(components)} components") - print("Loading firing counts...") + logger.info("Loading firing counts...") token_stats = repo.get_token_stats() - assert token_stats is not None, f"No token stats found for run {args.run_id}" + assert token_stats is not None, f"No token stats found for run {run_id}" firing_counts = { key: int(count) for key, count in zip(token_stats.component_keys, token_stats.firing_counts, strict=True) } - print("Extracting activations...") - all_by_layer, filtered_by_layer = extract_activations(components, args.ci_threshold) + logger.info("Extracting activations...") + all_by_layer, filtered_by_layer = _extract_activations(components, ci_threshold) n_layers = len(filtered_by_layer) n_total = sum(sum(len(v) for v in layer.values()) for layer in filtered_by_layer.values()) - print(f"Found {n_total} datapoints across {n_layers} layers with CI > {args.ci_threshold}") + logger.info(f"Found {n_total} datapoints across {n_layers} layers with CI > {ci_threshold}") - if n_total == 0: - print("No datapoints found above threshold. Try lowering --ci-threshold.") - return + assert n_total > 0, "No datapoints found above threshold. Try lowering ci_threshold." - print(f"Creating per-layer plots (ordered by median) in {output_dir_median}/...") + logger.info(f"Creating per-layer plots (ordered by median) in {out_dir_median}/...") for layer_name in sorted(all_by_layer.keys()): all_acts = all_by_layer[layer_name] filtered_acts = filtered_by_layer.get(layer_name, {}) - normalized = normalize_per_component(all_acts, filtered_acts) + normalized = _normalize_per_component(all_acts, filtered_acts) if not normalized: continue - ordered_keys = order_by_median(normalized) + ordered_keys = _order_by_median(normalized) safe_name = layer_name.replace(".", "_") - output_path = output_dir_median / f"{safe_name}.png" - create_layer_scatter_plot(normalized, ordered_keys, layer_name, args.run_id, output_path) - print(f" {output_path}") + output_path = out_dir_median / f"{safe_name}.png" + _create_layer_scatter_plot( + normalized, + ordered_keys, + layer_name, + run_id, + output_path, + x_label="Component Rank (by median activation)", + y_label="Normalized Component Activation", + ) + logger.info(f" Saved {output_path}") - print(f"Creating per-layer plots (ordered by frequency) in {output_dir_freq}/...") + logger.info(f"Creating per-layer plots (ordered by frequency) in {out_dir_freq}/...") for layer_name in sorted(all_by_layer.keys()): all_acts = all_by_layer[layer_name] filtered_acts = filtered_by_layer.get(layer_name, {}) - normalized = normalize_per_component(all_acts, filtered_acts) + normalized = _normalize_per_component(all_acts, filtered_acts) if not normalized: continue abs_from_midpoint = {key: np.abs(acts - 0.5) for key, acts in normalized.items()} - ordered_keys = order_by_frequency(abs_from_midpoint, firing_counts) + ordered_keys = _order_by_frequency(abs_from_midpoint, firing_counts) safe_name = layer_name.replace(".", "_") - output_path = output_dir_freq / f"{safe_name}.png" - create_layer_scatter_plot( + output_path = out_dir_freq / f"{safe_name}.png" + _create_layer_scatter_plot( abs_from_midpoint, ordered_keys, layer_name, - args.run_id, + run_id, output_path, x_label="Component Rank (by firing frequency)", y_label="|Normalized Component Activation - 0.5|", ) - print(f" {output_path}") + logger.info(f" Saved {output_path}") - print("Done!") + logger.info(f"All plots saved to {out_dir}") if __name__ == "__main__": - main() + fire.Fire(plot_component_activations)