From 3d1a772cdc2dc2ac63bf6c4be4124088922a896f Mon Sep 17 00:00:00 2001 From: Antovigo Date: Mon, 23 Feb 2026 15:16:40 -0800 Subject: [PATCH 01/13] Save metadata about the run --- spd/configs.py | 8 ++++ spd/experiments/tms/tms_5-2_config.yaml | 2 +- spd/run_spd.py | 49 ++++++++++++++++++++++++- 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/spd/configs.py b/spd/configs.py index efbfb3bb4..f93e5d01d 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -744,6 +744,14 @@ class Config(BaseConfig): default="", description="Prefix prepended to an auto-generated WandB run name", ) + label: str = Field( + default="", + description="Short human-readable label for this run or series of runs", + ) + notes: str = Field( + default="", + description="Free-form notes about this run's purpose or configuration choices", + ) # --- General --- seed: int = Field( diff --git a/spd/experiments/tms/tms_5-2_config.yaml b/spd/experiments/tms/tms_5-2_config.yaml index 07bc9056a..53741bdbc 100644 --- a/spd/experiments/tms/tms_5-2_config.yaml +++ b/spd/experiments/tms/tms_5-2_config.yaml @@ -79,4 +79,4 @@ pretrained_model_path: "wandb:goodfire/spd-pre-Sep-2025/runs/0hsp07o4" task_config: task_name: tms feature_probability: 0.05 - data_generation_type: "at_least_zero_active" \ No newline at end of file + data_generation_type: "at_least_zero_active" diff --git a/spd/run_spd.py b/spd/run_spd.py index d303a24c5..0592aa1f8 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -1,9 +1,11 @@ """Run SPD on a model.""" import gc +import json import os from collections import defaultdict from collections.abc import Iterator +from datetime import UTC, datetime from pathlib import Path from typing import Any, cast @@ -54,6 +56,7 @@ get_scheduled_value, save_pre_run_info, ) +from spd.utils.git_utils import repo_current_commit_hash, repo_is_clean from spd.utils.logging_utils import get_grad_norms_dict, local_log from spd.utils.module_utils import expand_module_patterns from spd.utils.run_utils import generate_run_id, save_file @@ -427,6 +430,43 @@ def create_pgd_data_iter() -> ( logger.info("Finished training loop.") +RUN_METADATA_FILENAME = "run_metadata.json" + + +def _write_run_metadata(out_dir: Path, run_id: str, config: Config, save_to_wandb: bool) -> None: + """Write run_metadata.json with git state, timestamp, and user annotations.""" + metadata = { + "run_id": run_id, + "git_commit": repo_current_commit_hash(), + "uncommitted_changes": not repo_is_clean(), + "date": datetime.now(UTC).strftime("%Y-%m-%d %H:%M"), + "label": config.label, + "notes": config.notes, + "completed": False, + } + + metadata_path = out_dir / RUN_METADATA_FILENAME + save_file(metadata, metadata_path, indent=2) + + if save_to_wandb: + try_wandb(wandb.save, str(metadata_path), base_path=str(out_dir), policy="now") + + +def _mark_run_completed(out_dir: Path, save_to_wandb: bool) -> None: + """Set completed=True in an existing run_metadata.json.""" + metadata_path = out_dir / RUN_METADATA_FILENAME + assert metadata_path.exists(), f"run_metadata.json not found at {metadata_path}" + + with open(metadata_path) as f: + metadata = json.load(f) + + metadata["completed"] = True + save_file(metadata, metadata_path, indent=2) + + if save_to_wandb: + try_wandb(wandb.save, str(metadata_path), base_path=str(out_dir), policy="now") + + def run_experiment( target_model: nn.Module, config: Config, @@ -474,6 +514,8 @@ def run_experiment( train_config=target_model_train_config, task_name=getattr(config.task_config, "task_name", None), ) + + _write_run_metadata(out_dir, run_id, config, save_to_wandb=config.wandb_project is not None) else: out_dir = None @@ -488,5 +530,8 @@ def run_experiment( tied_weights=tied_weights, ) - if is_main_process() and config.wandb_project: - wandb.finish() + if is_main_process(): + assert out_dir is not None + _mark_run_completed(out_dir, save_to_wandb=config.wandb_project is not None) + if config.wandb_project: + wandb.finish() From cd76c2c66a9eff31f4ba09d5222a8b9e983d3cc9 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Mon, 23 Feb 2026 15:57:00 -0800 Subject: [PATCH 02/13] Script to collect metadata about all the runs in a folder, and create a index --- pyproject.toml | 1 + spd/scripts/index_spd_runs.py | 287 ++++++++++++++++++++++++++++++++++ 2 files changed, 288 insertions(+) create mode 100644 spd/scripts/index_spd_runs.py diff --git a/pyproject.toml b/pyproject.toml index 88c3405a8..d71232a94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" spd-postprocess = "spd.postprocess.cli:cli" +spd-index = "spd.scripts.index_spd_runs:main" [build-system] requires = ["setuptools", "wheel"] diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py new file mode 100644 index 000000000..2164844fb --- /dev/null +++ b/spd/scripts/index_spd_runs.py @@ -0,0 +1,287 @@ +"""Generate a TSV index of all SPD runs. + +Scans SPD_OUT_DIR/spd for run directories and produces a runs_index.tsv with metadata columns: +run_id, git_commit, uncommitted_changes, label, notes, hyperparameters, date, completed. + +The hyperparameters column shows only config values that differ between runs sharing the same label. + +Usage: + uv run spd/scripts/index_spd_runs.py # default paths + uv run spd/scripts/index_spd_runs.py -i /path/to/runs # override input dir + uv run spd/scripts/index_spd_runs.py -o /path/to/out.tsv # override output path +""" + +import argparse +import csv +import json +from pathlib import Path +from typing import Any + +import yaml + +from spd.settings import SPD_OUT_DIR +from spd.utils.run_utils import _DISCRIMINATED_LIST_FIELDS + +COLUMNS = [ + "run_id", + "date", + "git_commit", + "uncommitted_changes", + "label", + "completed", + "hyperparameters", + "notes", +] + +NA = "NA" + + +def _flatten_dict(d: dict[str, Any], prefix: str = "") -> dict[str, str]: + """Recursively flatten a nested dict with dot-separated keys. + + For discriminated list fields, uses the discriminator value as key instead of index. + For other lists, uses index-based keys. + """ + flat: dict[str, str] = {} + for key, value in d.items(): + full_key = f"{prefix}{key}" if not prefix else f"{prefix}.{key}" + if isinstance(value, dict): + flat.update(_flatten_dict(value, full_key)) + elif isinstance(value, list) and key in _DISCRIMINATED_LIST_FIELDS: + disc_field = _DISCRIMINATED_LIST_FIELDS[key] + for item in value: + assert isinstance(item, dict) + disc_value = item[disc_field] + sub = {k: v for k, v in item.items() if k != disc_field} + flat.update(_flatten_dict(sub, f"{full_key}.{disc_value}")) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + flat.update(_flatten_dict(item, f"{full_key}.{i}")) + else: + flat[f"{full_key}.{i}"] = str(item) + else: + flat[full_key] = str(value) + return flat + + +def _read_metadata(run_dir: Path) -> dict[str, str]: + """Read metadata from run_metadata.json, or return NAs for legacy runs.""" + metadata_path = run_dir / "run_metadata.json" + if metadata_path.exists(): + with open(metadata_path) as f: + meta = json.load(f) + # Handle both old field name (git_dirty) and new (uncommitted_changes) + uncommitted = meta.get("uncommitted_changes", meta.get("git_dirty", NA)) + return { + "run_id": str(meta.get("run_id", run_dir.name)), + "git_commit": str(meta.get("git_commit", NA)), + "uncommitted_changes": str(uncommitted), + "label": str(meta.get("label", "")), + "notes": str(meta.get("notes", "")), + "date": str(meta.get("date", NA)), + "completed": str(meta.get("completed", NA)), + } + + # Legacy run without metadata — try to determine completion from config + metrics + completed = _check_legacy_completion(run_dir) + return { + "run_id": run_dir.name, + "git_commit": NA, + "uncommitted_changes": NA, + "label": "", + "notes": "", + "date": NA, + "completed": str(completed), + } + + +def _check_legacy_completion(run_dir: Path) -> bool: + """Determine if a legacy run (no run_metadata.json) completed.""" + config_path = run_dir / "final_config.yaml" + metrics_path = run_dir / "metrics.jsonl" + + if not config_path.exists() or not metrics_path.exists(): + return False + + with open(config_path) as f: + config_dict = yaml.safe_load(f) + total_steps = config_dict.get("steps") + if total_steps is None: + return False + + # Read last line of metrics.jsonl + last_line = "" + with open(metrics_path, "rb") as f: + # Seek backwards to find last newline + f.seek(0, 2) + size = f.tell() + if size == 0: + return False + pos = size - 1 + while pos > 0: + f.seek(pos) + char = f.read(1) + if char == b"\n" and pos < size - 1: + break + pos -= 1 + last_line = f.readline().decode().strip() + + if not last_line: + return False + + last_step = json.loads(last_line).get("step", -1) + return last_step >= total_steps + + +def _load_existing_index(index_path: Path) -> dict[str, dict[str, str]]: + """Load existing TSV index into dict[run_id → row].""" + if not index_path.exists(): + return {} + rows: dict[str, dict[str, str]] = {} + with open(index_path, newline="") as f: + reader = csv.DictReader(f, delimiter="\t") + for row in reader: + run_id = row["run_id"] + rows[run_id] = dict(row) + return rows + + +def _compute_hyperparameters( + label_groups: dict[str, list[str]], + run_dirs: dict[str, Path], +) -> dict[str, str]: + """Compute hyperparameters column for all runs. + + For each label group with >=2 runs, loads final_config.yaml, flattens, and diffs. + """ + hyperparams: dict[str, str] = {} + + for _label, run_ids in label_groups.items(): + if len(run_ids) < 2: + for rid in run_ids: + hyperparams[rid] = "" + continue + + # Load and flatten configs for each run in the group + flattened_configs: dict[str, dict[str, str]] = {} + for rid in run_ids: + config_path = run_dirs[rid] / "final_config.yaml" + if not config_path.exists(): + flattened_configs[rid] = {} + continue + with open(config_path) as f: + config_dict = yaml.safe_load(f) + flattened_configs[rid] = _flatten_dict(config_dict) + + # Find keys that differ across runs (ignore notes and label) + all_keys = set() + for fc in flattened_configs.values(): + all_keys.update(fc.keys()) + + ignore_keys = {"notes", "label"} + differing_keys: list[str] = [] + for key in sorted(all_keys): + if key in ignore_keys: + continue + values = {fc.get(key) for fc in flattened_configs.values()} + if len(values) > 1: + differing_keys.append(key) + + # Format hyperparameters string for each run + for rid in run_ids: + fc = flattened_configs.get(rid, {}) + parts = [f"{k}={fc.get(k, NA)}" for k in differing_keys] + hyperparams[rid] = " ".join(parts) + + return hyperparams + + +def build_index(runs_dir: Path, index_path: Path) -> None: + existing = _load_existing_index(index_path) + + # Discover all run directories + run_dirs: dict[str, Path] = {} + for entry in sorted(runs_dir.iterdir()): + if entry.is_dir(): + run_dirs[entry.name] = entry + + # Phase 1: collect per-run metadata (using cache where possible) + rows: dict[str, dict[str, str]] = {} + new_run_ids: set[str] = set() + for run_id, run_dir in run_dirs.items(): + if run_id in existing: + rows[run_id] = existing[run_id] + else: + rows[run_id] = _read_metadata(run_dir) + new_run_ids.add(run_id) + + # Phase 2: compute hyperparameters + # Group runs by label + label_groups: dict[str, list[str]] = {} + for run_id, row in rows.items(): + label = row.get("label", "") + if label and label != NA: + label_groups.setdefault(label, []).append(run_id) + + # Determine which label groups need recomputation + groups_to_recompute: set[str] = set() + groups_cached: set[str] = set() + for label, run_ids in label_groups.items(): + if any(rid in new_run_ids for rid in run_ids): + groups_to_recompute.add(label) + else: + groups_cached.add(label) + + # Recompute hyperparameters for groups with new runs + recompute_groups = { + lbl: rids for lbl, rids in label_groups.items() if lbl in groups_to_recompute + } + fresh_hyperparams = _compute_hyperparameters(recompute_groups, run_dirs) + + # Assign hyperparameters to all runs + for run_id, row in rows.items(): + if run_id in fresh_hyperparams: + row["hyperparameters"] = fresh_hyperparams[run_id] + elif run_id not in existing: + # New run not in any label group (or solo label) + row.setdefault("hyperparameters", "") + + # Phase 3: write TSV + with open(index_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=COLUMNS, delimiter="\t", extrasaction="ignore") + writer.writeheader() + for run_id in sorted(rows, key=lambda rid: rows[rid].get("date", ""), reverse=True): + writer.writerow(rows[run_id]) + + print(f"Wrote {len(rows)} runs to {index_path}") + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate TSV index of SPD runs") + parser.add_argument( + "-i", + "--input-dir", + type=Path, + default=SPD_OUT_DIR / "spd", + help="Directory containing run subdirectories", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=None, + help="Output TSV path (default: /runs_index.tsv)", + ) + args = parser.parse_args() + + runs_dir: Path = args.input_dir + assert runs_dir.is_dir(), f"Runs directory not found: {runs_dir}" + + index_path: Path = args.output if args.output else runs_dir / "runs_index.tsv" + + build_index(runs_dir, index_path) + + +if __name__ == "__main__": + main() From a31a3e3e18b34da0d52b983e21398b8c2c78e867 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Tue, 24 Feb 2026 16:00:16 -0800 Subject: [PATCH 03/13] Don't mention index in pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index d71232a94..88c3405a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ spd-harvest = "spd.harvest.scripts.run_slurm_cli:cli" spd-autointerp = "spd.autointerp.scripts.run_slurm_cli:cli" spd-attributions = "spd.dataset_attributions.scripts.run_slurm_cli:cli" spd-postprocess = "spd.postprocess.cli:cli" -spd-index = "spd.scripts.index_spd_runs:main" [build-system] requires = ["setuptools", "wheel"] From ed2064633e01f590557937f0d61574192479336c Mon Sep 17 00:00:00 2001 From: Antovigo Date: Tue, 24 Feb 2026 16:25:26 -0800 Subject: [PATCH 04/13] Reprocess recent unfinished run in case they are now finished --- spd/scripts/index_spd_runs.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index 2164844fb..e1e11ed8b 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -14,6 +14,7 @@ import argparse import csv import json +from datetime import UTC, datetime, timedelta from pathlib import Path from typing import Any @@ -35,6 +36,20 @@ NA = "NA" +_RECHECK_WINDOW = timedelta(days=7) + + +def _is_recent(row: dict[str, str]) -> bool: + """Return True if the run's date is within the recheck window.""" + date_str = row.get("date", NA) + if date_str == NA: + return False + try: + run_date = datetime.strptime(date_str, "%Y-%m-%d %H:%M").replace(tzinfo=UTC) + except ValueError: + return False + return datetime.now(UTC) - run_date < _RECHECK_WINDOW + def _flatten_dict(d: dict[str, Any], prefix: str = "") -> dict[str, str]: """Recursively flatten a nested dict with dot-separated keys. @@ -210,11 +225,13 @@ def build_index(runs_dir: Path, index_path: Path) -> None: rows: dict[str, dict[str, str]] = {} new_run_ids: set[str] = set() for run_id, run_dir in run_dirs.items(): - if run_id in existing: - rows[run_id] = existing[run_id] + cached = existing.get(run_id) + if cached and (cached.get("completed") == "True" or not _is_recent(cached)): + rows[run_id] = cached else: rows[run_id] = _read_metadata(run_dir) - new_run_ids.add(run_id) + if not cached: + new_run_ids.add(run_id) # Phase 2: compute hyperparameters # Group runs by label From 1d249f9fecb3ba5709ab48e2e6ab10849d56ce27 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Tue, 24 Feb 2026 16:29:07 -0800 Subject: [PATCH 05/13] Show progress --- spd/scripts/index_spd_runs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index e1e11ed8b..b4a6e9267 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -19,6 +19,7 @@ from typing import Any import yaml +from tqdm import tqdm from spd.settings import SPD_OUT_DIR from spd.utils.run_utils import _DISCRIMINATED_LIST_FIELDS @@ -224,7 +225,7 @@ def build_index(runs_dir: Path, index_path: Path) -> None: # Phase 1: collect per-run metadata (using cache where possible) rows: dict[str, dict[str, str]] = {} new_run_ids: set[str] = set() - for run_id, run_dir in run_dirs.items(): + for run_id, run_dir in tqdm(run_dirs.items(), desc="Reading runs"): cached = existing.get(run_id) if cached and (cached.get("completed") == "True" or not _is_recent(cached)): rows[run_id] = cached From bfdc1a03a56af5f499b28003f2a8ee43d7740f54 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Wed, 25 Feb 2026 15:16:39 -0800 Subject: [PATCH 06/13] sort by date, put NAs at the end --- spd/scripts/index_spd_runs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index b4a6e9267..5132287b8 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -269,7 +269,11 @@ def build_index(runs_dir: Path, index_path: Path) -> None: with open(index_path, "w", newline="") as f: writer = csv.DictWriter(f, fieldnames=COLUMNS, delimiter="\t", extrasaction="ignore") writer.writeheader() - for run_id in sorted(rows, key=lambda rid: rows[rid].get("date", ""), reverse=True): + for run_id in sorted( + rows, + key=lambda rid: rows[rid].get("date", "") if rows[rid].get("date", "") != NA else "", + reverse=True, + ): writer.writerow(rows[run_id]) print(f"Wrote {len(rows)} runs to {index_path}") From 39626468b0fba321ba119dc91161fcaaae5b72da Mon Sep 17 00:00:00 2001 From: Antovigo Date: Wed, 25 Feb 2026 16:09:09 -0800 Subject: [PATCH 07/13] support labels with mixed pretrained models --- spd/scripts/index_spd_runs.py | 49 +++++++++++++++++++++++------------ 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index 5132287b8..1c52064ed 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -190,25 +190,40 @@ def _compute_hyperparameters( config_dict = yaml.safe_load(f) flattened_configs[rid] = _flatten_dict(config_dict) - # Find keys that differ across runs (ignore notes and label) - all_keys = set() - for fc in flattened_configs.values(): - all_keys.update(fc.keys()) + # Sub-group by pretrained_model_path + pretrained_key = "pretrained_model_path" + subgroups: dict[str | None, list[str]] = {} + for rid in run_ids: + val = flattened_configs.get(rid, {}).get(pretrained_key) + subgroups.setdefault(val, []).append(rid) + multiple_pretrained = len(subgroups) > 1 ignore_keys = {"notes", "label"} - differing_keys: list[str] = [] - for key in sorted(all_keys): - if key in ignore_keys: - continue - values = {fc.get(key) for fc in flattened_configs.values()} - if len(values) > 1: - differing_keys.append(key) - - # Format hyperparameters string for each run - for rid in run_ids: - fc = flattened_configs.get(rid, {}) - parts = [f"{k}={fc.get(k, NA)}" for k in differing_keys] - hyperparams[rid] = " ".join(parts) + if multiple_pretrained: + ignore_keys.add(pretrained_key) + + for pretrained_val, sub_rids in subgroups.items(): + # Find differing keys within this sub-group + differing_keys: list[str] = [] + if len(sub_rids) >= 2: + sub_all_keys: set[str] = set() + for rid in sub_rids: + sub_all_keys.update(flattened_configs.get(rid, {}).keys()) + for key in sorted(sub_all_keys): + if key in ignore_keys: + continue + values = {flattened_configs.get(rid, {}).get(key) for rid in sub_rids} + if len(values) > 1: + differing_keys.append(key) + + # Format hyperparameters string for each run + for rid in sub_rids: + fc = flattened_configs.get(rid, {}) + parts: list[str] = [] + if multiple_pretrained: + parts.append(f"{pretrained_key}={pretrained_val or NA}") + parts.extend(f"{k}={fc.get(k, NA)}" for k in differing_keys) + hyperparams[rid] = " ".join(parts) return hyperparams From cd9e5258ab759d1d80bb230f0ffcb3a6207c0467 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Wed, 25 Feb 2026 16:19:06 -0800 Subject: [PATCH 08/13] option to force reprocessing all runs --- spd/scripts/index_spd_runs.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index 1c52064ed..140da29f4 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -228,8 +228,8 @@ def _compute_hyperparameters( return hyperparams -def build_index(runs_dir: Path, index_path: Path) -> None: - existing = _load_existing_index(index_path) +def build_index(runs_dir: Path, index_path: Path, *, force: bool = False) -> None: + existing = _load_existing_index(index_path) if not force else {} # Discover all run directories run_dirs: dict[str, Path] = {} @@ -240,12 +240,16 @@ def build_index(runs_dir: Path, index_path: Path) -> None: # Phase 1: collect per-run metadata (using cache where possible) rows: dict[str, dict[str, str]] = {} new_run_ids: set[str] = set() + n_reprocessed = 0 + n_cached = 0 for run_id, run_dir in tqdm(run_dirs.items(), desc="Reading runs"): cached = existing.get(run_id) if cached and (cached.get("completed") == "True" or not _is_recent(cached)): rows[run_id] = cached + n_cached += 1 else: rows[run_id] = _read_metadata(run_dir) + n_reprocessed += 1 if not cached: new_run_ids.add(run_id) @@ -291,7 +295,10 @@ def build_index(runs_dir: Path, index_path: Path) -> None: ): writer.writerow(rows[run_id]) - print(f"Wrote {len(rows)} runs to {index_path}") + print( + f"Wrote {len(rows)} runs to {index_path}" + f" ({n_reprocessed} reprocessed, {n_cached} cached)" + ) def main() -> None: @@ -310,6 +317,11 @@ def main() -> None: default=None, help="Output TSV path (default: /runs_index.tsv)", ) + parser.add_argument( + "--force", + action="store_true", + help="Bypass cache and reprocess all runs", + ) args = parser.parse_args() runs_dir: Path = args.input_dir @@ -317,7 +329,7 @@ def main() -> None: index_path: Path = args.output if args.output else runs_dir / "runs_index.tsv" - build_index(runs_dir, index_path) + build_index(runs_dir, index_path, force=args.force) if __name__ == "__main__": From 48092fbbe9ae0dbb8c26cabded80b9d78667fc6a Mon Sep 17 00:00:00 2001 From: Antovigo Date: Sat, 28 Feb 2026 12:51:18 -0800 Subject: [PATCH 09/13] fix hyperparam cachig bug --- spd/scripts/index_spd_runs.py | 61 +++++++++++++++++++++++++++++++---- 1 file changed, 54 insertions(+), 7 deletions(-) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index 140da29f4..f8af19329 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -228,7 +228,37 @@ def _compute_hyperparameters( return hyperparams -def build_index(runs_dir: Path, index_path: Path, *, force: bool = False) -> None: +def _read_final_metrics(run_dir: Path, metric_names: list[str]) -> dict[str, str]: + """Read the last line of metrics.jsonl and extract requested metric values.""" + metrics_path = run_dir / "metrics.jsonl" + if not metrics_path.exists(): + return {name: NA for name in metric_names} + + # Read last line using backwards seek + with open(metrics_path, "rb") as f: + f.seek(0, 2) + size = f.tell() + if size == 0: + return {name: NA for name in metric_names} + pos = size - 1 + while pos > 0: + f.seek(pos) + char = f.read(1) + if char == b"\n" and pos < size - 1: + break + pos -= 1 + last_line = f.readline().decode().strip() + + if not last_line: + return {name: NA for name in metric_names} + + data = json.loads(last_line) + return {name: str(data.get(name, NA)) for name in metric_names} + + +def build_index( + runs_dir: Path, index_path: Path, *, force: bool = False, metrics: list[str] | None = None +) -> None: existing = _load_existing_index(index_path) if not force else {} # Discover all run directories @@ -244,11 +274,22 @@ def build_index(runs_dir: Path, index_path: Path, *, force: bool = False) -> Non n_cached = 0 for run_id, run_dir in tqdm(run_dirs.items(), desc="Reading runs"): cached = existing.get(run_id) - if cached and (cached.get("completed") == "True" or not _is_recent(cached)): + # Cache miss if metrics were requested but aren't in the cached row + metrics_missing = metrics and cached and metrics[0] not in cached + if ( + cached + and not metrics_missing + and (cached.get("completed") == "True" or not _is_recent(cached)) + ): rows[run_id] = cached n_cached += 1 else: - rows[run_id] = _read_metadata(run_dir) + row = _read_metadata(run_dir) + if metrics: + row.update(_read_final_metrics(run_dir, metrics)) + if cached: + row.setdefault("hyperparameters", cached.get("hyperparameters", "")) + rows[run_id] = row n_reprocessed += 1 if not cached: new_run_ids.add(run_id) @@ -285,8 +326,9 @@ def build_index(runs_dir: Path, index_path: Path, *, force: bool = False) -> Non row.setdefault("hyperparameters", "") # Phase 3: write TSV + fieldnames = COLUMNS + (metrics or []) with open(index_path, "w", newline="") as f: - writer = csv.DictWriter(f, fieldnames=COLUMNS, delimiter="\t", extrasaction="ignore") + writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter="\t", extrasaction="ignore") writer.writeheader() for run_id in sorted( rows, @@ -296,8 +338,7 @@ def build_index(runs_dir: Path, index_path: Path, *, force: bool = False) -> Non writer.writerow(rows[run_id]) print( - f"Wrote {len(rows)} runs to {index_path}" - f" ({n_reprocessed} reprocessed, {n_cached} cached)" + f"Wrote {len(rows)} runs to {index_path} ({n_reprocessed} reprocessed, {n_cached} cached)" ) @@ -322,14 +363,20 @@ def main() -> None: action="store_true", help="Bypass cache and reprocess all runs", ) + parser.add_argument( + "--metrics", + type=str, + help="Comma-separated metric names to include (e.g. 'train/loss/total,train/l0/total')", + ) args = parser.parse_args() runs_dir: Path = args.input_dir assert runs_dir.is_dir(), f"Runs directory not found: {runs_dir}" index_path: Path = args.output if args.output else runs_dir / "runs_index.tsv" + metric_names = [m.strip() for m in args.metrics.split(",")] if args.metrics else None - build_index(runs_dir, index_path, force=args.force) + build_index(runs_dir, index_path, force=args.force, metrics=metric_names) if __name__ == "__main__": From 474513070bce31797ce90b6386fa586ed8b70471 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Sat, 28 Feb 2026 15:48:40 -0800 Subject: [PATCH 10/13] clean up index_spd_runs: extract helper, remove legacy shims - Extract _read_last_jsonl_line() to deduplicate backwards-seek JSONL pattern - Remove git_dirty fallback in _read_metadata() - Delete unused groups_cached variable - Skip reading metrics for uncompleted runs --- spd/scripts/index_spd_runs.py | 78 +++++++++++++++-------------------- 1 file changed, 34 insertions(+), 44 deletions(-) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index f8af19329..14060fbc5 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -87,12 +87,10 @@ def _read_metadata(run_dir: Path) -> dict[str, str]: if metadata_path.exists(): with open(metadata_path) as f: meta = json.load(f) - # Handle both old field name (git_dirty) and new (uncommitted_changes) - uncommitted = meta.get("uncommitted_changes", meta.get("git_dirty", NA)) return { "run_id": str(meta.get("run_id", run_dir.name)), "git_commit": str(meta.get("git_commit", NA)), - "uncommitted_changes": str(uncommitted), + "uncommitted_changes": str(meta.get("uncommitted_changes", NA)), "label": str(meta.get("label", "")), "notes": str(meta.get("notes", "")), "date": str(meta.get("date", NA)), @@ -112,28 +110,16 @@ def _read_metadata(run_dir: Path) -> dict[str, str]: } -def _check_legacy_completion(run_dir: Path) -> bool: - """Determine if a legacy run (no run_metadata.json) completed.""" - config_path = run_dir / "final_config.yaml" - metrics_path = run_dir / "metrics.jsonl" - - if not config_path.exists() or not metrics_path.exists(): - return False +def _read_last_jsonl_line(path: Path) -> dict[str, Any] | None: + """Read and parse the last line of a JSONL file using backwards seek. - with open(config_path) as f: - config_dict = yaml.safe_load(f) - total_steps = config_dict.get("steps") - if total_steps is None: - return False - - # Read last line of metrics.jsonl - last_line = "" - with open(metrics_path, "rb") as f: - # Seek backwards to find last newline + Returns the parsed JSON dict, or None if the file is empty. + """ + with open(path, "rb") as f: f.seek(0, 2) size = f.tell() if size == 0: - return False + return None pos = size - 1 while pos > 0: f.seek(pos) @@ -144,10 +130,29 @@ def _check_legacy_completion(run_dir: Path) -> bool: last_line = f.readline().decode().strip() if not last_line: + return None + return json.loads(last_line) + + +def _check_legacy_completion(run_dir: Path) -> bool: + """Determine if a legacy run (no run_metadata.json) completed.""" + config_path = run_dir / "final_config.yaml" + metrics_path = run_dir / "metrics.jsonl" + + if not config_path.exists() or not metrics_path.exists(): return False - last_step = json.loads(last_line).get("step", -1) - return last_step >= total_steps + with open(config_path) as f: + config_dict = yaml.safe_load(f) + total_steps = config_dict.get("steps") + if total_steps is None: + return False + + last = _read_last_jsonl_line(metrics_path) + if last is None: + return False + + return last.get("step", -1) >= total_steps def _load_existing_index(index_path: Path) -> dict[str, dict[str, str]]: @@ -234,25 +239,10 @@ def _read_final_metrics(run_dir: Path, metric_names: list[str]) -> dict[str, str if not metrics_path.exists(): return {name: NA for name in metric_names} - # Read last line using backwards seek - with open(metrics_path, "rb") as f: - f.seek(0, 2) - size = f.tell() - if size == 0: - return {name: NA for name in metric_names} - pos = size - 1 - while pos > 0: - f.seek(pos) - char = f.read(1) - if char == b"\n" and pos < size - 1: - break - pos -= 1 - last_line = f.readline().decode().strip() - - if not last_line: + data = _read_last_jsonl_line(metrics_path) + if data is None: return {name: NA for name in metric_names} - data = json.loads(last_line) return {name: str(data.get(name, NA)) for name in metric_names} @@ -286,7 +276,10 @@ def build_index( else: row = _read_metadata(run_dir) if metrics: - row.update(_read_final_metrics(run_dir, metrics)) + if row.get("completed") == "True": + row.update(_read_final_metrics(run_dir, metrics)) + else: + row.update({name: NA for name in metrics}) if cached: row.setdefault("hyperparameters", cached.get("hyperparameters", "")) rows[run_id] = row @@ -304,12 +297,9 @@ def build_index( # Determine which label groups need recomputation groups_to_recompute: set[str] = set() - groups_cached: set[str] = set() for label, run_ids in label_groups.items(): if any(rid in new_run_ids for rid in run_ids): groups_to_recompute.add(label) - else: - groups_cached.add(label) # Recompute hyperparameters for groups with new runs recompute_groups = { From 40845315af06e2a6296f585ba4a78a732e093089 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Mon, 2 Mar 2026 13:55:44 -0800 Subject: [PATCH 11/13] store run duration in metadat --- spd/run_spd.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spd/run_spd.py b/spd/run_spd.py index 0592aa1f8..15599f514 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -460,6 +460,8 @@ def _mark_run_completed(out_dir: Path, save_to_wandb: bool) -> None: with open(metadata_path) as f: metadata = json.load(f) + start_time = datetime.strptime(metadata["date"], "%Y-%m-%d %H:%M").replace(tzinfo=UTC) + metadata["duration"] = round((datetime.now(UTC) - start_time).total_seconds() / 3600, 1) metadata["completed"] = True save_file(metadata, metadata_path, indent=2) From ac236be92dcfcb80bb4aa9dbcb24b3fb698d9447 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Tue, 3 Mar 2026 17:31:30 -0800 Subject: [PATCH 12/13] add duration_hours column to runs index Read duration from run_metadata.json and display it in the TSV index. Also round duration to 2 decimal places instead of 1. --- spd/run_spd.py | 2 +- spd/scripts/index_spd_runs.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/spd/run_spd.py b/spd/run_spd.py index 15599f514..6053363ff 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -461,7 +461,7 @@ def _mark_run_completed(out_dir: Path, save_to_wandb: bool) -> None: metadata = json.load(f) start_time = datetime.strptime(metadata["date"], "%Y-%m-%d %H:%M").replace(tzinfo=UTC) - metadata["duration"] = round((datetime.now(UTC) - start_time).total_seconds() / 3600, 1) + metadata["duration"] = round((datetime.now(UTC) - start_time).total_seconds() / 3600, 2) metadata["completed"] = True save_file(metadata, metadata_path, indent=2) diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py index 14060fbc5..de29f5159 100644 --- a/spd/scripts/index_spd_runs.py +++ b/spd/scripts/index_spd_runs.py @@ -1,7 +1,7 @@ """Generate a TSV index of all SPD runs. Scans SPD_OUT_DIR/spd for run directories and produces a runs_index.tsv with metadata columns: -run_id, git_commit, uncommitted_changes, label, notes, hyperparameters, date, completed. +run_id, git_commit, uncommitted_changes, label, notes, hyperparameters, date, completed, duration. The hyperparameters column shows only config values that differ between runs sharing the same label. @@ -31,6 +31,7 @@ "uncommitted_changes", "label", "completed", + "duration_hours", "hyperparameters", "notes", ] @@ -95,6 +96,7 @@ def _read_metadata(run_dir: Path) -> dict[str, str]: "notes": str(meta.get("notes", "")), "date": str(meta.get("date", NA)), "completed": str(meta.get("completed", NA)), + "duration_hours": str(meta.get("duration", NA)), } # Legacy run without metadata — try to determine completion from config + metrics @@ -107,6 +109,7 @@ def _read_metadata(run_dir: Path) -> dict[str, str]: "notes": "", "date": NA, "completed": str(completed), + "duration_hours": NA, } From 3b086a15ff2563afb5a1f37ef76a8e6d010cde43 Mon Sep 17 00:00:00 2001 From: Antovigo Date: Tue, 3 Mar 2026 17:54:06 -0800 Subject: [PATCH 13/13] revert unintended whitespace change in tms_5-2_config.yaml --- spd/experiments/tms/tms_5-2_config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spd/experiments/tms/tms_5-2_config.yaml b/spd/experiments/tms/tms_5-2_config.yaml index 53741bdbc..07bc9056a 100644 --- a/spd/experiments/tms/tms_5-2_config.yaml +++ b/spd/experiments/tms/tms_5-2_config.yaml @@ -79,4 +79,4 @@ pretrained_model_path: "wandb:goodfire/spd-pre-Sep-2025/runs/0hsp07o4" task_config: task_name: tms feature_probability: 0.05 - data_generation_type: "at_least_zero_active" + data_generation_type: "at_least_zero_active" \ No newline at end of file