diff --git a/checkpoint_eval_monitor.py b/checkpoint_eval_monitor.py new file mode 100644 index 0000000..a26ed9b --- /dev/null +++ b/checkpoint_eval_monitor.py @@ -0,0 +1,216 @@ +# Filter out specific FutureWarnings from flash_attn +import warnings +import re + +# Define the warning patterns to filter +warning_patterns = [ + r"torch\.cuda\.amp\.custom_fwd.*is deprecated", + r"torch\.cuda\.amp\.custom_bwd.*is deprecated" +] + +# Create a filter function +def filter_flash_attn_warnings(message, category, filename, lineno, file=None, line=None): + # Check if it's a FutureWarning from flash_attn + if category == FutureWarning and "flash_attn" in filename: + # Check if the message matches any of our patterns + for pattern in warning_patterns: + if re.search(pattern, str(message)): + return None # Suppress the warning + # Return anything else + return True # Show other warnings + +# Apply the filter +warnings.filterwarnings("ignore", category=FutureWarning, module="flash_attn") + +import json +import logging +import os +import sys +import time +from pathlib import Path +from typing import Annotated, List, Optional, Set + +import typer +import yaml +from huggingface_hub import HfApi, list_repo_files +from typer import Option + +from run_evals import main as eval_main +from run_evals import TaskName + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger("poller") + +app = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]}, pretty_exceptions_show_locals=False) + + +# from maxb2: https://github.com/tiangolo/typer/issues/86#issuecomment-996374166 +def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Optional[str] = None): + if config is not None: + typer.echo(f"Loading config file: {config}\n") + try: + with open(config, "r") as f: # Load config file + conf = yaml.safe_load(f) + ctx.default_map = ctx.default_map or {} # Initialize the default map + ctx.default_map.update(conf) # Merge the config dict into default_map + except Exception as ex: + raise typer.BadParameter(str(ex)) + return config + + +def load_processed(file_path: str = "processed_checkpoints.json") -> Set[str]: + """ + Load a set of checkpoint filenames we've already processed, so we don't re‐process them. + """ + if os.path.exists(file_path): + try: + with open(file_path, "r") as f: + return set(json.load(f)) + except Exception as e: + logger.warning(f"Could not parse {file_path}: {e}") + return set() + + +def save_processed(processed: Set[str], file_path: str = "processed_checkpoints.json"): + """ + Save a set of checkpoint filenames, so next time we skip them. + """ + try: + with open(file_path, "w") as f: + json.dump(list(processed), f) + except Exception as e: + logger.warning(f"Could not write to {file_path}: {e}") + + +def find_new_checkpoints(files_in_repo: list[str], processed: Set[str]) -> Set[str]: + """ + Return any .pt filenames containing 'rank' that are not yet in 'processed'. + E.g. 'my_run/epoch3-rank0.pt' + """ + new_ckpts = set() + for f in files_in_repo: + if f.endswith(".pt") and "rank" in f and f not in processed and 'latest' not in f: + new_ckpts.add(f) + return new_ckpts + + +def poll_loop( + repo_id: str, + token: Optional[str], + checkpoint_dir: str, + poll_interval: int, + wandb_project: Optional[str], + wandb_entity: Optional[str], + tasks: List[str], + seeds: List[int], + gpu_ids: List[int], + skip_generation: bool, + train_config: Optional[Path], + wandb_run: Optional[str] = None, + track_run: bool = True, + track_run_project: Optional[str] = None, +): + """ + Main polling loop: + - check the HF repo for new .pt files + - pass them to run_evals.programmatic_main + - record them in JSON + - sleep + """ + hf_api = HfApi(token=token) + processed = load_processed() + + logger.info(f"Starting poller for {repo_id}") + logger.info(f"Polling every {poll_interval} seconds.\n") + + while True: + try: + logger.info(f"Checking for new checkpoints in {repo_id}...") + repo_files = list_repo_files(repo_id, token=token) + new_ckpts = find_new_checkpoints(repo_files, processed) + + if not new_ckpts: + logger.info("No new checkpoints found.") + else: + for ckpt in new_ckpts: + eval_batch_count = str(ckpt).split('ba')[1].split('-rank0.pt')[0] + logger.info(f"Found new checkpoint: {ckpt} with eval_batch_count: {eval_batch_count}") + logger.info("Calling run_evals.programmatic_main(...) on that checkpoint...") + + try: + eval_main( + checkpoints=checkpoint_dir, + hub_repo=repo_id, + hub_files=[ckpt], + hub_token=token, + wandb_project=wandb_project, + wandb_entity=wandb_entity, + tasks=tasks, + seeds=seeds, + skip_generation=skip_generation, + gpu_ids=gpu_ids, + eval_batch_count=eval_batch_count, + train_config=train_config, + verbose=True, + parallel=True, + track_run=track_run, + track_run_project=track_run_project, + ) + # Mark it processed + processed.add(ckpt) + save_processed(processed) + except Exception as e: + logger.error(f"Error running eval on {ckpt}: {e}", exc_info=True) + + except Exception as e: + logger.error(f"Error in poll loop: {e}", exc_info=True) + + logger.info(f"Sleeping {poll_interval} seconds...\n") + time.sleep(poll_interval) + + +@app.command() +def main( + repo_id: Annotated[str, Option(help="Hugging Face repo ID to monitor for new checkpoints", show_default=False)], + token: Annotated[Optional[str], Option(help="Optional HF API token for private repos")] = None, + checkpoint_dir: Annotated[Path, Option(help="Local directory to store or download checkpoints")] = "./checkpoints", + poll_interval: Annotated[int, Option(help="How many seconds to wait between polls")] = 60, + wandb_run: Annotated[Optional[str], Option(help="Optional W&B run to pass to eval script")] = None, + wandb_project: Annotated[Optional[str], Option(help="Optional W&B project to pass to eval script")] = None, + wandb_entity: Annotated[Optional[str], Option(help="Optional W&B entity to pass to eval script")] = None, + tasks: Annotated[List[TaskName], Option(help="Which tasks to evaluate")] = [TaskName.mnli], # type: ignore + seeds: Annotated[List[int], Option(help="Random seeds to pass to _main")] = [42, 314, 1234], + gpu_ids: Annotated[Optional[List[int]], Option(help="Optional list of GPU IDs to use for evaluation")] = None, + skip_generation: Annotated[bool, Option(help="If set, pass skip_generation=True to eval script")] = False, + track_run: Annotated[bool, Option(help="Track the eval run with wandb", rich_help_panel="Weights & Biases")] = True, + track_run_project: Annotated[Optional[str], Option(help="wandb project for tracking the run", rich_help_panel="Weights & Biases")] = None, + train_config: Annotated[Optional[Path], Option(help="Path to a .yaml file containing training configuration. If one is not provided, will attempt to load the config from a wandb run or use defaults.", rich_help_panel="Checkpoint & Config Paths")] = None, + config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, +): # fmt: skip + """ + Poll a Hugging Face repo for new .pt checkpoints (with 'rank' in filename); call run_evals. + """ + poll_loop( + repo_id=repo_id, + token=token, + checkpoint_dir=checkpoint_dir, + poll_interval=poll_interval, + wandb_run=wandb_run, + wandb_project=wandb_project, + wandb_entity=wandb_entity, + tasks=tasks, + seeds=seeds, + track_run=track_run, + track_run_project=track_run_project, + gpu_ids=gpu_ids, + skip_generation=skip_generation, + train_config=train_config, + ) + + +if __name__ == "__main__": + app() diff --git a/create_random_init_model.py b/create_random_init_model.py new file mode 100755 index 0000000..ab709d2 --- /dev/null +++ b/create_random_init_model.py @@ -0,0 +1,127 @@ +import os +import torch +import yaml +import argparse +from pathlib import Path +from huggingface_hub import HfApi +from composer import Trainer +from composer.models import HuggingFaceModel +from src.flex_bert import create_flex_bert_mlm + +def parse_args(): + parser = argparse.ArgumentParser(description='Create a random init Composer model and upload to HF') + parser.add_argument('--config_path', type=str, required=True, + help='Path to the training config YAML file') + parser.add_argument('--output_dir', type=str, default='./checkpoints/random_init', + help='Directory to save the model checkpoints') + parser.add_argument('--repo_id', type=str, default='PLACEHOLDER', + help='HuggingFace repository ID to upload the model') + parser.add_argument('--token', type=str, default=None, + help='HuggingFace API token for private repos') + return parser.parse_args() + +def main(): + args = parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + + with open(args.config_path, 'r') as f: + config = yaml.safe_load(f) + + print(f"Creating model with config from {args.config_path}") + + model_config = config['model']['model_config'] + + valid_attention_types = ['base', 'parallel', 'rope', 'rope_parallel'] + if 'attention_layer' in model_config and model_config['attention_layer'] not in valid_attention_types: + print(f"Warning: Invalid attention_layer '{model_config['attention_layer']}', falling back to 'rope'") + model_config['attention_layer'] = 'rope' + + try: + model = create_flex_bert_mlm( + pretrained_model_name=config['model']['pretrained_model_name'], + tokenizer_name=config['tokenizer_name'], + model_config=model_config + ) + print("HF model created successfully.") + except Exception as e: + print(f"Error creating model: {e}") + print("Attempting with simplified config...") + + for key in list(model_config.keys()): + if key not in ['vocab_size', 'hidden_size', 'num_hidden_layers', + 'num_attention_heads', 'attention_layer', 'padding']: + model_config.pop(key, None) + + model_config['attention_layer'] = 'rope' + model_config['padding'] = 'unpadded' + + model = create_flex_bert_mlm( + pretrained_model_name=config['model']['pretrained_model_name'], + tokenizer_name=config['tokenizer_name'], + model_config=model_config + ) + print("HF model created with simplified config.") + + + composer_model = HuggingFaceModel( + model=model, + tokenizer=None, + use_logits=True + ) + print("Composer model created.") + + checkpoint_path = os.path.join(args.output_dir, "latest-rank0.pt") + + trainer = Trainer( + model=composer_model, + max_duration="1ba", + device="cpu" + ) + + print(f"Saving Composer checkpoint to {checkpoint_path}...") + trainer.save_checkpoint(checkpoint_path) + + config_path = os.path.join(args.output_dir, f"{Path(args.output_dir).name}.yaml") + with open(config_path, 'w') as f: + yaml.dump(config, f) + + print(f"Config saved at: {config_path}") + + if args.token: + print(f"Uploading to HuggingFace repo: {args.repo_id}") + api = HfApi(token=args.token) + + try: + api.repo_info(repo_id=args.repo_id) + print(f"Repository {args.repo_id} already exists") + except Exception: + print(f"Creating new repository: {args.repo_id}") + api.create_repo( + repo_id=args.repo_id, + private=True, + repo_type="model", + exist_ok=True + ) + print(f"Repository {args.repo_id} created successfully") + + api.upload_file( + path_or_fileobj=checkpoint_path, + path_in_repo=f"{Path(args.output_dir).name}/latest-rank0.pt", + repo_id=args.repo_id, + token=args.token + ) + + api.upload_file( + path_or_fileobj=config_path, + path_in_repo=f"{Path(args.output_dir).name}/{Path(args.output_dir).name}.yaml", + repo_id=args.repo_id, + token=args.token + ) + + print("Upload complete!") + else: + print("No HuggingFace token provided. Skipping upload.") + +if __name__ == "__main__": + main() diff --git a/generate_eval_config.py b/generate_eval_config.py index 81e480f..a07d8a1 100644 --- a/generate_eval_config.py +++ b/generate_eval_config.py @@ -41,6 +41,7 @@ def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Option class ModelSize(str, Enum): BASE = "base" LARGE = "large" + HUGE = "huge" def get_model_defaults(model_size: ModelSize): @@ -58,6 +59,12 @@ def get_model_defaults(model_size: ModelSize): "intermediate_size": 2624, "num_attention_heads": 16, }, + "huge": { + "num_hidden_layers": 32, + "hidden_size": 1536, + "intermediate_size": 4096, + "num_attention_heads": 24, + }, } # Select the default model config based on the model_size argument @@ -225,7 +232,9 @@ def main( head_class_dropout: Annotated[float, Option(help="Classification head dropout rate", rich_help_panel="Model Options")] = 0.0, fast_ultrafeedback: Annotated[bool, Option("--fast-ultrafeedback", help="Use a shorter sequence length (1536) for the UltraFeedback eval", rich_help_panel="Task Settings")] = False, seeds: Annotated[List[int], Option(help="List of seeds to use for the eval", rich_help_panel="Task Settings")] = [1618, 42, 6033, 3145], + gpu_ids: Annotated[List[int], Option(help="List of GPU IDs to use for the eval", rich_help_panel="Task Settings")] = [0], parallel: Annotated[bool, Option("--parallel/--single", help="Run the evals in parallel on multiple GPUs or one GPU. Only use if evaluating a single checkpoint on multiple GPUs.", rich_help_panel="Task Settings")] = False, + eval_batch_count: Annotated[Optional[int], Option(help="Number of batches to evaluate", rich_help_panel="Task Settings")] = None, config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, ): # fmt: skip # Read the input YAML file @@ -236,14 +245,25 @@ def main( ckpt = checkpoint.name # checkpoint ckpt_path = checkpoint.parent elif checkpoint.is_dir(): - ckpts = list(checkpoint.glob("*.pt")) + # Search recursively for checkpoint files + ckpts = list(checkpoint.glob("**/*.pt")) if len(ckpts) == 1: ckpt = ckpts[0].name + ckpt_path = ckpts[0].parent elif len(ckpts) > 1: - ckpt = "latest-rank0.pt" + # Look for latest-rank0.pt in any subfolder + latest_ckpts = list(checkpoint.glob("**/latest-rank0.pt")) + if latest_ckpts: + ckpt = latest_ckpts[0].name + ckpt_path = latest_ckpts[0].parent + else: + # Default to first checkpoint found + ckpt = ckpts[0].name + ckpt_path = ckpts[0].parent elif len(ckpts) == 0: - raise ValueError(f"No checkpoint found in the provided directory: {checkpoint}") - ckpt_path = checkpoint + raise ValueError(f"No checkpoint found in the provided directory or its subdirectories: {checkpoint}") + else: + ckpt_path = checkpoint else: raise ValueError(f"Invalid checkpoint path provided: {checkpoint}") @@ -295,6 +315,8 @@ def main( else: base_run_name = safe_get(input_config, "run_name", ckpt_path.name) new_config["base_run_name"] = base_run_name + if eval_batch_count is not None: + new_config["base_run_name"] = f"{base_run_name}-{eval_batch_count}" new_config["default_seed"] = 19 new_config["precision"] = safe_get(input_config, "precision") @@ -388,7 +410,12 @@ def main( elif task_name == "mnli": task_config["seeds"] = seeds[:3] - task_config["trainer_kwargs"] = {"save_num_checkpoints_to_keep": 1, "max_duration": "2ep"} + task_config["trainer_kwargs"] = { + "save_num_checkpoints_to_keep": 1, + "max_duration": "2ep", + "batch_size": 64, + "device_train_microbatch_size": "auto", + } elif task_name == "boolq": task_config["seeds"] = seeds[:3] @@ -417,6 +444,8 @@ def main( task_config["seeds"] = seeds[:3] tasks_dict[task_name] = task_config + task_config["gpu_ids"] = gpu_ids + new_config["tasks"] = tasks_dict # Write the new configuration to a YAML file diff --git a/hf_checkpoints_uploader.py b/hf_checkpoints_uploader.py new file mode 100644 index 0000000..a096c99 --- /dev/null +++ b/hf_checkpoints_uploader.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import re +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional, Set + +import typer +import yaml +from huggingface_hub import HfApi, hf_hub_download, list_repo_files +from huggingface_hub.utils import HfHubHTTPError + +CHECKPOINT_RE = re.compile(r"ep\d+-ba(\d+)-rank0\.pt$") + +app = typer.Typer(pretty_exceptions_show_locals=False, context_settings={"help_option_names": ["-h", "--help"]}) + + +def find_local_checkpoints(base_dir: Path, model_dirs: List[str]) -> Dict[str, Path]: + """Return {repo_path: local_path} for every `ep*-ba*-rank0.pt` under each model dir.""" + ckpts: Dict[str, Path] = {} + for mdir in model_dirs: + for path in (base_dir / mdir).glob("ep*-ba*-rank0.pt"): + if path.name.startswith("latest-"): + continue # Skip alias + if CHECKPOINT_RE.match(path.name): + ckpts[f"{mdir}/{path.name}"] = path + return ckpts + + +def find_remote_checkpoints(api: HfApi, repo_id: str, token: Optional[str]) -> Set[str]: + """Return the set of path strings already present in the HF repo.""" + try: + return set(list_repo_files(repo_id, token=token)) + except HfHubHTTPError as e: + typer.secho(f"❌ Cannot list repo files: {e}", fg=typer.colors.RED, err=True) + raise typer.Exit(1) + + +def upload_file( + api: HfApi, + repo_id: str, + local_path: Path, + path_in_repo: str, + token: Optional[str], + commit_msg: str = "Add checkpoint", +): + print('Uploading ', local_path) + api.upload_file( + path_or_fileobj=str(local_path), + path_in_repo=path_in_repo, + repo_id=repo_id, + token=token, + repo_type="model", + commit_message=commit_msg, + ) + typer.echo(f"✅ Uploaded {path_in_repo}") + + +def catchup_upload( + api: HfApi, + repo_id: str, + base_dir: Path, + model_dirs: List[str], + token: Optional[str], +): + typer.echo("🔍 Running one-off catch-up scan …") + + local_ckpts = find_local_checkpoints(base_dir, model_dirs) + remote_ckpts = find_remote_checkpoints(api, repo_id, token) + + to_upload = [ + (repo_path, local_path) + for repo_path, local_path in local_ckpts.items() + if repo_path not in remote_ckpts + ] + + # Order by batch number (int) to keep history tidy + to_upload.sort(key=lambda x: int(CHECKPOINT_RE.search(x[0]).group(1))) # type: ignore + + for repo_path, local_path in to_upload: + upload_file(api, repo_id, local_path, repo_path, token, commit_msg="Catch-up upload") + + if not to_upload: + typer.echo("✨ Repo already up to date.") + + +def poll_loop( + api: HfApi, + repo_id: str, + base_dir: Path, + model_dirs: List[str], + poll_interval: int, + token: Optional[str], +): + typer.echo(f"🔄 Entering polling loop (every {poll_interval}s) …\n") + + while True: + try: + local_ckpts = find_local_checkpoints(base_dir, model_dirs) + remote_ckpts = find_remote_checkpoints(api, repo_id, token) + + new_items = [ + (rp, lp) for rp, lp in local_ckpts.items() if rp not in remote_ckpts + ] + new_items.sort(key=lambda x: int(CHECKPOINT_RE.search(x[0]).group(1))) # type: ignore + + for repo_path, local_path in new_items: + upload_file(api, repo_id, local_path, repo_path, token, commit_msg="Add checkpoint") + + except Exception as e: + # Log and continue; do not kill the loop + typer.secho(f"⚠️ Error in poll loop: {e}", fg=typer.colors.YELLOW, err=True) + + time.sleep(poll_interval) + + +# --------------------------- CLI ------------------------------------------------- + + +def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Optional[str] = None): + """Merge YAML config into Typer defaults (same helper you already use).""" + if config: + with open(config, "r") as f: + cfg = yaml.safe_load(f) + ctx.default_map = ctx.default_map or {} + ctx.default_map.update(cfg) + return config + + +@app.command() +def main( + repo_id: str = typer.Option(..., help="HF repo to push to, e.g. answerdotai/huge-in-run-checkpoints"), + base_dir: Path = typer.Option( + ..., help="Root with model_dir/checkpoints" + ), + model_dirs: List[str] = typer.Option( + ..., help="One or more sub-dirs to watch" + ), + token: Optional[str] = typer.Option(None, help="HF token (or set HF_TOKEN env var)"), + poll_interval: int = typer.Option(60, help="Seconds between scans after catch-up"), + once: bool = typer.Option( + False, "--once", help="Exit after the catch-up pass (no polling)" + ), + config: Optional[Path] = typer.Option( + None, + "--config", + callback=conf_callback, + is_eager=True, + help="YAML file with default values (CLI overrides)", + ), +): # fmt: skip + """ + Upload all `ep*-ba*-rank0.pt` checkpoints found under *base_dir/model_dir/* to Hugging Face Hub. + + 1. Performs an initial catch-up (only missing files are pushed). + 2. Unless `--once` is given, keeps polling local dirs for fresh checkpoints. + """ + + api = HfApi(token=token) + + catchup_upload(api, repo_id, base_dir, model_dirs, token) + + if once: + typer.echo("🏁 Done (catch-up only).") + return + + poll_loop(api, repo_id, base_dir, model_dirs, poll_interval, token) + + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/processed_checkpoints.json b/processed_checkpoints.json new file mode 100644 index 0000000..77679bf --- /dev/null +++ b/processed_checkpoints.json @@ -0,0 +1 @@ +["modernbert-huge-pretrain-v1/ep0-ba3000-rank0.pt"] \ No newline at end of file diff --git a/run_evals.py b/run_evals.py index 0585baf..d9918f9 100644 --- a/run_evals.py +++ b/run_evals.py @@ -4,17 +4,20 @@ import os import random import re +import shutil import signal import subprocess import tempfile import time import warnings + from collections import deque from enum import Enum from multiprocessing import Process, Queue from pathlib import Path from typing import Annotated, List, Optional + import datasets import psutil import typer @@ -41,6 +44,7 @@ class ModelSize(str, Enum): BASE = "base" LARGE = "large" + HUGE = "huge" # from maxb2: https://github.com/tiangolo/typer/issues/86#issuecomment-996374166 @@ -68,8 +72,10 @@ def conf_callback(ctx: typer.Context, param: typer.CallbackParam, config: Option # Global list to specify which GPUs to use allowed_gpus = None # Will be set to list of GPU IDs or None +console = Console() + -def kill_process_tree(pid): +def kill_process_tree(pid: int): try: parent = psutil.Process(pid) children = parent.children(recursive=True) @@ -213,7 +219,7 @@ def run_job( # Store process info for GPU management gpus_in_use[gpu_id] = {"process": process, "stderr_file": stderr_file, "config": config_path} - if gpu_id is None: + else: process.wait() handle_process_completion(process, stderr_file, config_path, verbose, gpu_id=None) if delete_eval_yamls: @@ -334,7 +340,11 @@ def create_symlink_for_newest_checkpoint(folder: Path, override_existing: bool = if folder.is_dir(): pt_files = list(folder.glob("*.pt")) if not pt_files: - print(f" Warning: No .pt file found in {folder}.") + print(f" Warning: No .pt file found in {folder}, skipping symlink creation.") + return + + if len(pt_files) == 1 and pt_files[0].name == "latest-rank0.pt" and not pt_files[0].is_symlink(): + print(f" Only found one .pt in {folder.name}, named 'latest-rank0.pt' (real file). Skipping symlink creation.") # fmt: skip return # Sort files based on epoch and batch numbers extracted from filenames @@ -359,25 +369,29 @@ def extract_numbers(filename: Path): newest_file = max(pt_files, key=extract_numbers) symlink_path = folder / "latest-rank0.pt" - if symlink_path.exists() and symlink_path.is_symlink(): + if symlink_path.is_symlink(): if symlink_path.resolve() == newest_file.resolve(): - print(f" Existing symlink points to latest checkpoint: {newest_file.parent.name}/{newest_file.name}") + print(f" Existing symlink in {folder.name} already points to {newest_file.name}") return else: print( - f" Warning: Existing symlink points to {symlink_path.parent.name}/{symlink_path.name}, " - f"but latest checkpoint is {newest_file.parent.name}/{newest_file.name}" + f" Warning: symlink in {folder.name} points to {symlink_path.resolve().name}, " + f"but newest is {newest_file.name}" ) if not override_existing: return + symlink_path.unlink(missing_ok=True) + elif symlink_path.exists(): + if not override_existing: + print(f" {symlink_path.name} is a real file in {folder.name}. Use override to remove it.") + return + symlink_path.unlink(missing_ok=True) symlink_path.symlink_to(newest_file.name) if override_existing: - print( - f" Overwriting existing symlink with {symlink_path.parent.name}/{symlink_path.name} -> {newest_file.name}" - ) + print(f" Overwrote symlink {symlink_path.name} -> {newest_file.name}") else: - print(f" Created new symlink {symlink_path.parent.name}/{symlink_path.name} -> {newest_file.name}") + print(f" Created new symlink {symlink_path.name} -> {newest_file.name}") def generate_eval_configs( @@ -391,6 +405,7 @@ def generate_eval_configs( pooling_type: Optional[str], head_class_act: Optional[str], head_class_norm: Optional[str], + eval_batch_count: Optional[str], head_class_dropout: float, tasks: Optional[List[TaskName]], # type: ignore fast_ultrafeedback: bool, @@ -399,6 +414,7 @@ def generate_eval_configs( use_dir_names: Optional[bool], model_size: ModelSize, rope_theta: Optional[float], + gpu_ids: Optional[List[int]] = None, ): """Generate evaluation configs for each checkpoint.""" @@ -456,7 +472,10 @@ def generate_eval_configs( # Add tasks if tasks: for task in tasks: - cmd.extend(["--tasks", task.value]) + if hasattr(task, "value"): + cmd.extend(["--tasks", task.value]) + else: + cmd.extend(["--tasks", str(task)]) if fast_ultrafeedback: cmd.append("--fast-ultrafeedback") @@ -464,9 +483,20 @@ def generate_eval_configs( for seed in seeds: cmd.extend(["--seeds", str(seed)]) - cmd.append("--parallel") if parallel else cmd.append("--single") + if parallel: + cmd.append("--parallel") + + if eval_batch_count: + cmd.extend(["--eval-batch-count", str(eval_batch_count)]) + + if gpu_ids: + if isinstance(gpu_ids, int): + gpu_ids = [gpu_ids] + for g in gpu_ids: + cmd.extend(["--gpu-ids", str(g)]) # Run the config generation process without suppressing output + run_subprocess(cmd, show_errors=True) if not train_config: time.sleep(1) @@ -483,7 +513,6 @@ def download_dataset(dataset_name: str, subset: Optional[str] = None): def download_datasets(tasks: List[TaskName], msg_queue): # type: ignore try: required_datasets = [] - task_to_datasets = { "mlmmlu_amateur_semipro": [["answerdotai/MLMMLU", "Amateur"], ["answerdotai/MLMMLU", "Semipro"]], "mlmmlu_rookie_reserve": [["answerdotai/MLMMLU", "Rookie"], ["answerdotai/MLMMLU", "Reserve"]], @@ -508,7 +537,7 @@ def download_datasets(tasks: List[TaskName], msg_queue): # type: ignore msgs = [] for dataset_name, subset in required_datasets: - datasets.load_dataset(dataset_name, subset) + datasets.load_dataset(dataset_name, subset, trust_remote_code=True) msgs.append(f"Successfully downloaded {dataset_name} {subset}") msg_queue.put(" " + "\n ".join(msgs) + "\n") except Exception as e: @@ -624,14 +653,11 @@ def move_and_flatten_files(local_dir: Path): return downloaded_files -console = Console() - - @app.command() def main( checkpoints: Annotated[Path, Option(help="Path to the directory containing FlexBert checkpoints or location to download checkpoints from Hugging Face Hub to", rich_help_panel="Checkpoint & Config Paths", show_default=False)], train_config: Annotated[Optional[Path], Option(help="Path to a .yaml file containing training configuration. If one is not provided, will attempt to load the config from a wandb run or use defaults.", rich_help_panel="Checkpoint & Config Paths")] = None, - model_size: Annotated[ModelSize, Option("--model-size", help="Model to use for default model config", rich_help_panel="Checkpoint & Config Paths")] = ModelSize.BASE, + model_size: Annotated[ModelSize, Option("--model-size", help="Model to use for default model config", rich_help_panel="Checkpoint & Config Paths")] = ModelSize.HUGE, rope_theta: Annotated[Optional[float], Option("--rope-theta", help="Value for `rotary_emb_base` in the model configuration. If not provided, defaults to pretraining value of 10000.0", rich_help_panel="Checkpoint & Config Paths")] = None, skip_generation: Annotated[bool, Option("--skip-generation", help="Skip generation of evaluation configs. If not true, assumes all existing eval yamls have been already ran.", rich_help_panel="Checkpoint & Config Paths")] = False, run_all_yamls: Annotated[bool, Option("--run-all-yamls", help="Run all evaluation yamls in the `checkpoints` directory, even if some have already been run.", rich_help_panel="Checkpoint & Config Paths")] = False, @@ -656,37 +682,40 @@ def main( delete_eval_yamls: Annotated[bool, Option("--delete/--keep", help="Delete all evaluation YAML files after running the evals. Use `delete_eval_yamls` if passing to `config`", rich_help_panel="Config Options")] = False, use_dir_names: Annotated[Optional[bool], Option("--use-dir-names", help="Use the folder names as the wandb run names. Defaults to true if multiple `checkpoints` are provided with one `train_config`", rich_help_panel="Config Options")] = None, gpu_ids: Annotated[Optional[List[int]], Option(help="List of GPU IDs to use", rich_help_panel="GPU Options")] = None, + eval_batch_count: Annotated[Optional[int], Option("--eval-batch-count", help="Number of batches to evaluate", rich_help_panel="Task Settings")] = None, config: Annotated[Optional[Path], Option(callback=conf_callback, is_eager=True, help="Relative path to YAML config file for setting options. Passing CLI options will supersede config options.", case_sensitive=False, rich_help_panel="Options")] = None, ): # fmt: skip """Run evaluations on model checkpoints.""" + if isinstance(checkpoints, str): + checkpoints = Path(checkpoints) + if isinstance(train_config, str): + train_config = Path(train_config) - # Set the allowed_gpus global variable global allowed_gpus - if gpu_ids is not None: - allowed_gpus = gpu_ids - else: - allowed_gpus = None # Use all GPUs + allowed_gpus = gpu_ids if hub_repo: - print(f"\nDownloading files from {hub_repo}...") + print(f"\nDownloading from {hub_repo} to {checkpoints} ...") downloaded_files = download_hub_files( - repo_id=hub_repo, filenames=hub_files, output_dir=checkpoints, token=hub_token + repo_id=hub_repo, + filenames=hub_files, + output_dir=checkpoints, + token=hub_token, ) if not downloaded_files: print("No files were downloaded successfully. Exiting.") raise Exit(code=1) print(f"Successfully downloaded {len(downloaded_files)} files to {checkpoints}") - # Set default tasks to all tasks if not provided - all_tasks = [task for task in TaskName] - tasks = tasks or all_tasks + if not tasks or len(tasks) == 0: + tasks = [t for t in TaskName] print("\nAsynchronously downloading required datasets...") msg_queue = Queue() download_process = Process(target=download_datasets, args=(tasks, msg_queue)) download_process.start() - print("\nCreating symlinks for latest checkpoints...") + print("\nCreating symlinks for newest checkpoints...") for folder in checkpoints.glob("*"): if folder.is_dir() and not folder.name.startswith("."): create_symlink_for_newest_checkpoint(folder, overwrite_existing_symlinks) @@ -717,13 +746,15 @@ def main( tasks=tasks, fast_ultrafeedback=fast_ultrafeedback, seeds=seeds, + eval_batch_count=eval_batch_count, parallel=parallel, use_dir_names=use_dir_names, model_size=model_size, rope_theta=rope_theta, + gpu_ids=gpu_ids, ) - config_files = list(checkpoints.glob("*_evaluation.yaml")) - config_files = sorted(list(set(config_files) - set(config_files_completed))) + config_files = list(set(checkpoints.glob("*_evaluation.yaml")) - set(config_files_completed)) + config_files = sorted(config_files) else: config_files = list(checkpoints.glob("*_evaluation.yaml")) @@ -738,13 +769,13 @@ def main( while not msg_queue.empty(): print(msg_queue.get()) - if len(config_files) >= 1 and parallel is False: - manage_jobs(configs=config_files, verbose=verbose, delete_eval_yamls=delete_eval_yamls) - elif len(config_files) > 1 and parallel is True: - raise ValueError(f"{parallel=} is only supported for running one config at a time.") - elif len(config_files) == 1 and parallel is True: + if len(config_files) >= 1 and not parallel: + manage_jobs(config_files, verbose=verbose, delete_eval_yamls=delete_eval_yamls) + elif len(config_files) > 1 and parallel: + raise ValueError("Parallel runs only supported for a single config at a time.") + elif len(config_files) == 1 and parallel: if not verbose: - console.print(f"[bold green]Running {config_files[0].name} in parallel on GPUs {', '.join(map(str, gpu_ids))}") # fmt: skip + console.print(f"[bold green]Running {config_files[0].name} in parallel on GPUs: {gpu_ids}") run_job(config_files[0], verbose=verbose, delete_eval_yamls=delete_eval_yamls, gpu_ids=gpu_ids) else: message = "No configuration files found in the specified directory." @@ -760,16 +791,21 @@ def main( else: console.print("[bold green]All jobs completed.") + shutil.rmtree("./checkpoints", ignore_errors=True) + shutil.rmtree("./finetuned-checkpoints", ignore_errors=True) + # Register the signal handler signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) +# Recursively delete ./checkpoints and ignore any problems. + if __name__ == "__main__": try: app() finally: - # Ensure all subprocesses are terminated when the script exits + # Ensure all subprocesses are terminated when the script exits for process in all_processes: if process.poll() is None: process.terminate() @@ -777,4 +813,4 @@ def main( try: process.wait(timeout=5) except subprocess.TimeoutExpired: - process.kill() + process.kill() \ No newline at end of file