diff --git a/spd/configs.py b/spd/configs.py index efbfb3bb4..f93e5d01d 100644 --- a/spd/configs.py +++ b/spd/configs.py @@ -744,6 +744,14 @@ class Config(BaseConfig): default="", description="Prefix prepended to an auto-generated WandB run name", ) + label: str = Field( + default="", + description="Short human-readable label for this run or series of runs", + ) + notes: str = Field( + default="", + description="Free-form notes about this run's purpose or configuration choices", + ) # --- General --- seed: int = Field( diff --git a/spd/run_spd.py b/spd/run_spd.py index d303a24c5..6053363ff 100644 --- a/spd/run_spd.py +++ b/spd/run_spd.py @@ -1,9 +1,11 @@ """Run SPD on a model.""" import gc +import json import os from collections import defaultdict from collections.abc import Iterator +from datetime import UTC, datetime from pathlib import Path from typing import Any, cast @@ -54,6 +56,7 @@ get_scheduled_value, save_pre_run_info, ) +from spd.utils.git_utils import repo_current_commit_hash, repo_is_clean from spd.utils.logging_utils import get_grad_norms_dict, local_log from spd.utils.module_utils import expand_module_patterns from spd.utils.run_utils import generate_run_id, save_file @@ -427,6 +430,45 @@ def create_pgd_data_iter() -> ( logger.info("Finished training loop.") +RUN_METADATA_FILENAME = "run_metadata.json" + + +def _write_run_metadata(out_dir: Path, run_id: str, config: Config, save_to_wandb: bool) -> None: + """Write run_metadata.json with git state, timestamp, and user annotations.""" + metadata = { + "run_id": run_id, + "git_commit": repo_current_commit_hash(), + "uncommitted_changes": not repo_is_clean(), + "date": datetime.now(UTC).strftime("%Y-%m-%d %H:%M"), + "label": config.label, + "notes": config.notes, + "completed": False, + } + + metadata_path = out_dir / RUN_METADATA_FILENAME + save_file(metadata, metadata_path, indent=2) + + if save_to_wandb: + try_wandb(wandb.save, str(metadata_path), base_path=str(out_dir), policy="now") + + +def _mark_run_completed(out_dir: Path, save_to_wandb: bool) -> None: + """Set completed=True in an existing run_metadata.json.""" + metadata_path = out_dir / RUN_METADATA_FILENAME + assert metadata_path.exists(), f"run_metadata.json not found at {metadata_path}" + + with open(metadata_path) as f: + metadata = json.load(f) + + start_time = datetime.strptime(metadata["date"], "%Y-%m-%d %H:%M").replace(tzinfo=UTC) + metadata["duration"] = round((datetime.now(UTC) - start_time).total_seconds() / 3600, 2) + metadata["completed"] = True + save_file(metadata, metadata_path, indent=2) + + if save_to_wandb: + try_wandb(wandb.save, str(metadata_path), base_path=str(out_dir), policy="now") + + def run_experiment( target_model: nn.Module, config: Config, @@ -474,6 +516,8 @@ def run_experiment( train_config=target_model_train_config, task_name=getattr(config.task_config, "task_name", None), ) + + _write_run_metadata(out_dir, run_id, config, save_to_wandb=config.wandb_project is not None) else: out_dir = None @@ -488,5 +532,8 @@ def run_experiment( tied_weights=tied_weights, ) - if is_main_process() and config.wandb_project: - wandb.finish() + if is_main_process(): + assert out_dir is not None + _mark_run_completed(out_dir, save_to_wandb=config.wandb_project is not None) + if config.wandb_project: + wandb.finish() diff --git a/spd/scripts/index_spd_runs.py b/spd/scripts/index_spd_runs.py new file mode 100644 index 000000000..de29f5159 --- /dev/null +++ b/spd/scripts/index_spd_runs.py @@ -0,0 +1,376 @@ +"""Generate a TSV index of all SPD runs. + +Scans SPD_OUT_DIR/spd for run directories and produces a runs_index.tsv with metadata columns: +run_id, git_commit, uncommitted_changes, label, notes, hyperparameters, date, completed, duration. + +The hyperparameters column shows only config values that differ between runs sharing the same label. + +Usage: + uv run spd/scripts/index_spd_runs.py # default paths + uv run spd/scripts/index_spd_runs.py -i /path/to/runs # override input dir + uv run spd/scripts/index_spd_runs.py -o /path/to/out.tsv # override output path +""" + +import argparse +import csv +import json +from datetime import UTC, datetime, timedelta +from pathlib import Path +from typing import Any + +import yaml +from tqdm import tqdm + +from spd.settings import SPD_OUT_DIR +from spd.utils.run_utils import _DISCRIMINATED_LIST_FIELDS + +COLUMNS = [ + "run_id", + "date", + "git_commit", + "uncommitted_changes", + "label", + "completed", + "duration_hours", + "hyperparameters", + "notes", +] + +NA = "NA" + +_RECHECK_WINDOW = timedelta(days=7) + + +def _is_recent(row: dict[str, str]) -> bool: + """Return True if the run's date is within the recheck window.""" + date_str = row.get("date", NA) + if date_str == NA: + return False + try: + run_date = datetime.strptime(date_str, "%Y-%m-%d %H:%M").replace(tzinfo=UTC) + except ValueError: + return False + return datetime.now(UTC) - run_date < _RECHECK_WINDOW + + +def _flatten_dict(d: dict[str, Any], prefix: str = "") -> dict[str, str]: + """Recursively flatten a nested dict with dot-separated keys. + + For discriminated list fields, uses the discriminator value as key instead of index. + For other lists, uses index-based keys. + """ + flat: dict[str, str] = {} + for key, value in d.items(): + full_key = f"{prefix}{key}" if not prefix else f"{prefix}.{key}" + if isinstance(value, dict): + flat.update(_flatten_dict(value, full_key)) + elif isinstance(value, list) and key in _DISCRIMINATED_LIST_FIELDS: + disc_field = _DISCRIMINATED_LIST_FIELDS[key] + for item in value: + assert isinstance(item, dict) + disc_value = item[disc_field] + sub = {k: v for k, v in item.items() if k != disc_field} + flat.update(_flatten_dict(sub, f"{full_key}.{disc_value}")) + elif isinstance(value, list): + for i, item in enumerate(value): + if isinstance(item, dict): + flat.update(_flatten_dict(item, f"{full_key}.{i}")) + else: + flat[f"{full_key}.{i}"] = str(item) + else: + flat[full_key] = str(value) + return flat + + +def _read_metadata(run_dir: Path) -> dict[str, str]: + """Read metadata from run_metadata.json, or return NAs for legacy runs.""" + metadata_path = run_dir / "run_metadata.json" + if metadata_path.exists(): + with open(metadata_path) as f: + meta = json.load(f) + return { + "run_id": str(meta.get("run_id", run_dir.name)), + "git_commit": str(meta.get("git_commit", NA)), + "uncommitted_changes": str(meta.get("uncommitted_changes", NA)), + "label": str(meta.get("label", "")), + "notes": str(meta.get("notes", "")), + "date": str(meta.get("date", NA)), + "completed": str(meta.get("completed", NA)), + "duration_hours": str(meta.get("duration", NA)), + } + + # Legacy run without metadata — try to determine completion from config + metrics + completed = _check_legacy_completion(run_dir) + return { + "run_id": run_dir.name, + "git_commit": NA, + "uncommitted_changes": NA, + "label": "", + "notes": "", + "date": NA, + "completed": str(completed), + "duration_hours": NA, + } + + +def _read_last_jsonl_line(path: Path) -> dict[str, Any] | None: + """Read and parse the last line of a JSONL file using backwards seek. + + Returns the parsed JSON dict, or None if the file is empty. + """ + with open(path, "rb") as f: + f.seek(0, 2) + size = f.tell() + if size == 0: + return None + pos = size - 1 + while pos > 0: + f.seek(pos) + char = f.read(1) + if char == b"\n" and pos < size - 1: + break + pos -= 1 + last_line = f.readline().decode().strip() + + if not last_line: + return None + return json.loads(last_line) + + +def _check_legacy_completion(run_dir: Path) -> bool: + """Determine if a legacy run (no run_metadata.json) completed.""" + config_path = run_dir / "final_config.yaml" + metrics_path = run_dir / "metrics.jsonl" + + if not config_path.exists() or not metrics_path.exists(): + return False + + with open(config_path) as f: + config_dict = yaml.safe_load(f) + total_steps = config_dict.get("steps") + if total_steps is None: + return False + + last = _read_last_jsonl_line(metrics_path) + if last is None: + return False + + return last.get("step", -1) >= total_steps + + +def _load_existing_index(index_path: Path) -> dict[str, dict[str, str]]: + """Load existing TSV index into dict[run_id → row].""" + if not index_path.exists(): + return {} + rows: dict[str, dict[str, str]] = {} + with open(index_path, newline="") as f: + reader = csv.DictReader(f, delimiter="\t") + for row in reader: + run_id = row["run_id"] + rows[run_id] = dict(row) + return rows + + +def _compute_hyperparameters( + label_groups: dict[str, list[str]], + run_dirs: dict[str, Path], +) -> dict[str, str]: + """Compute hyperparameters column for all runs. + + For each label group with >=2 runs, loads final_config.yaml, flattens, and diffs. + """ + hyperparams: dict[str, str] = {} + + for _label, run_ids in label_groups.items(): + if len(run_ids) < 2: + for rid in run_ids: + hyperparams[rid] = "" + continue + + # Load and flatten configs for each run in the group + flattened_configs: dict[str, dict[str, str]] = {} + for rid in run_ids: + config_path = run_dirs[rid] / "final_config.yaml" + if not config_path.exists(): + flattened_configs[rid] = {} + continue + with open(config_path) as f: + config_dict = yaml.safe_load(f) + flattened_configs[rid] = _flatten_dict(config_dict) + + # Sub-group by pretrained_model_path + pretrained_key = "pretrained_model_path" + subgroups: dict[str | None, list[str]] = {} + for rid in run_ids: + val = flattened_configs.get(rid, {}).get(pretrained_key) + subgroups.setdefault(val, []).append(rid) + + multiple_pretrained = len(subgroups) > 1 + ignore_keys = {"notes", "label"} + if multiple_pretrained: + ignore_keys.add(pretrained_key) + + for pretrained_val, sub_rids in subgroups.items(): + # Find differing keys within this sub-group + differing_keys: list[str] = [] + if len(sub_rids) >= 2: + sub_all_keys: set[str] = set() + for rid in sub_rids: + sub_all_keys.update(flattened_configs.get(rid, {}).keys()) + for key in sorted(sub_all_keys): + if key in ignore_keys: + continue + values = {flattened_configs.get(rid, {}).get(key) for rid in sub_rids} + if len(values) > 1: + differing_keys.append(key) + + # Format hyperparameters string for each run + for rid in sub_rids: + fc = flattened_configs.get(rid, {}) + parts: list[str] = [] + if multiple_pretrained: + parts.append(f"{pretrained_key}={pretrained_val or NA}") + parts.extend(f"{k}={fc.get(k, NA)}" for k in differing_keys) + hyperparams[rid] = " ".join(parts) + + return hyperparams + + +def _read_final_metrics(run_dir: Path, metric_names: list[str]) -> dict[str, str]: + """Read the last line of metrics.jsonl and extract requested metric values.""" + metrics_path = run_dir / "metrics.jsonl" + if not metrics_path.exists(): + return {name: NA for name in metric_names} + + data = _read_last_jsonl_line(metrics_path) + if data is None: + return {name: NA for name in metric_names} + + return {name: str(data.get(name, NA)) for name in metric_names} + + +def build_index( + runs_dir: Path, index_path: Path, *, force: bool = False, metrics: list[str] | None = None +) -> None: + existing = _load_existing_index(index_path) if not force else {} + + # Discover all run directories + run_dirs: dict[str, Path] = {} + for entry in sorted(runs_dir.iterdir()): + if entry.is_dir(): + run_dirs[entry.name] = entry + + # Phase 1: collect per-run metadata (using cache where possible) + rows: dict[str, dict[str, str]] = {} + new_run_ids: set[str] = set() + n_reprocessed = 0 + n_cached = 0 + for run_id, run_dir in tqdm(run_dirs.items(), desc="Reading runs"): + cached = existing.get(run_id) + # Cache miss if metrics were requested but aren't in the cached row + metrics_missing = metrics and cached and metrics[0] not in cached + if ( + cached + and not metrics_missing + and (cached.get("completed") == "True" or not _is_recent(cached)) + ): + rows[run_id] = cached + n_cached += 1 + else: + row = _read_metadata(run_dir) + if metrics: + if row.get("completed") == "True": + row.update(_read_final_metrics(run_dir, metrics)) + else: + row.update({name: NA for name in metrics}) + if cached: + row.setdefault("hyperparameters", cached.get("hyperparameters", "")) + rows[run_id] = row + n_reprocessed += 1 + if not cached: + new_run_ids.add(run_id) + + # Phase 2: compute hyperparameters + # Group runs by label + label_groups: dict[str, list[str]] = {} + for run_id, row in rows.items(): + label = row.get("label", "") + if label and label != NA: + label_groups.setdefault(label, []).append(run_id) + + # Determine which label groups need recomputation + groups_to_recompute: set[str] = set() + for label, run_ids in label_groups.items(): + if any(rid in new_run_ids for rid in run_ids): + groups_to_recompute.add(label) + + # Recompute hyperparameters for groups with new runs + recompute_groups = { + lbl: rids for lbl, rids in label_groups.items() if lbl in groups_to_recompute + } + fresh_hyperparams = _compute_hyperparameters(recompute_groups, run_dirs) + + # Assign hyperparameters to all runs + for run_id, row in rows.items(): + if run_id in fresh_hyperparams: + row["hyperparameters"] = fresh_hyperparams[run_id] + elif run_id not in existing: + # New run not in any label group (or solo label) + row.setdefault("hyperparameters", "") + + # Phase 3: write TSV + fieldnames = COLUMNS + (metrics or []) + with open(index_path, "w", newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames, delimiter="\t", extrasaction="ignore") + writer.writeheader() + for run_id in sorted( + rows, + key=lambda rid: rows[rid].get("date", "") if rows[rid].get("date", "") != NA else "", + reverse=True, + ): + writer.writerow(rows[run_id]) + + print( + f"Wrote {len(rows)} runs to {index_path} ({n_reprocessed} reprocessed, {n_cached} cached)" + ) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate TSV index of SPD runs") + parser.add_argument( + "-i", + "--input-dir", + type=Path, + default=SPD_OUT_DIR / "spd", + help="Directory containing run subdirectories", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=None, + help="Output TSV path (default: /runs_index.tsv)", + ) + parser.add_argument( + "--force", + action="store_true", + help="Bypass cache and reprocess all runs", + ) + parser.add_argument( + "--metrics", + type=str, + help="Comma-separated metric names to include (e.g. 'train/loss/total,train/l0/total')", + ) + args = parser.parse_args() + + runs_dir: Path = args.input_dir + assert runs_dir.is_dir(), f"Runs directory not found: {runs_dir}" + + index_path: Path = args.output if args.output else runs_dir / "runs_index.tsv" + metric_names = [m.strip() for m in args.metrics.split(",")] if args.metrics else None + + build_index(runs_dir, index_path, force=args.force, metrics=metric_names) + + +if __name__ == "__main__": + main()