Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions spd/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,14 @@ class Config(BaseConfig):
default="",
description="Prefix prepended to an auto-generated WandB run name",
)
label: str = Field(
default="",
description="Short human-readable label for this run or series of runs",
)
notes: str = Field(
default="",
description="Free-form notes about this run's purpose or configuration choices",
)

# --- General ---
seed: int = Field(
Expand Down
51 changes: 49 additions & 2 deletions spd/run_spd.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Run SPD on a model."""

import gc
import json
import os
from collections import defaultdict
from collections.abc import Iterator
from datetime import UTC, datetime
from pathlib import Path
from typing import Any, cast

Expand Down Expand Up @@ -54,6 +56,7 @@
get_scheduled_value,
save_pre_run_info,
)
from spd.utils.git_utils import repo_current_commit_hash, repo_is_clean
from spd.utils.logging_utils import get_grad_norms_dict, local_log
from spd.utils.module_utils import expand_module_patterns
from spd.utils.run_utils import generate_run_id, save_file
Expand Down Expand Up @@ -427,6 +430,45 @@ def create_pgd_data_iter() -> (
logger.info("Finished training loop.")


RUN_METADATA_FILENAME = "run_metadata.json"


def _write_run_metadata(out_dir: Path, run_id: str, config: Config, save_to_wandb: bool) -> None:
"""Write run_metadata.json with git state, timestamp, and user annotations."""
metadata = {
"run_id": run_id,
"git_commit": repo_current_commit_hash(),
"uncommitted_changes": not repo_is_clean(),
"date": datetime.now(UTC).strftime("%Y-%m-%d %H:%M"),
"label": config.label,
"notes": config.notes,
"completed": False,
}

metadata_path = out_dir / RUN_METADATA_FILENAME
save_file(metadata, metadata_path, indent=2)

if save_to_wandb:
try_wandb(wandb.save, str(metadata_path), base_path=str(out_dir), policy="now")


def _mark_run_completed(out_dir: Path, save_to_wandb: bool) -> None:
"""Set completed=True in an existing run_metadata.json."""
metadata_path = out_dir / RUN_METADATA_FILENAME
assert metadata_path.exists(), f"run_metadata.json not found at {metadata_path}"

with open(metadata_path) as f:
metadata = json.load(f)

start_time = datetime.strptime(metadata["date"], "%Y-%m-%d %H:%M").replace(tzinfo=UTC)
metadata["duration"] = round((datetime.now(UTC) - start_time).total_seconds() / 3600, 2)
metadata["completed"] = True
save_file(metadata, metadata_path, indent=2)

if save_to_wandb:
try_wandb(wandb.save, str(metadata_path), base_path=str(out_dir), policy="now")


def run_experiment(
target_model: nn.Module,
config: Config,
Expand Down Expand Up @@ -474,6 +516,8 @@ def run_experiment(
train_config=target_model_train_config,
task_name=getattr(config.task_config, "task_name", None),
)

_write_run_metadata(out_dir, run_id, config, save_to_wandb=config.wandb_project is not None)
else:
out_dir = None

Expand All @@ -488,5 +532,8 @@ def run_experiment(
tied_weights=tied_weights,
)

if is_main_process() and config.wandb_project:
wandb.finish()
if is_main_process():
assert out_dir is not None
_mark_run_completed(out_dir, save_to_wandb=config.wandb_project is not None)
if config.wandb_project:
wandb.finish()
Loading
Loading