diff --git a/.gitignore b/.gitignore index a4b62842b3..b8c7ee47a3 100644 --- a/.gitignore +++ b/.gitignore @@ -277,3 +277,6 @@ python/examples/launch/hello_world/fedml_job_entry_pack.bat **mpi_host_file /python/fedml/workflow/driver_example/customized_job_example/train_job/bootstrap.bat /python/fedml/workflow/driver_example/customized_job_example/train_job/fedml_job_entry_pack.bat + + +venv \ No newline at end of file diff --git a/python/fedml/cross_silo/server/fedml_server_manager.py b/python/fedml/cross_silo/server/fedml_server_manager.py index 9639f4c0e1..eb934508dc 100644 --- a/python/fedml/cross_silo/server/fedml_server_manager.py +++ b/python/fedml/cross_silo/server/fedml_server_manager.py @@ -246,7 +246,22 @@ def handle_message_receive_model_from_client(self, msg_params): if self.is_main_process(): mlops.log_aggregated_model_info(self.args.round_idx, model_url=global_model_url) - logging.info("\n\n==========end {}-th round training===========\n".format(self.args.round_idx)) + # -------------------------------------------------- + # Log global-update frequency in wall-clock terms + # -------------------------------------------------- + current_ts = time.time() + # Compute and print only if this is not the very first round + if hasattr(self, "_last_round_end_ts") and self._last_round_end_ts is not None: + delta = current_ts - self._last_round_end_ts + if delta > 0: + freq = 1.0 / delta + logging.info( + f"Global update frequency: {freq:.4f} updates/sec ({delta:.2f} s per round)" + ) + # Record timestamp for the next round + self._last_round_end_ts = current_ts + + logging.info("\n\n==========end {}/{}-th round training===========\n".format(self.args.round_idx, self.round_num)) if self.args.round_idx < self.round_num: mlops.event("server.wait", event_started=True, event_value=str(self.args.round_idx)) diff --git a/python/spotlight_prj/fedllm/custom_trainer.py b/python/spotlight_prj/fedllm/custom_trainer.py index fd5fab12ea..1bd37123d0 100644 --- a/python/spotlight_prj/fedllm/custom_trainer.py +++ b/python/spotlight_prj/fedllm/custom_trainer.py @@ -6,6 +6,8 @@ This version also integrates GRPO training for GSM8K dataset. """ + + import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) @@ -22,10 +24,154 @@ from fedml.train.llm.distributed import barrier from peft import PeftModel from trl import GRPOTrainer, GRPOConfig +from trl.trainer.utils import prepare_deepspeed + +from fedml.ml.aggregator.agg_operator import FedMLAggOperator from run_fedllm import LLMTrainer, LLMAggregator, save_checkpoint, load_checkpoint from src.peft_utils import set_peft_model_state_dict from src.modeling_utils import load_state_dict +import time, logging +import threading +import subprocess # for launching validation after checkpoints +import shutil # for deleting old checkpoints + +from fractions import Fraction + +# New import for TrainerCallback +from transformers import TrainerCallback, AutoConfig, AutoModelForCausalLM +import transformers + +import wandb +import json + +import warnings +warnings.filterwarnings("ignore") + + +def disable_dropout_in_model(model: torch.nn.Module) -> None: + """ + Disable dropout by setting all torch.nn.Dropout modules to eval mode and + zero probability. + """ + for module in model.modules(): + if isinstance(module, torch.nn.Dropout): + module.p = 0.0 + module.eval() + + +class TimedGRPOTrainer(GRPOTrainer): + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.ref_model is not None: + # Load any model you like as the reference baseline + self.ref_model = AutoModelForCausalLM.from_pretrained( + "Qwen/Qwen3-0.6B", + torch_dtype=torch.bfloat16, + device_map="cpu", + trust_remote_code=True, + use_cache=False, + ) + self.ref_model.eval() + disable_dropout_in_model(self.ref_model) + for p in self.ref_model.parameters(): + p.requires_grad_(False) + + """ + + + def _record_step_stats(self, stats): + # ------------------------------------------------------------- + # Measure *inter-step* wall-clock time: difference between the start + # of this stats call and the previous. This captures the full time + # spent in the GRPO optimisation step (generation + backward pass, + # etc.) rather than just the duration of this method. + # ------------------------------------------------------------- + t_now = time.perf_counter() + step_elapsed = None + if hasattr(self, "_prev_step_t"): + step_elapsed = t_now - self._prev_step_t + self._prev_step_t = t_now # update for next call + + # Call parent implementation *after* timing start so that we include + # all work done before stats are returned. + super()._record_step_stats(stats) + + # ------------------------------------------------------------- + # Compute additional metrics + # ------------------------------------------------------------- + stats["kl_divergence"] = stats["kl"].mean().item() + if step_elapsed is not None: + stats["grpo_step_time"] = step_elapsed # seconds + + # NEW: forward stats to Trainer's logging system so that callbacks + # like GRPOMetricsCallback can record them via the TrainingMetricsLogger. + # This ensures that after every GRPO step the metrics are properly + # captured by the custom logger. + self.log(stats) + + # Override GRPOTrainer internals to measure generation latency per roll-out batch + # NOTE: Upstream `GRPOTrainer` uses `_generate_and_score_completions` (not + # `_make_experience`). The original override therefore never executed. + # We rename the method accordingly so that it is invoked during training. + + def _generate_and_score_completions(self, *args, **kwargs): + + t0 = time.perf_counter() + # Call upstream implementation + result = super()._generate_and_score_completions(*args, **kwargs) + # ------------------------------------------------------------------ + # Compute and log average completion time per generation + # ------------------------------------------------------------------ + elapsed = time.perf_counter() - t0 # total time for this roll-out batch + num_gens = max(1, getattr(self.args, "num_generations", 1)) + self.avg_completion_time = elapsed / num_gens + + # Log the metric so that it is captured by both Accelerate and + # the TrainingMetricsLogger (via GRPOMetricsCallback). + self.accelerator.log({"avg_completion_time": self.avg_completion_time}, step=self.state.global_step) + self.log({"avg_completion_time": self.avg_completion_time}) + + if self.state.global_step % 10 == 0: + torch.cuda.empty_cache() + + return result + + + """ + def _get_per_token_logps_and_entropies(self, model, batch, *args, **kwargs): + + Ensure inputs and model are on the same device before delegating to parent impl. + + This override fixes CPU↔GPU mismatch errors when the reference model is kept + on CPU while the policy lives on GPU. We simply move the tensor inputs in + ``batch`` to the device of ``model`` before invoking the upstream helper. + + target_device = next(model.parameters()).device + # Move all tensor values in the batch to the model's device + if torch.is_tensor(batch): + # Upstream may forward a single Tensor instead of a mapping + batch = batch.to(target_device) + else: + # Standard case: mapping of tensors / non-tensor objects + batch = { + k: (v.to(target_device) if torch.is_tensor(v) else v) + for k, v in batch.items() + } + logps, entropies = super()._get_per_token_logps_and_entropies(model, batch, *args, **kwargs) + + policy_device = self.accelerator.device + if torch.is_tensor(logps): + logps = logps.to(torch.float16) + logps = logps.to(policy_device) + if entropies is not None and torch.is_tensor(entropies): + entropies = entropies.to(torch.float16) + entropies = entropies.to(policy_device) + + return logps, entropies + """ class FullModelLLMTrainer(LLMTrainer): @@ -39,6 +185,51 @@ def __init__(self, *args, **kwargs): self.DATASET_ANS = re.compile(r"####\s*([-+]?\d+\.?\d*)") # Regex for model completion format (\boxed{}) self.MODEL_ANS = re.compile(r"\\boxed\{([^}]*)\}") + + self.BOXED_RE = re.compile(r"\\boxed\{([^}]*)\}") # capture content inside \boxed{…} + + # ------------------------------------------------------------------ + # Configuration: enable or disable per-round checkpoints + # ------------------------------------------------------------------ + + # Default: omit per-round checkpoints unless user explicitly enables + # them via the FedML YAML (enable_round_checkpoints: true) + self._enable_round_ckpt = getattr(self.args, "enable_round_checkpoints", False) + + self.exact_match_reward = 2.0 + self.numeric_equivalence_reward=1.5 + self.incorrect_answer_reward=0.0 + + # Instantiate the training metrics logger and keep as an attribute so + # it can be accessed by callbacks. + self.logger = TrainingMetricsLogger( + log_dir=os.path.join(self.args.output_dir, "wandb_logs"), + run_name=f"fl-client{getattr(self.args, 'rank', 'unknown')}_run{getattr(self.args, 'run_id', os.getenv('FEDML_CURRENT_RUN_ID', '0'))}", + enable_wandb=True, + wandb_project="fedllm-grpo-training", + args=self.args, + ) + + + def to_number(self, text: str) -> Optional[float]: + """Convert string to float if possible, handling simple fractions.""" + text = text.replace(",", "").strip() + # Fractions like 3/4 + if "/" in text: + try: + return float(Fraction(text)) + except (ValueError, ZeroDivisionError): + pass + try: + return float(text) + except ValueError: + return None + + + def extract_boxed(self, text: str) -> str: + """Return first \\boxed{...} contents; '' if none.""" + m = self.BOXED_RE.search(text) + return m.group(1) if m else "" def reward_fn(self, completions, answer, **_): """Reward function for GSM8K that checks if the predicted answer matches the true answer.""" @@ -54,14 +245,22 @@ def reward_fn(self, completions, answer, **_): if pred and tru: pred_num = pred.group(1) tru_num = tru.group(1) - out.append(1.0 if pred_num == tru_num else -0.2) + if pred_num == tru_num: + out.append(self.exact_match_reward) + else: + p_num, g_num = self.to_number(pred_num), self.to_number(tru_num) + if (p_num is not None and g_num is not None and abs(p_num - g_num) < 1e-4): + out.append(self.numeric_equivalence_reward) + else: + out.append(self.incorrect_answer_reward) else: - out.append(-0.2) + out.append(self.incorrect_answer_reward) return out def train(self, train_data, device, args): """Override train to use GRPO training on GSM8K dataset.""" self.log("Starting GRPO training on GSM8K") + # Load GSM8K dataset ds = load_dataset("openai/gsm8k", "main", split="train") @@ -74,7 +273,7 @@ def train(self, train_data, device, args): # Calculate effective batch size for GRPO constraint # effective_batch_size = num_gpus * per_device_batch_size * gradient_accumulation_steps - gradient_accumulation_steps = 4 + gradient_accumulation_steps = getattr(args, 'gradient_accumulation_steps', 2) effective_batch_size = 1 * grpo_batch_size * gradient_accumulation_steps # Num generations must evenly divide the effective batch size @@ -90,6 +289,8 @@ def train(self, train_data, device, args): else: num_generations = 2 + num_generations = 4 + # For testing, we can use a very small number of steps if grpo_max_steps > 0: self.log(f"GRPO training for {grpo_max_steps} steps (test mode)") @@ -101,6 +302,10 @@ def train(self, train_data, device, args): # **FIX: Load fresh model and tokenizer for GRPO to avoid FedML state corruption** from transformers import AutoModelForCausalLM, AutoTokenizer import torch + + # ↓↓↓ off-load the FedML copy BEFORE allocating fresh_model + self.model.to("cpu") + torch.cuda.empty_cache() # actually releases the VRAM # Get model name from model_args model_name = self.model_args.model_name_or_path @@ -113,22 +318,25 @@ def train(self, train_data, device, args): fresh_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.bfloat16, - use_cache=False + use_cache=False, + trust_remote_code=True ) else: fresh_model = AutoModelForCausalLM.from_pretrained( model_name, - torch_dtype=torch.float32, # Use float32 for better stability - use_cache=False + torch_dtype=torch.float16, # Use float32 for better stability + use_cache=False, + trust_remote_code=True ) except Exception as e: self.log(f"Failed to load with requested precision, falling back to float32: {e}") fresh_model = AutoModelForCausalLM.from_pretrained( model_name, - torch_dtype=torch.float32, # Fallback to float32 - use_cache=False + torch_dtype=torch.float16, # Fallback to float32 + use_cache=False, + trust_remote_code=True ) - fresh_tokenizer = AutoTokenizer.from_pretrained(model_name) + fresh_tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) fresh_tokenizer.pad_token = fresh_tokenizer.eos_token # Copy current model state to fresh model (to preserve any training from previous rounds) @@ -141,7 +349,11 @@ def train(self, train_data, device, args): current_state = self.model.state_dict() # Load into fresh model - fresh_model.load_state_dict(current_state, strict=False) + incompatible = fresh_model.load_state_dict(current_state, strict=True) + # Log any keys that failed to load for easier debugging + self.log( + f"missing keys: {incompatible.missing_keys}, unexpected keys: {incompatible.unexpected_keys}" + ) # Move fresh model to correct device fresh_model.to(device) @@ -167,26 +379,36 @@ def train(self, train_data, device, args): output_dir=str(self.checkpoint_dir / "grpo"), per_device_train_batch_size=grpo_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, - max_completion_length=1024, + max_completion_length=512, num_generations=num_generations, # Adjusted based on effective batch size num_train_epochs=grpo_num_epochs if grpo_max_steps <= 0 else 1, # Use 1 epoch if max_steps is set max_steps=grpo_max_steps if grpo_max_steps > 0 else -1, # Override epochs with max_steps learning_rate=5e-6, bf16=use_bf16, # Match model precision fp16=not use_bf16, # Use fp16 if not bf16 - gradient_checkpointing=False, # Keep consistent with config - logging_steps=5 if grpo_max_steps > 0 and grpo_max_steps < 50 else 25, # More frequent logging for short runs - log_completions=True, + gradient_checkpointing=getattr(args, 'gradient_checkpointing', False), + #logging_steps=5 if grpo_max_steps > 0 and grpo_max_steps < 50 else 25, # More frequent logging for short runs + logging_steps=1, + log_completions=False, save_steps=grpo_max_steps if grpo_max_steps > 0 else 500, # Save at the end if using max_steps # Add seed for reproducibility in federated setting - seed=42 + self.round_idx * 100 + args.rank, # Different seed per round and client + seed=int(time.perf_counter_ns() % (2**32)), + #report_to="wandb", + scale_rewards=False, + temperature=0.7, + top_p=0.95, + top_k=50, + repetition_penalty=1.1, + epsilon=0.2, + beta=0.1, + #optim="sgd", ) self.log(f"GRPO Config - bf16: {use_bf16}, fp16: {not use_bf16}, batch_size: {grpo_batch_size}") self.log(f"GRPO Config - max_completion_length: 1024, num_generations: {num_generations}") # Create GRPO trainer with fresh model and tokenizer - grpo_trainer = GRPOTrainer( + grpo_trainer = TimedGRPOTrainer( model=fresh_model, # Use fresh model args=cfg, train_dataset=ds.shuffle(seed=cfg.seed), @@ -197,21 +419,23 @@ def train(self, train_data, device, args): # **FIX: Set generation parameters for numerical stability** grpo_trainer.generation_kwargs = { "do_sample": True, - "temperature": 1.0, - "top_p": 0.9, - "top_k": 50, "pad_token_id": fresh_tokenizer.eos_token_id, "eos_token_id": fresh_tokenizer.eos_token_id, + "bos_token_id": fresh_tokenizer.bos_token_id, "max_new_tokens": 512, - "repetition_penalty": 1.1, # Prevent repetition "length_penalty": 1.0, # Neutral length penalty } + # Attach our logging callback so that metrics are recorded every step. + grpo_trainer.add_callback(GRPOMetricsCallback(self.logger)) + self.log(f"Set generation parameters: {grpo_trainer.generation_kwargs}") # Run GRPO training grpo_trainer.train() + + # **Copy trained weights back to FedML's model** self.log("Copying GRPO-trained weights back to FedML model") trained_state = fresh_model.state_dict() @@ -221,23 +445,40 @@ def train(self, train_data, device, args): self.model.base_model.load_state_dict(trained_state, strict=False) else: self.model.load_state_dict(trained_state, strict=False) + self.model.to("cpu") + #del trained_state - # Save the trained model in FedML's expected location + # Optionally save a pre-aggregation checkpoint for this round + self.latest_checkpoint_dir = self.checkpoint_dir / f"round_{self.round_idx}_before_agg" - self.log(f"Saving GRPO-trained model to \"{self.latest_checkpoint_dir}\"") - - # Save checkpoint using FedML's model + self.log(f"[round-ckpt] Saving GRPO-trained model to \"{self.latest_checkpoint_dir}\"") + save_checkpoint( self.model, self.latest_checkpoint_dir, is_saving_process=self.training_args.should_save, synchronize=True ) + + # After saving the current round checkpoint, clean up older round_* checkpoints + if self.training_args.should_save: + self._cleanup_old_round_checkpoints() + """ + grpo_trainer.accelerator.end_training() + grpo_trainer.accelerator.free_memory() + grpo_trainer.model = None + gc.collect() + torch.cuda.empty_cache() + """ # Clean up fresh model to free memory del fresh_model del fresh_tokenizer - torch.cuda.empty_cache() if torch.cuda.is_available() else None + del grpo_trainer.optimizer + del grpo_trainer.lr_scheduler + del grpo_trainer + self.model.to("cpu") + torch.cuda.empty_cache() self.log("GRPO training finished") @@ -251,6 +492,8 @@ def on_after_local_training(self, train_data, device, args): def set_model_params(self, model_parameters) -> None: self.log("start") + t0 = time.perf_counter() + model_parameters = to_device(model_parameters, device="cpu") barrier() @@ -262,17 +505,21 @@ def set_model_params(self, model_parameters) -> None: load_state_dict(self.model, model_parameters, strict=False) barrier() - if self.round_idx >= 0 and self.should_save: - # save aggregated model checkpoint - self.latest_checkpoint_dir = self.checkpoint_dir / f"round_{self.round_idx}_after_agg" - self.log(f"saving aggregated model to \"{self.latest_checkpoint_dir}\"") - save_checkpoint( - self.model, - self.latest_checkpoint_dir, - is_saving_process=self.training_args.should_save, - state_dict=model_parameters, - synchronize=True - ) + + + # save aggregated model checkpoint + self.latest_checkpoint_dir = self.checkpoint_dir / f"round_{self.round_idx}_after_agg" + self.log(f"saving aggregated model to \"{self.latest_checkpoint_dir}\"") + save_checkpoint( + self.model, + self.latest_checkpoint_dir, + is_saving_process=self.training_args.should_save, + state_dict=model_parameters, + synchronize=True + ) + + elapsed = time.perf_counter() - t0 + self.log(f"set_model_params (client) took {elapsed:.3f}s") self.log("finished") @@ -289,6 +536,11 @@ def sync_process_group( if round_idx is None: round_idx = self.round_idx + model_params = to_device(model_params, "cpu") # ensure params live on CPU + + dtypes = set(t.dtype for t in model_params.values()) + print(f"model_params dtypes: {dtypes}") # Should print torch.float32 if FP32 + broadcast_object_list([round_idx, model_params, client_index], from_process=from_process) self.log("finished") @@ -296,18 +548,216 @@ def sync_process_group( def await_sync_process_group(self, from_process: int = 0) -> list: self.log("start") + # ---------------------- Timing start ---------------------- + t0 = time.perf_counter() outputs = broadcast_object_list([None, None, None], from_process=from_process) + download_elapsed = time.perf_counter() - t0 + # ---------------------- WandB log ------------------------ + if getattr(self, "logger", None) and self.logger.enable_wandb and self.logger.wandb_run: + # Step keyed by federated round so uploads and downloads align. + self.logger.wandb_run.log({ + "performance/model_download_time": download_elapsed + }, step=self.round_idx) + + # Store for optional moving-average statistics. + self.logger.accumulated_metrics.setdefault("model_download_times", []).append(download_elapsed) + + self.log(f"model download took {download_elapsed:.3f}s") self.log("finished") return outputs + def _cleanup_old_round_checkpoints(self, keep_last: int = 1): + """Delete old round_* checkpoints but keep the most recent `keep_last`. + + Wall-clock checkpoints (wallclock_*) are never removed. + """ + pattern = re.compile(r"round_(\d+)_(before|after)_agg") + # Collect candidate directories and their round numbers + ckpts = [] + for d in self.checkpoint_dir.iterdir(): + m = pattern.fullmatch(d.name) + if m and d != self.latest_checkpoint_dir: + ckpts.append((int(m.group(1)), d)) + + # Sort by round number so oldest come first + ckpts.sort(key=lambda x: x[0]) + + # Remove all but the newest `keep_last` checkpoints + for _, d in ckpts[:-keep_last]: + try: + shutil.rmtree(d, ignore_errors=True) + except Exception as e: + self.log(f"[WARN] Failed to delete old checkpoint {d}: {e}") + class FullModelLLMAggregator(LLMAggregator): """Custom aggregator that properly handles both PEFT and non-PEFT models.""" + # ------------------------------------------------------------------ + # Periodic checkpointing setup + # ------------------------------------------------------------------ + + def __init__(self, *args, **kwargs): + """Extend parent init and start a background thread that creates a + checkpoint every ``server_checkpoint_interval_minutes`` (default 30). + + Notes + ----- + * Only the main process (``self.is_main_process()``) actually writes the + checkpoint to avoid race conditions. + * Checkpoints are written under + ``{self.checkpoint_dir}/wallclock_{unix_ts}`` so they will not + collide with the per-round checkpoints that already exist. + """ + super().__init__(*args, **kwargs) + + # ------------------------------------------------------------- + # WandB logger for aggregator-level (server) statistics – initialize + # EARLY so that it exists even when periodic checkpointing is disabled. + # ------------------------------------------------------------- + self.logger = TrainingMetricsLogger( + log_dir=os.path.join(self.args.output_dir, "wandb_logs"), + run_name=f"fl-server_run{getattr(self.args, 'run_id', os.getenv('FEDML_CURRENT_RUN_ID', '0'))}", + enable_wandb=True, + wandb_project="fedllm-grpo-training", + args=self.args, + ) + self.model_broadcasts = 0 + + # Determine interval (seconds) + interval_min = getattr(self.args, "server_checkpoint_interval_minutes", 30) + if interval_min <= 0: + # Disable if user passes 0 or negative value + self._checkpoint_interval = None + return + + self._checkpoint_interval = interval_min * 60 + + # Background thread is only needed on the main process + if self.is_main_process(): + self._stop_checkpoint_evt = threading.Event() + self._checkpoint_thread = threading.Thread( + target=self._periodic_checkpoint_loop, + name="periodic_ckpt_thread", + daemon=True, + ) + self._checkpoint_thread.start() + + # Whether to save per-round checkpoints (default False) + self._enable_round_ckpt = getattr(self.args, "enable_round_checkpoints", False) + + # ------------------ Nesterov Momentum Setup (NEW) ------------------ + # Learning rate for the server optimizer (default 1.0 so the server fully + # applies the aggregated update when momentum=0) + self._server_lr = getattr(self.args, "server_lr", 1.0) + # Momentum coefficient. Typical values are 0.9 or 0.99 + self._momentum = getattr(self.args, "server_momentum", 0.9) + # Enable / disable Nesterov variant (default=True) + self._nesterov = getattr(self.args, "server_nesterov", True) + # Momentum buffer for each parameter + self._velocity: OrderedDict = OrderedDict() + # ------------------------------------------------------------------- + + # ----- WandB server-side logger (NEW) ----- + # Create a standalone TrainingMetricsLogger so that aggregator-level + # system statistics (e.g. active workers, model broadcasts) are also + # recorded in the same WandB project as the clients. + if not hasattr(self, "logger"): + self.logger = TrainingMetricsLogger( + log_dir=os.path.join(self.args.output_dir, "wandb_logs"), + run_name=f"fl-server_run{getattr(self.args, 'run_id', os.getenv('FEDML_CURRENT_RUN_ID', '0'))}", + enable_wandb=True, + wandb_project="fedllm-grpo-training", + args=self.args, + ) + # Counter for how many times the global model has been broadcast to + # clients – useful for monitoring server throughput. + self.model_broadcasts = 0 + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _cleanup_old_round_checkpoints(self, keep_last: int = 1): + """Delete old round_* checkpoints but keep the most recent `keep_last`. + + Wall-clock checkpoints (wallclock_*) are never removed. + """ + pattern = re.compile(r"round_(\d+)_(before|after)_agg") + # Collect candidate directories and their round numbers + ckpts = [] + for d in self.checkpoint_dir.iterdir(): + m = pattern.fullmatch(d.name) + if m and d != self.latest_checkpoint_dir: + ckpts.append((int(m.group(1)), d)) + + # Sort by round number so oldest come first + ckpts.sort(key=lambda x: x[0]) + + # Remove all but the newest `keep_last` checkpoints + for _, d in ckpts[:-keep_last]: + try: + shutil.rmtree(d, ignore_errors=True) + except Exception as e: + self.log(f"[WARN] Failed to delete old checkpoint {d}: {e}") + + def _periodic_checkpoint_loop(self): + """Loop that sleeps ``_checkpoint_interval`` seconds then writes a + checkpoint until ``_stop_checkpoint_evt`` is set (i.e., program exit). + """ + while not self._stop_checkpoint_evt.wait(self._checkpoint_interval): + try: + ts = int(time.time()) + ckpt_dir = self.checkpoint_dir / f"wallclock_{ts}" + self.log(f"Periodic checkpoint → {ckpt_dir}") + # Always save checkpoints in the standard HuggingFace format so that + # the resulting directory can be loaded with `from_pretrained`. + # For `PeftModel` this will also persist the adapter weights. + # Only the main process writes the checkpoint to avoid race conditions + # (the background thread is spawned exclusively on the main process). + if self.training_args.should_save: + ckpt_dir.mkdir(parents=True, exist_ok=True) + try: + # Try the native HuggingFace save. + # For `PeftModel` this will also persist the adapter weights. + self.model.save_pretrained(str(ckpt_dir), state_dict=self.model.state_dict()) + except AttributeError: + # Fallback to the generic helper if the model doesn't implement + # `save_pretrained` (unlikely for LLMs but safe-guard regardless). + save_checkpoint( + self.model, + checkpoint_dir=ckpt_dir, + is_saving_process=True, + synchronize=False, + ) + # ---------------- New behaviour ---------------- + # After successfully writing the checkpoint, prune older + # wallclock_* checkpoints so that only the latest six are + # kept on disk. + self._cleanup_old_wallclock_checkpoints() + # Run validation on the newly saved checkpoint + try: + script_path = Path(__file__).parent / "validation.py" + log_path = Path(self.args.output_dir) / "validation.log" + with open(log_path, "a") as lf: + subprocess.Popen( + [sys.executable, str(script_path), "--model", str(ckpt_dir)], + stdout=lf, + stderr=subprocess.STDOUT, + close_fds=True, + ) + except Exception as e: + self.log(f"[WARN] Failed to launch validation: {e}") + except Exception as e: + # Log and continue – do not crash training due to checkpoint failure + self.log(f"[WARN] Periodic checkpoint failed: {e}") + def set_model_params(self, model_parameters) -> None: self.log("start") + t0 = time.perf_counter() + model_parameters = to_device(model_parameters, device="cpu") barrier() @@ -319,16 +769,501 @@ def set_model_params(self, model_parameters) -> None: load_state_dict(self.model, model_parameters, strict=False) barrier() - if self.round_idx >= 0 and self.should_save: - # save aggregated model checkpoint - self.latest_checkpoint_dir = self.checkpoint_dir / f"round_{self.round_idx}_after_agg" - self.log(f"saving aggregated model to \"{self.latest_checkpoint_dir}\"") - save_checkpoint( - self.model, - self.latest_checkpoint_dir, - is_saving_process=self.training_args.should_save, - state_dict=model_parameters, - synchronize=True + # save aggregated model checkpoint + self.latest_checkpoint_dir = self.checkpoint_dir / f"round_{self.round_idx}_after_agg" + self.log(f"saving aggregated model to \"{self.latest_checkpoint_dir}\"") + save_checkpoint( + self.model, + self.latest_checkpoint_dir, + is_saving_process=self.training_args.should_save, + state_dict=model_parameters, + synchronize=True + ) + + # Clean up old round checkpoints on the server as well + if self.training_args.should_save: + self._cleanup_old_round_checkpoints() + + elapsed = time.perf_counter() - t0 + self.log(f"set_model_params (server) took {elapsed:.3f}s") + + # ------------------------------------------------------------- + # NEW: push aggregator-level system statistics to WandB + # ------------------------------------------------------------- + self.model_broadcasts += 1 + self.logger.log_server_statistics( + stats={ + "server_statistics": { + "active_workers": getattr(self.args, "client_num_in_total", 0), + "model_subscribers": [], + "service_status": { + "current_model_version": self.round_idx, + "buffer_statistics": {}, + }, + }, + "current_pipeline_depth": 0, # placeholder – update if pipeline depth is tracked elsewhere + "model_broadcasts": self.model_broadcasts, + }, + global_step=self.round_idx, + ) + + self.log("finished") + + def _cleanup_old_wallclock_checkpoints(self, keep_last: int = 12): + """Delete old wallclock_* checkpoints but keep the most recent ``keep_last``. + + This complements the round-based checkpoint cleanup by pruning time-based + checkpoints created by the periodic background thread. The newest + ``keep_last`` checkpoints are retained; older ones are removed to avoid + unbounded disk usage on long-running servers. + """ + pattern = re.compile(r"wallclock_(\d+)$") + valid_ckpts = [] # (timestamp, Path) + invalid_ckpts = [] # Path(s) that lack model files + + # Determine candidate checkpoints and group by validity + for d in self.checkpoint_dir.iterdir(): + m = pattern.fullmatch(d.name) + if not m: + continue # skip non-wallclock dirs + + # Heuristic: consider checkpoint *valid* if it contains at least one + # model weight file produced by ``save_pretrained`` or our fallback + # helper (i.e. *.bin or *.safetensors). This covers both HF and PEFT. + has_model_file = any(d.glob("*.bin")) or any(d.glob("*.safetensors")) or any(d.glob("*.pt")) + + if has_model_file: + valid_ckpts.append((int(m.group(1)), d)) + else: + invalid_ckpts.append(d) + + # Remove *all* invalid checkpoints immediately as they are unusable + for d in invalid_ckpts: + try: + shutil.rmtree(d, ignore_errors=True) + except Exception as e: + self.log(f"[WARN] Failed to delete incomplete wallclock checkpoint {d}: {e}") + + # Sort valid checkpoints chronologically (oldest first) + valid_ckpts.sort(key=lambda x: x[0]) + + # Keep only the most recent ``keep_last`` valid checkpoints + for _, d in valid_ckpts[:-keep_last]: + try: + shutil.rmtree(d, ignore_errors=True) + except Exception as e: + self.log(f"[WARN] Failed to delete old wallclock checkpoint {d}: {e}") + + +class TrainingMetricsLogger: + """Comprehensive logging for GRPO training with WandB support""" + + def __init__( + self, + log_dir: str, + run_name: Optional[str] = None, + enable_wandb: bool = False, + wandb_project: Optional[str] = None, + wandb_entity: Optional[str] = None, + wandb_config: Optional[dict] = None, + args: Optional[Any] = None, + ): + """Parameters + ---------- + log_dir : str + Directory where auxiliary JSON / txt logs will be written. + run_name : str, optional + Human-readable name that will appear in the WandB UI. + enable_wandb : bool, default False + If ``True`` a WandB run is initialised, otherwise the logger will + operate in offline mode and simply discard `.log*()` calls. + wandb_project, wandb_entity, wandb_config : Optional[str | dict] + Passed through to :pyfunc:`wandb.init` unchanged. + args : Any, optional + (FedML) *args* namespace used throughout the project. We only + use it to derive a *unique* WandB run *id* so that the server and + every client write to **separate** runs instead of clobbering one + another. + """ + + self.log_dir = log_dir + self.run_name = run_name or f"grpo_training_{int(time.time())}" + self.enable_wandb = enable_wandb + self.args = args # may be ``None`` for unit tests / offline runs + + # ------------------------------------------------------------------ + # WandB setup – ensure that each process (server / client-rank-N) gets + # its *own* run. Re-using the same run *id* from multiple processes + # causes metrics to silently overwrite each other and leads to exactly + # the "not everything we log shows up" behaviour that we observed on + # the dashboard. + # ------------------------------------------------------------------ + self.wandb_run = None + if self.enable_wandb: + wandb_kwargs = { + "project": wandb_project or "grpo-training", + "entity": wandb_entity, + "name": self.run_name, + "config": wandb_config or {}, + "reinit": True, + } + + # Use a *group* so that the server run and all client runs are + # nicely collated in the WandB UI, while still receiving unique + # run IDs. + if args is not None and hasattr(args, "run_id"): + wandb_kwargs["group"] = str(args.run_id) + + # Derive a UNIQUE id: "-server" or "-client" + role_suffix = ( + "-server" + if getattr(args, "role", "server") == "server" + else f"-client{getattr(args, 'rank', '0')}" + ) + wandb_kwargs["id"] = f"{args.run_id}{role_suffix}" + + # Remove None entries so wandb.init does not complain. + wandb_kwargs = {k: v for k, v in wandb_kwargs.items() if v is not None} + + self.wandb_run = wandb.init(**wandb_kwargs) + print( + f"[WandB] Logging initialised → " + f"project={wandb_kwargs.get('project')}, run_name={self.run_name}" ) - self.log("finished") \ No newline at end of file + # Metrics tracking + self.step_count = 0 + self.training_start_time = time.time() + self.last_log_time = time.time() + # Stores the most recent average completion time reported by the trainer. + # Initialised here so that attribute always exists and we avoid AttributeError + # if the metric is accessed before the first value is logged. + self.avg_completion_time: Optional[float] = None + + # Accumulated metrics for averaging + self.accumulated_metrics = { + 'losses': [], + 'rewards': [], + 'kl_divergences': [], + 'policy_losses': [], + 'value_losses': [], + 'advantages': [], + 'rollout_lengths': [], + 'completion_times': [], + 'step_times': [], + } + + def log_training_step(self, step_id: str, train_result: dict, global_step: int): + """Log metrics for a single training step""" + + # Prepare metrics dict for wandb + wandb_metrics = {} + + # Core training metrics + if 'loss' in train_result: + wandb_metrics['training/loss'] = train_result['loss'] + self.accumulated_metrics['losses'].append(train_result['loss']) + + if 'avg_reward' in train_result: + wandb_metrics['training/avg_reward'] = train_result['avg_reward'] + self.accumulated_metrics['rewards'].append(train_result['avg_reward']) + + # Advanced GRPO metrics + if 'kl_divergence' in train_result: + wandb_metrics['training/kl_divergence'] = train_result['kl_divergence'] + self.accumulated_metrics['kl_divergences'].append(train_result['kl_divergence']) + + if 'policy_loss' in train_result: + wandb_metrics['training/policy_loss'] = train_result['policy_loss'] + self.accumulated_metrics['policy_losses'].append(train_result['policy_loss']) + + if 'value_loss' in train_result: + wandb_metrics['training/value_loss'] = train_result['value_loss'] + self.accumulated_metrics['value_losses'].append(train_result['value_loss']) + + if 'advantage_mean' in train_result: + wandb_metrics['training/advantage_mean'] = train_result['advantage_mean'] + self.accumulated_metrics['advantages'].append(train_result['advantage_mean']) + + # Rollout statistics + if 'rollout_count' in train_result: + wandb_metrics['rollouts/count_per_step'] = train_result['rollout_count'] + + if 'avg_rollout_length' in train_result: + wandb_metrics['rollouts/avg_length'] = train_result['avg_rollout_length'] + self.accumulated_metrics['rollout_lengths'].append(train_result['avg_rollout_length']) + + # Average completion time (per generation) + # Update the cached value if the trainer provided a fresh measurement. + if 'avg_completion_time' in train_result: + self.avg_completion_time = train_result['avg_completion_time'] + + if self.avg_completion_time is not None: + wandb_metrics['performance/avg_completion_time'] = self.avg_completion_time + self.accumulated_metrics['completion_times'].append(self.avg_completion_time) + + if 'rollout_time' in train_result: + wandb_metrics['performance/rollout_time'] = train_result['rollout_time'] + + if 'training_time' in train_result: + wandb_metrics['performance/training_step_time'] = train_result['training_time'] + + # Weight update timing metrics + if 'weight_update_time' in train_result: + wandb_metrics['performance/weight_update_time'] = train_result['weight_update_time'] + + if 'backward_time' in train_result: + wandb_metrics['performance/backward_pass_time'] = train_result['backward_time'] + + if 'optimizer_time' in train_result: + wandb_metrics['performance/optimizer_step_time'] = train_result['optimizer_time'] + + if 'wait_time' in train_result: + wandb_metrics['performance/batch_wait_time'] = train_result['wait_time'] + + # Gradient metrics + if 'grad_norm' in train_result: + wandb_metrics['training/grad_norm'] = train_result['grad_norm'] + + # Learning rate + if 'learning_rate' in train_result: + wandb_metrics['training/learning_rate'] = train_result['learning_rate'] + + # GRPO step time + if 'grpo_step_time' in train_result: + wandb_metrics['performance/grpo_step_time'] = train_result['grpo_step_time'] + self.accumulated_metrics['step_times'].append(train_result['grpo_step_time']) + + # Log to wandb + if self.enable_wandb and self.wandb_run and wandb_metrics: + # Replace the Trainer-provided ``global_step`` (which resets every + # round) with an internal monotonically-increasing counter so + # that WandB treats each update as a new step instead of + # overwriting previous values. + wandb_step = self.step_count # 0-based running counter + wandb_metrics['global_step'] = wandb_step + self.wandb_run.log(wandb_metrics, step=wandb_step) + + # Advance our own monotonically-increasing counter by exactly one + # because this method is invoked once per call to `Trainer.log`. + self.step_count += 1 + + def log_server_statistics(self, stats: dict, global_step: int): + """Log server and system statistics""" + wandb_metrics = {} + + if 'server_statistics' in stats: + server_stats = stats['server_statistics'] + + # Handle double nesting + if 'server_statistics' in server_stats: + server_stats = server_stats['server_statistics'] + + # Active workers + if 'active_workers' in server_stats: + wandb_metrics['system/active_workers'] = server_stats['active_workers'] + + # Model subscribers + if 'model_subscribers' in server_stats: + inference_workers = [w for w in server_stats['model_subscribers'] if 'trainer' not in w.lower()] + wandb_metrics['system/inference_workers'] = len(inference_workers) + wandb_metrics['system/total_subscribers'] = len(server_stats['model_subscribers']) + + # Service status + if 'service_status' in server_stats: + service_status = server_stats['service_status'] + + # Buffer statistics + if 'buffer_statistics' in service_status: + buffer_stats = service_status['buffer_statistics'] + + if 'pending_steps' in buffer_stats: + wandb_metrics['system/pending_steps'] = buffer_stats['pending_steps'] + + if 'ready_batches' in buffer_stats: + wandb_metrics['system/ready_batches'] = buffer_stats['ready_batches'] + + if 'total_rollouts_received' in buffer_stats: + wandb_metrics['system/total_rollouts_received'] = buffer_stats['total_rollouts_received'] + + # Model version tracking + if 'current_model_version' in service_status: + wandb_metrics['system/current_model_version'] = service_status['current_model_version'] + + # Pipeline statistics + if 'current_pipeline_depth' in stats: + wandb_metrics['system/pipeline_depth'] = stats['current_pipeline_depth'] + + if 'model_broadcasts' in stats: + wandb_metrics['system/model_broadcasts'] = stats['model_broadcasts'] + + # Log to wandb + if self.enable_wandb and self.wandb_run and wandb_metrics: + self.wandb_run.log(wandb_metrics, step=self.step_count) + + def log_performance_metrics(self, global_step: int, training_rate: Optional[float] = None): + """Log performance and timing metrics""" + wandb_metrics = {} + + current_time = time.time() + elapsed_time = current_time - self.training_start_time + + # Training rate + if training_rate is not None: + wandb_metrics['performance/training_rate_steps_per_hour'] = training_rate + + # Overall training time + wandb_metrics['performance/elapsed_time_hours'] = elapsed_time / 3600 + + # Steps per second (recent) + time_since_last_log = current_time - self.last_log_time + if time_since_last_log > 0 and hasattr(self, 'last_step_count'): + steps_since_last = global_step - self.last_step_count + steps_per_second = steps_since_last / time_since_last_log + wandb_metrics['performance/steps_per_second'] = steps_per_second + + # Log to wandb + if self.enable_wandb and wandb_metrics: + self.wandb_run.log(wandb_metrics, step=global_step) + + self.last_log_time = current_time + self.last_step_count = global_step + + + def get_moving_average(values, window): + if len(values) == 0: + return 0 + window = min(window, len(values)) + return sum(values[-window:]) / window + + def log_moving_averages(self, global_step: int, window_size: int = 100): + """Log moving averages of key metrics""" + wandb_metrics = {} + + # Moving averages + if self.accumulated_metrics['losses']: + avg_loss = self.get_moving_average(self.accumulated_metrics['losses'], window_size) + wandb_metrics[f'moving_avg/loss_{window_size}'] = avg_loss + + if self.accumulated_metrics['rewards']: + avg_reward = self.get_moving_average(self.accumulated_metrics['rewards'], window_size) + wandb_metrics[f'moving_avg/reward_{window_size}'] = avg_reward + + if self.accumulated_metrics['kl_divergences']: + avg_kl = self.get_moving_average(self.accumulated_metrics['kl_divergences'], window_size) + wandb_metrics[f'moving_avg/kl_divergence_{window_size}'] = avg_kl + + if self.accumulated_metrics['rollout_lengths']: + avg_length = self.get_moving_average(self.accumulated_metrics['rollout_lengths'], window_size) + wandb_metrics[f'moving_avg/rollout_length_{window_size}'] = avg_length + + if self.accumulated_metrics['completion_times']: + avg_ct = self.get_moving_average(self.accumulated_metrics['completion_times'], window_size) + wandb_metrics[f'moving_avg/completion_time_{window_size}'] = avg_ct + + if self.accumulated_metrics['step_times']: + avg_st = self.get_moving_average(self.accumulated_metrics['step_times'], window_size) + wandb_metrics[f'moving_avg/step_time_{window_size}'] = avg_st + + # Log to wandb + if self.enable_wandb and wandb_metrics: + # Use our internal monotonically-increasing counter so that these + # points are not overwritten when `global_step` resets each round. + wandb_step = max(0, self.step_count - 1) + self.wandb_run.log(wandb_metrics, step=wandb_step) + + def log_hyperparameters(self, hparams: dict): + """Log hyperparameters""" + # Convert all values to scalars for TensorBoard + scalar_hparams = {} + for key, value in hparams.items(): + if isinstance(value, (int, float)): + scalar_hparams[key] = value + elif isinstance(value, (str,list)): + # TensorBoard doesn't handle strings well, so we'll just log them as text + continue + else: + scalar_hparams[key] = float(value) if value is not None else 0.0 + + # Log to wandb (wandb handles different types better) + if self.enable_wandb: + # Update wandb config with hyperparameters + self.wandb_run.config.update(hparams) + + def log_model_statistics(self, model, global_step: int): + """Log model-specific statistics""" + wandb_metrics = {} + + # Model parameter statistics + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + + wandb_metrics['model/total_parameters'] = total_params + wandb_metrics['model/trainable_parameters'] = trainable_params + + # Parameter norms + total_norm = 0 + for p in model.parameters(): + if p.grad is not None: + total_norm += p.grad.data.norm(2).item() ** 2 + total_norm = total_norm ** 0.5 + + if total_norm > 0: + wandb_metrics['model/gradient_norm'] = total_norm + + # Weight norms by layer (sample a few to avoid too many metrics) + layer_count = 0 + for name, param in model.named_parameters(): + if param.requires_grad and param.data is not None: + # Only log first few layers to wandb to avoid clutter + if layer_count < 10: + wandb_metrics[f'model_weights/{name}_norm'] = param.data.norm().item() + layer_count += 1 + + # Log to wandb + if self.enable_wandb and wandb_metrics: + self.wandb_run.log(wandb_metrics, step=global_step) + + def log_reward_distribution(self, rewards: list, global_step: int): + """Log reward distribution""" + if rewards: + if self.enable_wandb: + wandb_metrics = { + 'rewards/min': min(rewards), + 'rewards/max': max(rewards), + 'rewards/std': torch.tensor(rewards).std().item(), + 'rewards/mean': sum(rewards) / len(rewards) + } + # Create histogram for wandb + wandb_metrics['rewards/histogram'] = wandb.Histogram(rewards) + self.wandb_run.log(wandb_metrics, step=global_step) + + def save_training_config(self, config: dict): + """Save training configuration to file""" + config_path = os.path.join(self.log_dir, "training_config.json") + with open(config_path, 'w') as f: + json.dump(config, f, indent=2, default=str) + print(f"Training configuration saved to: {config_path}") + + def close(self): + """Close logging connections""" + if self.enable_wandb and self.wandb_run: + self.wandb_run.finish() + print("WandB logging closed") + +# -------------------- New Callback -------------------- +class GRPOMetricsCallback(TrainerCallback): + """HuggingFace Trainer callback that forwards log events to our + TrainingMetricsLogger instance so that each GRPO step is recorded.""" + + def __init__(self, logger: "TrainingMetricsLogger"): + super().__init__() + self.logger = logger + + + def on_log(self, args, state, control, logs=None, **kwargs): + # Forward the metrics dictionary to the TrainingMetricsLogger. This + # fires after every call to `Trainer.log`, i.e. after each GRPO step. + if logs: + # Use a generic step_id; users can differentiate by global_step. + self.logger.log_training_step("grpo_step", logs, state.global_step) diff --git a/python/spotlight_prj/fedllm/data_formatting.py b/python/spotlight_prj/fedllm/data_formatting.py new file mode 100644 index 0000000000..047892cc38 --- /dev/null +++ b/python/spotlight_prj/fedllm/data_formatting.py @@ -0,0 +1,158 @@ + +from datasets import load_dataset + + +class DataFormatting: + + def __init__(self): + + self.system_prompt = """ + + Respond in the following format: + + + ... + + + + + ... + + + """ + + + + def extract_answer_from_model_output(self, text): + + """ + Extracts the value from the last tag in the text. + + Args: + text (str): The model generated containing XML-style tags. + + Returs: + str or None: The content inside the tags, or None if no valid answer is found + + Explanation: + 1. Splits the text on the tag to isolate content after the tag. + 2. Checks if at least one tag exists in the text. + 3. For the last segment: + - Verifies it contains a closing + - Extracts only the content between the tags. + 4. Returns None if the answer is empty (just "...") or if tags are missing + """ + + + #split on and take everything after the last occurane. + parts = text.split("") + + if len(parts)<2: # No tag found + + return None + + last_part = parts[-1] + + #Extract the content up to + + if "" not in last_part: + return None + + answer = last_part.split("")[0].strip() + + return None if answer =="..." else answer + + + def extract_answer_from_dataset(self, text): + + """ + Extracts the answer from the GSM8K dataset examples. + + Args: + text(str): The dataset example text containing a question and answer + + Returns: + str or None: The extracted answer part after the '####' delimiter, or None + + + Explanation: + + 1. Checks if the text contains the '####' delimiter that separates questions from answers + 2. If found, splits the text at this delimiter and returns the second part + 3. The answer is stripped of leading or trailing white spaces. + 4. Returns None if no delimiter is present. + + """ + + if "####" not in text: + return None + + return text.split("####")[1].strip() + + + + def prepare_dataset(self, split="train"): + + """ + Load and prepare GSM8K dataset for training with string prompts. + + Args: + split(str): The dataset split to load("train" or "test"), Defaults to "train" + + Returns: + list: A list of formatted examples, each containing a prompt string and the role + + Explanation: + 1. Loads GSM8K dataset from Hugging Face dataset hub. + 2. For each example in the dataset: + - Creates a list of messages with system prompt and the question. + - Converts this list into a single string prompt using build_prompt() + - Extracts the answer from the dataset example. + - Creates a list of formatted examples with prompt and answer. + 3. Returns the list of formatted examples ready for model training or evaluation. + """ + + data = load_dataset('openai/gsm8k', 'main')[split] + + formatted_data = [] + + for example in data: + + # convert the list of messages to a single string prompt + + prompt_str = self.build_prompt([ + {"role": "system", "content": self.system_prompt}, + {"role":"user", "content": example["question"]} + ]) + + + formatted_example = { + "prompt":prompt_str, # string rather than a list + "answer": self.extract_answer_from_dataset(example["answer"]) + } + formatted_data.append(formatted_example) + + return formatted_data + + + + def build_prompt(self,messages): + + """ + Build a single prompt string from a list of messages. + + Args: + messages(list): A list of message dictionaries, each with 'role' and 'content' + + Returns: + str: A concatenated string of all message content. + + Explanation: + 1. Takes a list of message dictionaries in typical chat format. + 2. Extracts the 'content' field from each message and strips whitespace. + 3. Joins all content strings with newlines to create a single prompt. + 4. This preserves the training format while converting from structures messages. + """ + + return "\n".join(msg["content"].strip() for msg in messages) + diff --git a/python/spotlight_prj/fedllm/evaluation.py b/python/spotlight_prj/fedllm/evaluation.py new file mode 100644 index 0000000000..c1cfa94f38 --- /dev/null +++ b/python/spotlight_prj/fedllm/evaluation.py @@ -0,0 +1,206 @@ + +import re + +import torch +from data_formatting import DataFormatting + + +class Evaluation: + + def __init__(self): + self.dat_fmt = DataFormatting() + + + + + def extract_last_number(self, text): + + """ + Extracts the last number appearing in the text + + Args: + text (str): The text to extract a number from. + + Returns: + float or None: The last number in the text, or None if no number is found + + + Explanation: + 1. Removes dollar signs and percentage symbols from text. + 2. Users regex to find a number that appeares at the end of the text. + 3. The pattern matches numbers that appear at the end of the string. + 4 Return the found number as float, or None if no match is found. + """ + + text = text.replace('$', '').replace('%','') + + pattern = r'(?:^|\s|=)\s*(-?\d*\.?\d+)\s*$' + + match = re.search(pattern, text) + + return float(match.group(1)) if match else None + + + + + def extract_single_number(self, text): + + """ + Extracts a single number from text if exactly one number is present. + + Args: + text (str): The text to extract number from. + + Returns: + float or None: The single number in the text, or None if zero or multiple numbers. + + Explanation: + 1. Uses regex to find all numbers in the text including the negative numebers. + 2. If exactly one number if found, returns it as float. + 3. If zero or multiple numbers are found, returns None. + + """ + + numbers =re.findall(r'-?\d*\.?\d+', text) + #print("NUMBERS ARE:::", numbers) + + if len(numbers)==0: + return None + elif len(numbers)==1: + return float(numbers[0]) + + else: + return None + + + + def evaluate_model(self, model, tokenizer, eval_samples, device): + + """ + Evaluates the model on a set of examples and prints detailed results. + + Args: + model: The language model to evaluate. + tokenizer: The tokenizer for encoding inputs and decoding outputs. + eval_samples (list): List of evaluation examples each containing "prompt" and "answer" + device: The device (CPU or GPU) to run evaluation on + + Return: + float: The accuracy percentage (correct predictions / total examples * 100) + + + Explanation: + 1. Sets the model to evaluation mode. + 2. For each example in the evaluation set: + - Encodes the prompt and generates a respnse using the model + - Extracts the predicted answer from the generated response + - Compares the predicted answer with the expected answer using multiple methods + + a. Extract string matching + b. Single number extraction and comparion. + c. Last number extraction and comparison + -Prints detailed information about each example + 3. Calculates and returns the overall accuracy. + 4. Returns the model to training mode. + + """ + + + model.eval() + + correct = 0 + + total = len(eval_samples) + + print("\n" + "="*50) + print("EVALUATION ON", total, "EXAMPLES") + print("="*50) + + + for example in eval_samples: + + #get the prompt and expected answer + + full_prompt = example["prompt"] + expected = example["answer"] + + #Tokenize and generate response + + inputs = tokenizer(full_prompt, return_tensors='pt', padding=False, truncation=False, return_attention_mask=True).to(device) + + with torch.no_grad(): + + outputs = model.generate( + input_ids = inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=512, + temperature=0.7, + num_return_sequences=1, + pad_token_id = tokenizer.pad_token_id, + eos_token_id = tokenizer.eos_token_id, + forced_eos_token_id = tokenizer.eos_token_id, + early_stopping = False, + ) + + response = tokenizer.decode(outputs[0], skip_special_tokens=True) + + try: + #Extract answers and check correctness + predicted = self.dat_fmt.extract_answer_from_model_output(response) + + #Try different matching method + + if predicted == expected : # Exact match + + is_correct = True + + else: + # Try single number matchin + pred_num = self.extract_single_number(str(predicted)) + exp_num = self.extract_single_number(str(expected)) + + if pred_num is not None and exp_num is not None and pred_num==exp_num: + + is_correct = True + else: + #Try the last number matchin + pre_num = self.extract_last_number(str(predicted)) + exp_num = self.extract_last_number(str(expected)) + + is_correct = (pred_num is not None and exp_num is not None and pred_num == exp_num) + + if is_correct: + correct+=1 + + + # Print evaluation results + + print("\nPrompt:") + print(full_prompt) + print("\nExpected Answer:") + print(expected) + print("\nExtracted Answer:") + print(predicted) + print("\nFull Generated Response:") + print(response) + print("\nCorrect:", "✓" if is_correct else "✗") + print("--"*50) + + except Exception as e: + + print("\nFailed to parse the model output from prompt:") + print(full_prompt) + print('Error:',e) + print('-'*50) + + + accuracy = (correct / total) * 100 + + print(f"\nAccuracy: {accuracy:.2f}% ({correct}/{total})" ) + + # return the model to training mode + model.train() + + return accuracy + + diff --git a/python/spotlight_prj/fedllm/fedml_config/grpo_gsm8k_test_config.yaml b/python/spotlight_prj/fedllm/fedml_config/grpo_gsm8k_test_config.yaml index f2c3549db6..ed02856073 100644 --- a/python/spotlight_prj/fedllm/fedml_config/grpo_gsm8k_test_config.yaml +++ b/python/spotlight_prj/fedllm/fedml_config/grpo_gsm8k_test_config.yaml @@ -19,6 +19,7 @@ data_args: model_args: skip_log_model_net: True model_name_or_path: "Qwen/Qwen3-0.6B" + model_dtype: "bfloat16" peft_type: "none" # Full model fine-tuning use_flash_attention: False @@ -26,12 +27,12 @@ train_args: federated_optimizer: "FedAvg" client_optimizer: "adamw_torch" server_optimizer: "FedAvg" - client_num_in_total: 1 # Single client setup - client_num_per_round: 1 # Single client setup - comm_round: 3 # Reduced to 3 rounds for testing + client_num_in_total: 4 # Single client setup + client_num_per_round: 4 # Single client setup + comm_round: 30 # Reduced to 3 rounds for testing # GRPO-specific settings for testing - grpo_max_steps: 10 # Only 10 training steps per round for quick testing - grpo_num_epochs: 1 # Ignored when grpo_max_steps > 0 + grpo_max_steps: 50 + grpo_num_epochs: 2 # Ignored when grpo_max_steps > 0 grpo_batch_size: 2 # Smaller batch size for faster testing # FedML training settings (ignored when using GRPO) local_num_train_epochs: 1 @@ -40,14 +41,14 @@ train_args: deepspeed: null # Disable DeepSpeed for GRPO compatibility ddp_find_unused_parameters: False seed: 1234 - fp16: True # Use fp16 instead of bf16 for GPU compatibility - bf16: False + fp16: False # Use fp16 instead of bf16 for GPU compatibility + bf16: True gradient_checkpointing: False # Match GRPO config - per_device_train_batch_size: 4 # Will be overridden by GRPO + per_device_train_batch_size: 2 # Will be overridden by GRPO per_device_eval_batch_size: 8 - gradient_accumulation_steps: 1 # Will be overridden by GRPO + gradient_accumulation_steps: 2 # Will be overridden by GRPO eval_accumulation_steps: 4 - learning_rate: 5e-6 # Will be overridden by GRPO + learning_rate: 5e-6 # Will be overridden by GRPO warmup_steps: 0 output_dir: ".logs/FedML/{run_id}" logging_steps: 5 # Frequent logging for testing diff --git a/python/spotlight_prj/fedllm/run_fedllm.py b/python/spotlight_prj/fedllm/run_fedllm.py index b726450ca7..c09268e2a5 100644 --- a/python/spotlight_prj/fedllm/run_fedllm.py +++ b/python/spotlight_prj/fedllm/run_fedllm.py @@ -157,15 +157,22 @@ def _save_checkpoint( if state_dict is None: state_dict = model.state_dict() - if isinstance(model, (PeftModel, PreTrainedModel)): - model.save_pretrained( - save_directory=str(checkpoint_dir), - state_dict=state_dict - ) + # Always store a **single** weight file so that downstream logic can + # reliably load it without having to handle Hugging Face sharded + # checkpoints. For PEFT (LoRA/Adapter) models we keep the original + # filename expected by `load_checkpoint()` (``adapter_model.bin``), + # otherwise we save using the standard Hugging Face filename + # ``pytorch_model.bin``. + + checkpoint_dir = Path(checkpoint_dir) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + if isinstance(model, PeftModel): + filename = PEFT_WEIGHTS_NAME # "adapter_model.bin" else: - checkpoint_dir = Path(checkpoint_dir) - checkpoint_dir.mkdir(parents=True, exist_ok=True) - torch.save(state_dict, str(checkpoint_dir / HF_WEIGHTS_NAME)) + filename = HF_WEIGHTS_NAME # "pytorch_model.bin" + + torch.save(state_dict, str(checkpoint_dir / filename)) def save_checkpoint( @@ -208,9 +215,24 @@ def save_checkpoint( f" \"{type(model_or_trainer)}\"." ) - # save model checkpoint + # Save model checkpoint if isinstance(model_or_trainer, HFTrainer): - model_or_trainer.save_checkpoint(checkpoint_dir) + # Hugging Face Trainer normally creates sharded checkpoints. To keep + # downstream logic simple we instead persist a **single** weight file + # for the underlying model, re-using the same helper that `Module` + # path employs. + + underlying_model = model_or_trainer.model + + if is_saving_process: + # Prefer caller-provided `state_dict` when given (e.g. aggregated + # weights from the server); otherwise pull fresh weights from the + # model. + _save_checkpoint( + underlying_model, + checkpoint_dir, + state_dict or underlying_model.state_dict() + ) elif isinstance(model_or_trainer, Module): if is_saving_process: @@ -381,7 +403,12 @@ def on_after_local_training(self, train_data, device, args: Arguments) -> None: self.latest_checkpoint_dir = self.checkpoint_dir / f"round_{self.round_idx}_before_agg" self.log(f"saving model to \"{self.latest_checkpoint_dir}\"") - save_checkpoint(self.trainer, self.latest_checkpoint_dir) + # Force checkpoint creation even if TrainingArguments.save_strategy == "no" + save_checkpoint( + self.trainer, + self.latest_checkpoint_dir, + is_saving_process=True, + ) self.log("finished") return outputs diff --git a/python/spotlight_prj/fedllm/run_fedllm_custom.py b/python/spotlight_prj/fedllm/run_fedllm_custom.py index 1ea3706962..825c101633 100644 --- a/python/spotlight_prj/fedllm/run_fedllm_custom.py +++ b/python/spotlight_prj/fedllm/run_fedllm_custom.py @@ -44,17 +44,19 @@ def _save_checkpoint( if state_dict is None: state_dict = model.state_dict() - if isinstance(model, (PeftModel, PreTrainedModel)): - # Force safe_serialization=False to get pytorch_model.bin instead of model.safetensors - model.save_pretrained( - save_directory=str(checkpoint_dir), - state_dict=state_dict, - safe_serialization=False # This ensures pytorch_model.bin is created - ) + # Always produce a single-file checkpoint so that downstream loading logic + # can simply look for ``adapter_model.bin`` (PEFT) or ``pytorch_model.bin`` + # without worrying about Hugging Face sharding. + + checkpoint_dir = Path(checkpoint_dir) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + if isinstance(model, PeftModel): + filename = "adapter_model.bin" else: - checkpoint_dir = Path(checkpoint_dir) - checkpoint_dir.mkdir(parents=True, exist_ok=True) - torch.save(state_dict, str(checkpoint_dir / HF_WEIGHTS_NAME)) + filename = HF_WEIGHTS_NAME # "pytorch_model.bin" + + torch.save(state_dict, str(checkpoint_dir / filename)) # Monkey patch the _save_checkpoint function in the imported module diff --git a/python/spotlight_prj/fedllm/scripts/run_fedml_client_custom.sh b/python/spotlight_prj/fedllm/scripts/run_fedml_client_custom.sh index 2d4be76789..09fa207cbc 100755 --- a/python/spotlight_prj/fedllm/scripts/run_fedml_client_custom.sh +++ b/python/spotlight_prj/fedllm/scripts/run_fedml_client_custom.sh @@ -25,7 +25,7 @@ LAUNCHER="${6:-"auto"}" CONFIG_PATH="${7:-"fedml_config/grpo_gsm8k_test_config.yaml"}" # Use the custom launcher that properly handles non-PEFT models -python3 launch_fedllm_custom.py \ +timeout --signal=SIGINT --kill-after=30s 22200 python3 launch_fedllm_custom.py \ --cf "${CONFIG_PATH}" \ --rank "${RANK}" \ --role client \ diff --git a/python/spotlight_prj/fedllm/scripts/run_fedml_server_custom.sh b/python/spotlight_prj/fedllm/scripts/run_fedml_server_custom.sh index f32a85e9fa..08f7d83bda 100755 --- a/python/spotlight_prj/fedllm/scripts/run_fedml_server_custom.sh +++ b/python/spotlight_prj/fedllm/scripts/run_fedml_server_custom.sh @@ -25,8 +25,10 @@ LAUNCHER="${6:-"auto"}" # FedML config CONFIG_PATH="${7:-"fedml_config/fedml_config.yaml"}" +python scripts/save_initial_checkpoint.py + # Use the custom launcher that properly handles non-PEFT models -python3 launch_fedllm_custom.py \ +timeout --signal=SIGINT --kill-after=30s 22200 python3 launch_fedllm_custom.py \ --cf "${CONFIG_PATH}" \ --rank "${RANK}" \ --role server \ diff --git a/python/spotlight_prj/fedllm/scripts/save_initial_checkpoint.py b/python/spotlight_prj/fedllm/scripts/save_initial_checkpoint.py index 4b7ef2147b..2b00909dd1 100644 --- a/python/spotlight_prj/fedllm/scripts/save_initial_checkpoint.py +++ b/python/spotlight_prj/fedllm/scripts/save_initial_checkpoint.py @@ -9,7 +9,7 @@ # Configuration RUN_ID = os.environ.get("RUN_ID", "test_run") -MODEL_NAME = "Qwen/Qwen3-0.6B" +MODEL_NAME = "Qwen/Qwen3-1.7B" OUTPUT_DIR = f"/workspace/FedML/python/spotlight_prj/fedllm/.logs/FedML/{RUN_ID}/node_0/init" print(f"Saving initial checkpoint for model: {MODEL_NAME}") @@ -20,8 +20,8 @@ # Load model and tokenizer print("Loading model...") -model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16) -tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) # Save model in the format expected by FedML (pytorch_model.bin) print("Saving model checkpoint...") diff --git a/python/spotlight_prj/fedllm/stat_test.py b/python/spotlight_prj/fedllm/stat_test.py new file mode 100644 index 0000000000..d9a97b2c95 --- /dev/null +++ b/python/spotlight_prj/fedllm/stat_test.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +""" +perm_test_files.py – Paired permutation test for model rewards (multi-rollout format) + +File format +----------- +Each line = one evaluation question. +Each line contains comma-separated rewards (0, 1.5, 2) – one per rollout. + +Example (Three questions, two rollouts each): +2,1.5 +2,2 +0,0 + +Usage +----- +python perm_test_files.py +""" +from __future__ import annotations +import sys +import pathlib +import numpy as np +from typing import Callable + +# ---------------------------------------------------------------------- +# CONFIGURATION: choose how to collapse multiple roll-outs into one value. +# ---------------------------------------------------------------------- +AGG_FUNC: Callable[[np.ndarray], float] = np.mean # or np.max, etc. + +# ---------------------------------------------------------------------- +def load_rewards(path: str | pathlib.Path) -> np.ndarray: + """ + Read *path* and return a 1-D array of per-question aggregated rewards. + Each line is split on commas, converted to floats, then collapsed with AGG_FUNC. + """ + try: + lines = pathlib.Path(path).read_text().strip().splitlines() + except OSError as err: + sys.exit(f"Error reading '{path}': {err}") + + if not lines: + sys.exit(f"Error: '{path}' is empty.") + + per_question = [] + for lineno, line in enumerate(lines, start=1): + if not line.strip(): + sys.exit(f"Error: blank line at {path}:{lineno}.") + try: + values = np.fromstring(line, sep=",", dtype=float) + except ValueError as err: + sys.exit(f"Error parsing numbers in '{path}' line {lineno}: {err}") + if values.size == 0: + sys.exit(f"Error: no numeric values in '{path}' line {lineno}.") + per_question.append(AGG_FUNC(values)) + + return np.asarray(per_question, dtype=float) + + +def permutation_test(rA: np.ndarray, + rB: np.ndarray, + B: int = 100_000, + seed: int = 42) -> tuple[float, float]: + """ + Paired permutation test on per-question reward differences. + + Returns + ------- + gap : float mean(rA − rB) + p_two_sided : float permutation p-value + """ + d = rA - rB + gap = d.mean() + + rng = np.random.default_rng(seed) + signs = rng.choice([1, -1], size=(B, d.size)) + perm_gaps = (signs * d).mean(axis=1) + p_two_sided = (np.abs(perm_gaps) >= abs(gap)).mean() + return gap, p_two_sided + + +def main() -> None: + if len(sys.argv) != 3: + print("Usage: python perm_test_files.py ") + sys.exit(1) + + rA = load_rewards(sys.argv[1]) + rB = load_rewards(sys.argv[2]) + + if rA.size != rB.size: + sys.exit("Error: the two files contain different numbers of questions.") + + gap, p = permutation_test(rA, rB) + + print(f"# questions : {rA.size}") + print(f"Aggregation over roll-outs : {AGG_FUNC.__name__}") + print(f"Mean reward difference (A-B) : {gap:.6f}") + print(f"Two-sided permutation p-value : {p:.6g}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python/spotlight_prj/fedllm/validation.py b/python/spotlight_prj/fedllm/validation.py new file mode 100644 index 0000000000..ffae86eb2e --- /dev/null +++ b/python/spotlight_prj/fedllm/validation.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python +""" +Evaluate Qwen3‑0.6B on GSM8K (test split) with vLLM. + +Usage examples +-------------- +# Default: download model, 1 rollout, 8‑question batches, 100 examples +python eval_qwen3_gsm8k.py + +# Local checkpoint, 4 rollouts, 16 questions per batch, 200 examples +python eval_qwen3_gsm8k.py \ + --model /path/to/Qwen3-0.6B-local \ + --rollouts 4 \ + --batch-examples 16 \ + --num-examples 200 + +# Using custom weights with base model config/tokenizer +python eval_qwen3_gsm8k.py \ + --model /path/to/custom_weights.safetensors \ + --base-model Qwen/Qwen3-0.6B \ + --rollouts 4 \ + --batch-examples 16 \ + --num-examples 200 +""" + +import argparse +import os +import re +import shutil +import tempfile +import time +from fractions import Fraction +from pathlib import Path +from typing import Optional, List + +from datasets import load_dataset # pip install datasets +from transformers import AutoConfig, AutoTokenizer # pip install transformers +from vllm import LLM, SamplingParams # pip install vllm + +# --------------------------- reward configuration --------------------------- + +BOXED_RE = re.compile(r"\\boxed\{([^}]*)\}") # capture content inside \boxed{…} + +EXACT_MATCH_REWARD = 2.0 +NUM_EQ_REWARD = 1.5 +INCORRECT_REWARD = 0.0 + + +# ------------------------------- utilities --------------------------------- + +def is_weight_file(path: str) -> bool: + """Check if path points to a weight file (.bin, .safetensors, .pt, .pth).""" + if not os.path.isfile(path): + return False + return Path(path).suffix.lower() in {'.bin', '.safetensors', '.pt', '.pth'} + + +def is_complete_checkpoint(path: str) -> bool: + """Check if path is a directory containing config.json (indicating a complete checkpoint).""" + if not os.path.isdir(path): + return False + return os.path.exists(os.path.join(path, 'config.json')) + + +def setup_model_with_custom_weights(weight_path: str, base_model: str) -> str: + """ + Create a temporary directory with base model config/tokenizer and custom weights. + Returns the path to the temporary directory. + """ + # Create temporary directory + temp_dir = tempfile.mkdtemp(prefix="qwen_custom_weights_") + + try: + print(f"[INFO] Setting up temporary model directory at {temp_dir}") + print(f"[INFO] Loading config and tokenizer from base model: {base_model}") + + # Download and save config + config = AutoConfig.from_pretrained(base_model, trust_remote_code=True) + config.save_pretrained(temp_dir) + + # Download and save tokenizer + tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) + tokenizer.save_pretrained(temp_dir) + + # Copy weight file to temporary directory + weight_filename = Path(weight_path).name + dest_weight_path = os.path.join(temp_dir, weight_filename) + print(f"[INFO] Copying weights from {weight_path} to {dest_weight_path}") + shutil.copy2(weight_path, dest_weight_path) + + print(f"[INFO] Custom model setup complete in {temp_dir}") + return temp_dir + + except Exception as e: + # Clean up on error + shutil.rmtree(temp_dir, ignore_errors=True) + raise RuntimeError(f"Failed to setup custom model: {e}") + + +def to_number(text: str) -> Optional[float]: + """Convert string to float if possible, handling simple fractions.""" + text = text.replace(",", "").strip() + # Fractions like 3/4 + if "/" in text: + try: + return float(Fraction(text)) + except (ValueError, ZeroDivisionError): + pass + try: + return float(text) + except ValueError: + return None + + +def extract_boxed(text: str) -> str: + """Return first \\boxed{...} contents; '' if none.""" + m = BOXED_RE.search(text) + return m.group(1) if m else "" + + +def reward(pred: str, gold: str) -> float: + """Assign reward based on exact match or numeric equivalence.""" + pred, gold = pred.strip(), gold.strip() + if pred == gold: + return EXACT_MATCH_REWARD + p_num, g_num = to_number(pred), to_number(gold) + if (p_num is not None and g_num is not None + and abs(p_num - g_num) < 1e-4): + return NUM_EQ_REWARD + return INCORRECT_REWARD + + +def batched(lst: List, n: int): + """Yield successive n‑sized chunks from *lst*.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + + +def get_output_filename(model_path: str) -> str: + """Generate a filesystem-safe filename based on the model path.""" + if model_path is None: + return "Qwen_Qwen3-0.6B_rewards.csv" + + # Extract meaningful name from different model path formats + if "/" in model_path: + # HuggingFace model ID (e.g., "Qwen/Qwen3-0.6B") or file path + name = model_path.split("/")[-1] + if "." in name: # Remove file extension for weight files + name = Path(name).stem + else: + name = model_path + + # Replace invalid filename characters + name = re.sub(r'[<>:"/\\|?*]', '_', name) + + return f"{name}_rewards.csv" + + +# ------------------------------- main -------------------------------------- + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser() + p.add_argument("--rollouts", type=int, default=4, + help="completions per example (default: 4)") + p.add_argument("--batch-examples", type=int, default=16, + help="examples per vLLM inference call (default: 2)") + p.add_argument("--num-examples", type=int, default=-1, + help="total GSM8K test examples to evaluate (default: 100, use -1 for full dataset)") + p.add_argument( + "--model", + default=None, + help=("HF repo ID, local checkpoint dir, or path to weight file(s). " + "If omitted, downloads Qwen/Qwen3-0.6B automatically."), + ) + p.add_argument( + "--base-model", + default="Qwen/Qwen3-0.6B", + help=("Base model for config/tokenizer when using custom weight files. " + "Ignored when --model is a full checkpoint directory."), + ) + p.add_argument("--max-tokens", type=int, default=512, + help="generation length cap (tokens)") + p.add_argument("--temperature", type=float, default=0.7) + p.add_argument("--top-p", type=float, default=0.95) + return p.parse_args() + + +def main() -> None: + args = parse_args() + temp_model_dir = None + + try: + # ------------------------- resolve model path -------------------------- + original_model_arg = args.model # Store original for filename + if args.model is None: + args.model = "Qwen/Qwen3-0.6B" + print(f"[INFO] No --model given → downloading '{args.model}' " + "from Hugging Face Hub…") + elif is_weight_file(args.model): + print(f"[INFO] Detected weight file: {args.model}") + print(f"[INFO] Using base model: {args.base_model}") + temp_model_dir = setup_model_with_custom_weights(args.model, args.base_model) + args.model = temp_model_dir + elif is_complete_checkpoint(args.model): + print(f"[INFO] Using complete checkpoint directory: {args.model}") + else: + # Assume it's a HuggingFace model ID + print(f"[INFO] Using HuggingFace model: {args.model}") + + # ----------------------- initialize LLM & sampler ---------------------- + llm = LLM(model=args.model, + trust_remote_code=True, # Qwen uses custom code + dtype="bfloat16") # let vLLM choose BF16 / FP16 / FP32 + + sampler = SamplingParams( + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_tokens, + n=args.rollouts, + seed=42 + ) + + # --------------------------- load dataset ----------------------------- + ds = load_dataset("openai/gsm8k", "main", split="test") + if args.num_examples == -1: + # Use full dataset without shuffling + print(f"[INFO] Using full dataset ({len(ds)} examples)") + else: + # Shuffle and select specified number of examples + num_to_select = min(args.num_examples, len(ds)) + ds = ds.shuffle(seed=42).select(range(num_to_select)) + print(f"[INFO] Using {len(ds)} examples (shuffled)") + + total_reward = 0.0 + total_completions = len(ds) * args.rollouts + all_example_rewards = [] # Track all rollout rewards for each example + + # --------------------------- evaluation ------------------------------- + print(f"[INFO] Starting generation for {len(ds)} examples in batches of {args.batch_examples}...") + start_time = time.time() + + batch_count = 0 + for batch in batched(list(ds), args.batch_examples): + batch_start = time.time() + prompts = [ex["question"] for ex in batch] # **raw questions only** + outputs = llm.generate(prompts, sampler) + batch_end = time.time() + + batch_count += 1 + batch_time = batch_end - batch_start + print(f"[TIMING] Batch {batch_count} ({len(batch)} examples): {batch_time:.2f}s") + + for ex, gen in zip(batch, outputs): + gold = ex["answer"].split("####")[-1].strip() + example_rollout_rewards = [] + for out in gen.outputs: + pred = extract_boxed(out.text) + rollout_reward = reward(pred, gold) + total_reward += rollout_reward + example_rollout_rewards.append(rollout_reward) + + # Store all rollout rewards for this example + all_example_rewards.append(example_rollout_rewards) + + end_time = time.time() + total_generation_time = end_time - start_time + + avg_reward = total_reward / total_completions + print(f"\n[TIMING] Total generation time: {total_generation_time:.2f}s") + print(f"[TIMING] Average time per batch: {total_generation_time / batch_count:.2f}s") + print(f"[TIMING] Average time per example: {total_generation_time / len(ds):.2f}s") + print(f"[TIMING] Average time per completion: {total_generation_time / total_completions:.3f}s") + print(f"\nEvaluated {len(ds)} examples × {args.rollouts} rollouts " + f"(batch size = {args.batch_examples}).") + print(f"Average reward: {avg_reward:.4f}") + + # ------------------------ write rewards to file ------------------------- + output_filename = get_output_filename(original_model_arg) + + with open(output_filename, 'w') as f: + for example_rewards in all_example_rewards: + line = ",".join(str(r) for r in example_rewards) + f.write(line + "\n") + + print(f"[INFO] Individual example rewards written to: {output_filename}") + print(f"[INFO] File contains {len(all_example_rewards)} lines, one per example") + + # Calculate max reward per example for summary stats + max_rewards_per_example = [max(rewards) for rewards in all_example_rewards] + print(f"[INFO] Max reward per example average: {sum(max_rewards_per_example) / len(max_rewards_per_example):.4f}") + + finally: + # Clean up temporary directory if it was created + if temp_model_dir and os.path.exists(temp_model_dir): + print(f"[INFO] Cleaning up temporary directory: {temp_model_dir}") + shutil.rmtree(temp_model_dir, ignore_errors=True) + + +if __name__ == "__main__": + main() \ No newline at end of file