diff --git a/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py b/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py index 45b2af0c..e38c9c0d 100644 --- a/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py +++ b/benchmarks/dpo_ewc/continual_dpo_EWC_trainer.py @@ -1,6 +1,9 @@ +from contextlib import nullcontext +from copy import deepcopy from dataclasses import dataclass, field from typing import Any, Dict, Optional, Union +import accelerate import deepspeed import torch import torch.nn as nn @@ -36,16 +39,10 @@ class ContinualDPOEWCConfig(ContinualDPOConfig): class ContinualDPOEWCTrainer(ContinualDPOTrainer): - """DPO Trainer enhanced with Elastic Weight Consolidation (EWC) for continual learning. + """DPO Trainer enhanced with Elastic Weight Consolidation (EWC) for continual learning.""" - EWC keeps a memory of parameter importance from previous tasks and penalizes - changes to important parameters when learning new tasks. - """ - - # Class-level variables to store Fisher Information and old parameters across tasks - class_fisher_information: Dict[str, torch.Tensor] = {} - class_old_params: Dict[str, torch.Tensor] = {} - current_task_index: int = 0 + fisher: Dict[str, torch.Tensor] = {} + old_params: Dict[str, torch.Tensor] = {} def __init__( self, @@ -55,17 +52,38 @@ def __init__( args: Optional[ContinualDPOEWCConfig] = None, **kwargs: Any, ): + self.model_copy = deepcopy(model) # Copy before deepspeed initialization + self.ewc_lambda = args.ewc_lambda if args is not None else 100.0 super().__init__(model, ref_model, reward_model, args, **kwargs) - # Store EWC-specific parameters - self.ewc_lambda = args.ewc_lambda if args is not None else 100.0 + def train(self) -> Any: + result = super().train() + + ContinualDPOEWCTrainer.old_params = { + name: param.detach().clone() + for name, param in self.accelerator.unwrap_model( + self.model + ).named_parameters() + if param.requires_grad + } - # Track if we're on the first task - is_first_task = ContinualDPOEWCTrainer.current_task_index == 0 - if is_first_task: - # Initialize empty dictionaries for first task - ContinualDPOEWCTrainer.class_fisher_information = {} - ContinualDPOEWCTrainer.class_old_params = {} + if self.accelerator.is_main_process: + fisher = self.compute_fisher() + else: + # The secondary processes need to have a dictionary initialized + # on the current device, and with the corrct tensor shapes to enable + # broadcast to _copy() the tensors. + fisher = { + name: torch.zeros_like(param, device=self.accelerator.device) + for name, param in self.accelerator.unwrap_model( + self.model + ).named_parameters() + if param.requires_grad + } + + # TODO: Double check that this broadcast is actually correct + ContinualDPOEWCTrainer.fisher = accelerate.utils.broadcast(fisher) + return result def compute_loss( self, @@ -74,119 +92,57 @@ def compute_loss( return_outputs: bool = False, num_items_in_batch: Optional[int] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: - """Compute the DPO loss with additional EWC regularization to prevent - catastrophic forgetting of previously learned tasks. - """ - # Regular DPO loss calculation - regular_loss, outputs = super().compute_loss( + dpo_loss, outputs = super().compute_loss( model, inputs, return_outputs=True, num_items_in_batch=num_items_in_batch ) - # Skip EWC loss on first task since there's nothing to preserve yet - is_first_task = ContinualDPOEWCTrainer.current_task_index == 0 - if is_first_task: - return (regular_loss, outputs) if return_outputs else regular_loss - - # Calculate EWC regularization loss - ewc_loss = self.compute_ewc_loss() - - # Combine losses - total_loss = regular_loss + ewc_loss - + def compute_ewc_loss() -> torch.Tensor: + ewc_loss = torch.tensor(0.0, device=self.accelerator.device) + model = self.accelerator.unwrap_model(self.model) + for name, param in model.named_parameters(): + if name in ContinualDPOEWCTrainer.fisher and param.requires_grad: + with ( + deepspeed.zero.GatheredParameters([param], modifier_rank=None) + if hasattr(param, 'ds_id') + else nullcontext() + ): + fisher = ContinualDPOEWCTrainer.fisher[name].to(param.device) + old_param = ContinualDPOEWCTrainer.old_params[name].to( + param.device + ) + ewc_loss += (fisher * (param - old_param).pow(2)).sum() + return 0.5 * self.ewc_lambda * ewc_loss + + ewc_loss = compute_ewc_loss() + total_loss = dpo_loss + ewc_loss self.log( { 'ewc_loss': ewc_loss.item(), - 'dpo_loss': regular_loss.item(), + 'dpo_loss': dpo_loss.item(), 'total_loss': total_loss.item(), } ) - return (total_loss, outputs) if return_outputs else total_loss - def compute_ewc_loss(self) -> torch.Tensor: - """Compute the EWC regularization loss. - - This loss penalizes changes to parameters that were important for previous tasks, - as determined by their Fisher information matrix. - - Returns: - EWC regularization loss tensor - """ - if not ContinualDPOEWCTrainer.class_fisher_information: - # No previous tasks, so no regularization needed - return torch.tensor(0.0, device=self.accelerator.device) - - ewc_loss = torch.tensor(0.0, device=self.accelerator.device) - - # Calculate the EWC penalty for each parameter - model = self.accelerator.unwrap_model(self.model) - - for name, param in model.named_parameters(): - if name not in ContinualDPOEWCTrainer.class_fisher_information: - continue - if not param.requires_grad: - continue - - if ( - name in ContinualDPOEWCTrainer.class_fisher_information - and param.requires_grad - ): - # Get the Fisher information and old parameter values - fisher = ContinualDPOEWCTrainer.class_fisher_information[name].to( - param.device - ) - - with deepspeed.zero.GatheredParameters([param], modifier_rank=0): - if self.accelerator.is_main_process: - old_param = ContinualDPOEWCTrainer.class_old_params[name].to( - param.device - ) - # Calculate squared distance weighted by Fisher information - delta = param - old_param - ewc_loss = ewc_loss + (fisher * delta.pow(2)).sum() - - # Apply the EWC lambda coefficient and return - return 0.5 * self.ewc_lambda * ewc_loss - - def compute_fisher_information( - self, num_samples: int = 120 + def compute_fisher( + self, num_samples: int = 120, device: str = 'cuda:0' ) -> Dict[str, torch.Tensor]: - """Compute Fisher Information matrix for the current model parameters. - - Args: - num_samples: Number of samples to use for Fisher computation - - Returns: - Dictionary mapping parameter names to their Fisher information values - """ - # Get unwrapped model for computing Fisher - model = self.accelerator.unwrap_model(self.model) - self.accelerator.device - - # Make sure parameters require gradients - for param in model.parameters(): - param.requires_grad_(True) - - # Initialize fisher information dictionary - fisher_info = {} - for name, param in model.named_parameters(): - if param.requires_grad: - fisher_info[name] = torch.zeros_like(param) - - # Create a dataloader for Fisher estimation - fisher_dataloader = self.get_train_dataloader() - - # Collect samples for Fisher estimation - sample_count = 0 - for batch in fisher_dataloader: - if sample_count >= num_samples: - break + # Computing fisher outside the deepspeed context + model = deepcopy(self.model_copy) + model.load_state_dict( + self.accelerator.unwrap_model(self.model).state_dict(), strict=False + ) + model = model.to(device) + model.eval() - # Check what keys are available in the batch (for debugging) - batch_keys = list(batch.keys()) + fisher = { + name: torch.zeros_like(param, device=device) + for name, param in model.named_parameters() + if param.requires_grad + } - # Try to determine the batch size from available keys - batch_size = None + def guess_batch_size(batch: Any) -> int: + batch_size = 1 for key in ['input_ids', 'chosen_input_ids', 'policy_input_ids']: if ( key in batch @@ -195,102 +151,74 @@ def compute_fisher_information( ): batch_size = batch[key].shape[0] break + return batch_size - if batch_size is None: - print( - f'Warning: Could not determine batch size. Available keys: {batch_keys}' - ) - batch_size = 1 # Default fallback - - # Forward pass with gradients - model.zero_grad() - - try: - loss, _ = self.compute_loss(model, batch, return_outputs=True) - - # Check if loss requires gradient - if not loss.requires_grad: - print( - "Warning: Loss doesn't require gradients. Adding requires_grad=True" - ) - loss = loss.clone().detach().requires_grad_(True) - - self.accelerator.backward(loss) - - # Accumulate squared gradients as Fisher information estimate - for name, param in model.named_parameters(): - if param.requires_grad and param.grad is not None: - fisher_info[name] += param.grad.detach().pow(2) - except Exception as e: - print(f'Error during Fisher computation: {e}') - continue - - # Safely increment sample count - sample_count += batch_size - - # Normalize by the number of samples - if sample_count > 0: - for name in fisher_info.keys(): - fisher_info[name] /= sample_count - else: - print('Warning: No samples were processed for Fisher computation') - - print(f'Computed Fisher information for {sample_count} examples') - return fisher_info + def move_to_device(batch: Any) -> Any: + if isinstance(batch, dict): + return { + k: v.to(device) if hasattr(v, 'to') else v for k, v in batch.items() + } + elif isinstance(batch, (list, tuple)): + return type(batch)(move_to_device(x) for x in batch) + else: + return batch.to(device) if hasattr(batch, 'to') else batch - def store_current_parameters(self) -> Dict[str, torch.Tensor]: - """Store the current model parameters. + sample_count = 0 + for batch in self.get_train_dataloader(): + if sample_count >= num_samples: + break - Returns: - Dictionary mapping parameter names to their current values - """ - model = self.accelerator.unwrap_model(self.model) - old_params = {} + sample_count += guess_batch_size(batch) + batch = move_to_device(batch) - for name, param in model.named_parameters(): - with deepspeed.zero.GatheredParameters([param], modifier_rank=0): - if self.accelerator.is_main_process: - if param.requires_grad: - old_params[name] = param.data.clone().detach() - return old_params + model.zero_grad(set_to_none=True) + loss = self.compute_dpo_loss_for_fisher(model, batch) + loss.backward() - def train(self) -> Any: - """Override train method to incorporate EWC regularization.""" - # Regular training - result = super().train() + for name, param in model.named_parameters(): + with ( + deepspeed.zero.GatheredParameters([param], modifier_rank=None) + if hasattr(param, 'ds_id') + else nullcontext() + ): + if param.grad is not None: + fisher[name] += param.grad.detach().clone().pow(2) - # After training completes, update the Fisher information and old parameters - # for the next task - self.accelerator.print( - 'Computing Fisher information matrix for the next task...' - ) + for name in fisher: + fisher[name] /= sample_count + return fisher - # Calculate and log EWC loss statistics before computing new Fisher information - if ContinualDPOEWCTrainer.current_task_index > 0: - ewc_loss = self.compute_ewc_loss() - # Log EWC loss details - self.log( - { - 'ewc_stats/total_loss': ewc_loss.item(), - 'ewc_stats/per_param_avg': ewc_loss.item() - / sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ), - 'ewc_stats/lambda': self.ewc_lambda, - 'ewc_stats/task_index': ContinualDPOEWCTrainer.current_task_index, - } - ) - self.accelerator.print( - f'EWC loss for task {ContinualDPOEWCTrainer.current_task_index}: {ewc_loss.item():.4f}' + def compute_dpo_loss_for_fisher( + self, + model: Union[PreTrainedModel, nn.Module], + batch: dict[str, Union[torch.Tensor, Any]], + ) -> torch.Tensor: + # The following is a patch of super().compute_loss which avoid metrics computation + # and can be run outside of the accelerate context. It primarily specializes the + # implementation of get_batch_loss_metrics to skip calls to `self.accelerator.gather_for_metrics`. + with ( + torch.amp.autocast('cuda') + if self._peft_has_been_casted_to_bf16 + else nullcontext() + ): + model_output = self.concatenated_forward(model, batch) + + if 'ref_chosen_logps' in batch and 'ref_rejected_logps' in batch: + ref_chosen_logps = batch['ref_chosen_logps'] + ref_rejected_logps = batch['ref_rejected_logps'] + else: + ref_chosen_logps, ref_rejected_logps = self.compute_ref_log_probs(batch) + + losses, _, _ = self.dpo_loss( + model_output['chosen_logps'], + model_output['rejected_logps'], + ref_chosen_logps, + ref_rejected_logps, ) - - # Compute new Fisher information and parameters - ContinualDPOEWCTrainer.class_fisher_information = ( - self.compute_fisher_information() - ) - ContinualDPOEWCTrainer.class_old_params = self.store_current_parameters() - - # Increment task index for next time - ContinualDPOEWCTrainer.current_task_index += 1 - - return result + if self.args.rpo_alpha is not None: + losses = losses + self.args.rpo_alpha * model_output['nll_loss'] + if self.use_weighting: + losses = losses * model_output['policy_weights'] + if self.aux_loss_enabled: + losses = losses + self.aux_loss_coef * model_output['aux_loss'] + return losses.mean().to(self.args.device) diff --git a/benchmarks/parallel_eval_checkpoints.py b/benchmarks/parallel_eval_checkpoints.py index 158deb65..01924388 100644 --- a/benchmarks/parallel_eval_checkpoints.py +++ b/benchmarks/parallel_eval_checkpoints.py @@ -9,6 +9,7 @@ ContinualDPOConfig, ContinualDPOTrainer, ) +from safetensors import safe_open from transformers import ( AutoModelForCausalLM, AutoModelForSequenceClassification, @@ -30,6 +31,7 @@ def main( model_args: ModelConfig, ) -> None: # Determine torch dtype and quantization configs + torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ['auto', None] @@ -45,14 +47,17 @@ def main( revision=model_args.model_revision, attn_implementation=model_args.attn_implementation, torch_dtype=torch_dtype, - use_cache=False if training_args.gradient_checkpointing else True, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, ) # Checkpoint loop checkpoint_path = script_args.checkpoint_dir - dataset_name = checkpoint_path.split('/')[-2].replace('.', '') + if 'DPO' not in checkpoint_path: + dataset_name = 'dataset-' + checkpoint_path.split('/')[-2].split('_')[-1] + else: + dataset_name = checkpoint_path.split('/')[-2].replace('.', '') + checkpoint_step = checkpoint_path.split('/')[-1].replace('.', '') print( f'Evaluating checkpoint: {checkpoint_step} trained on dataset: {dataset_name} on all tasks' @@ -60,11 +65,29 @@ def main( checkpoint_name = dataset_name + '_' + checkpoint_step print('checkpoint_name', checkpoint_name) - model = AutoModelForCausalLM.from_pretrained( - checkpoint_path, - trust_remote_code=model_args.trust_remote_code, - **model_kwargs, - ) + if 'DPO' not in checkpoint_path: + base_model_name = model_args.model_name_or_path # Use the base model path for config + + # Load config from base model first + from transformers import AutoConfig + config = AutoConfig.from_pretrained( + base_model_name, + trust_remote_code=model_args.trust_remote_code, + ) + # remove the prefix 'policy.' from the keys to load the model; skip the critic and value model + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + config=config, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + else: + model = AutoModelForCausalLM.from_pretrained( + checkpoint_path, + trust_remote_code=model_args.trust_remote_code, + local_files_only=True, + **model_kwargs, + ) peft_config = get_peft_config(model_args) ref_model = AutoModelForCausalLM.from_pretrained( @@ -92,12 +115,12 @@ def main( output_dir = training_args.output_dir # Validate reward model paths if provided - for i, _ in enumerate(continual_dataset): - reward_path = training_args.reward_model_path + '_' + str(i) - if not os.path.exists(reward_path): - raise FileNotFoundError( - f'Reward model not found for dataset {i} at {reward_path}' - ) + # for i, _ in enumerate(continual_dataset): + # reward_path = training_args.reward_model_path + '_' + str(i) + # if not os.path.exists(reward_path): + # raise FileNotFoundError( + # f'Reward model not found for dataset {i} at {reward_path}' + # ) # Task Loop for i, dataset in enumerate(continual_dataset): diff --git a/jobs/cppo/cppo_domain_shift_multi_gpu.sh b/jobs/cppo/cppo_domain_shift_multi_gpu.sh index e4f9f26e..2b10c00e 100644 --- a/jobs/cppo/cppo_domain_shift_multi_gpu.sh +++ b/jobs/cppo/cppo_domain_shift_multi_gpu.sh @@ -35,5 +35,5 @@ accelerate launch --config_file benchmarks/cppo/accelerate_configs/deepspeed_zer --eval_steps 200 \ --save_steps 300 \ --bf16 \ - --output_dir "$HOME/Qwen2-0.5B-CPPO-${dataset_name}" \ + --output_dir "/home/s/shahradm/links/projects/aip-rrabba/shared/aifgen_experiments/Qwen2-0.5B-CPPO-${dataset_name}" \ --no_remove_unused_columns diff --git a/jobs/cppo/cppo_lipschitz_multi_gpu.sh b/jobs/cppo/cppo_lipschitz_multi_gpu.sh index 06a533af..43509299 100644 --- a/jobs/cppo/cppo_lipschitz_multi_gpu.sh +++ b/jobs/cppo/cppo_lipschitz_multi_gpu.sh @@ -36,5 +36,5 @@ accelerate launch --config_file benchmarks/cppo/accelerate_configs/deepspeed_zer --eval_steps 200 \ --save_steps 300 \ --bf16 \ - --output_dir "$HOME/Qwen2-0.5B-CPPO-${dataset_name}" \ + --output_dir "/home/s/shahradm/links/projects/aip-rrabba/shared/aifgen_experiments/Qwen2-0.5B-CPPO-${dataset_name}" \ --no_remove_unused_columns diff --git a/jobs/cppo/cppo_long_piecewise_multi_gpu.sh b/jobs/cppo/cppo_long_piecewise_multi_gpu.sh index 2614ad53..1e2fbfa1 100644 --- a/jobs/cppo/cppo_long_piecewise_multi_gpu.sh +++ b/jobs/cppo/cppo_long_piecewise_multi_gpu.sh @@ -35,5 +35,5 @@ accelerate launch --config_file benchmarks/cppo/accelerate_configs/deepspeed_zer --eval_steps 200 \ --save_steps 300 \ --bf16 \ - --output_dir "$HOME/Qwen2-0.5B-CPPO-${dataset_name}" \ + --output_dir "/home/s/shahradm/links/projects/aip-rrabba/shared/aifgen_experiments/Qwen2-0.5B-CPPO-${dataset_name}" \ --no_remove_unused_columns diff --git a/jobs/cppo/cppo_piecewise_multi_gpu.sh b/jobs/cppo/cppo_piecewise_multi_gpu.sh index ea84fae2..3b6c84d4 100644 --- a/jobs/cppo/cppo_piecewise_multi_gpu.sh +++ b/jobs/cppo/cppo_piecewise_multi_gpu.sh @@ -30,11 +30,11 @@ accelerate launch --config_file benchmarks/cppo/accelerate_configs/deepspeed_zer --response_length 256 \ --num_train_epochs 4 \ --gradient_checkpointing \ - --per_device_train_batch_size 8 \ + --per_device_train_batch_size 4 \ --logging_steps 10 \ --eval_strategy steps \ --eval_steps 200 \ --save_steps 300 \ --bf16 \ - --output_dir "$HOME/Qwen2-0.5B-CPPO-${dataset_name}" \ + --output_dir "/home/s/shahradm/links/projects/aip-rrabba/shared/aifgen_experiments/Qwen2-0.5B-CPPO-${dataset_name}" \ --no_remove_unused_columns diff --git a/jobs/dpo/dpo_cppo_multi_gpu.sh b/jobs/dpo/dpo_cppo_multi_gpu.sh index d427e72e..30316a4b 100644 --- a/jobs/dpo/dpo_cppo_multi_gpu.sh +++ b/jobs/dpo/dpo_cppo_multi_gpu.sh @@ -18,20 +18,20 @@ dataset_name='CPPO-RL' accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ benchmarks/dpo/dpo_continual.py \ - --dataset_name 'CPPO-RL' \ + --dataset_name $dataset_name \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --reward_model_path LifelongAlignment/Qwen2.5-0.5B-Instruct_CPPO_REWARD \ - --learning_rate 5.0e-6 \ + --learning_rate 1.0e-6 \ --num_train_epochs 4 \ - --per_device_train_batch_size 8 \ + --per_device_train_batch_size 16 \ --gradient_checkpointing \ - --logging_steps 20 \ + --logging_steps 10 \ --eval_strategy steps \ --response_length 256 \ - --eval_steps 500 \ - --save_steps 500 \ + --eval_steps 50000 \ + --save_steps 300 \ --bf16 \ --output_dir "$SCRATCH/projects/Qwen2-0.5B-DPO-${dataset_name}" \ --no_remove_unused_columns \ - --wandb_project $dataset_name \ + --wandb_project "$dataset_name-post-May-19" \ --wandb_run_name "Qwen2-0.5B-DPO-${dataset_name}-multi-gpu" diff --git a/jobs/dpo_ewc/dpo_ewc_long_piecewise_multi_gpu.sh b/jobs/dpo_ewc/dpo_ewc_long_piecewise_multi_gpu.sh index 39ae8eea..f2222897 100644 --- a/jobs/dpo_ewc/dpo_ewc_long_piecewise_multi_gpu.sh +++ b/jobs/dpo_ewc/dpo_ewc_long_piecewise_multi_gpu.sh @@ -5,7 +5,7 @@ #SBATCH --ntasks-per-node=4 # One task per GPU #SBATCH --cpus-per-task=6 #SBATCH --mem=64G -#SBATCH --time=24:00:00 +#SBATCH --time=1:00:00 #SBATCH --output=out/%x.%j.out # Include job name + job ID #SBATCH --error=out/%x.%j.err # Include job name + job ID #SBATCH --mail-type=ALL @@ -16,22 +16,22 @@ source .env dataset_name='aifgen-long-piecewise' -accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero2.yaml \ +accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ benchmarks/dpo_ewc/dpo_EWC_continual.py \ - --dataset_name $dataset_name \ + --dataset_name benchmarks/continual_data_debug.json \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ - --reward_model_path LifelongAlignment/Qwen2.5-0.5B-Instruct_${dataset_name}_REWARD \ - --learning_rate 5.0e-6 \ + --reward_model_path LifelongAlignment/Qwen2-0.5B-Instruct_${dataset_name}_REWARD \ + --learning_rate 1.0e-6 \ --num_train_epochs 4 \ - --per_device_train_batch_size 8 \ + --per_device_train_batch_size 16 \ --gradient_checkpointing \ --logging_steps 20 \ --eval_strategy steps \ --response_length 256 \ - --eval_steps 500 \ - --save_steps 500 \ + --eval_steps 50000 \ + --save_steps 300 \ --bf16 \ - --output_dir "$SCRATCH/projects/Qwen2-0.5B-DPO-EWC-${dataset_name}" \ + --output_dir "/home/s/shahradm/links/projects/aip-rrabba/shared/aifgen_experiments/Qwen2-0.5B-DPO-EWC-${dataset_name}" \ --no_remove_unused_columns \ - --wandb_project $dataset_name \ - --wandb_run_name "Qwen2-0.5B-DPO-EWC-${dataset_name}-multi-gpu" + --wandb_project "$dataset_name-post-May-19" \ + --wandb_run_name "Qwen2-0.5B-DPO-EWC-${dataset_name}-multi-gpu-debug" diff --git a/jobs/parallel_eval.sh b/jobs/parallel_eval.sh new file mode 100755 index 00000000..095e2ee3 --- /dev/null +++ b/jobs/parallel_eval.sh @@ -0,0 +1,88 @@ +# EVAL +datasets="aifgen-long-piecewise" +dataset_indices="0 1" +checkpoint_indices="300 600" + +for dataset_index in $dataset_indices +do + for dataset_name in $datasets + do + for checkpoint in $checkpoint_indices + do + job_name="${dataset_name}-${dataset_index}-${checkpoint}" + mkdir -p out/ + run_cmd="jobs/schedule_eval.sh ${dataset_name} ${dataset_index} ${checkpoint}" + sbatch_cmd="sbatch --job-name $job_name ${run_cmd}" + cmd="$sbatch_cmd" + echo -e "${cmd}" + ${cmd} + sleep 1 + done + done +done + +datasets="aifgen-domain-preference-shift" +dataset_indices="0 1 2 3" +checkpoint_indices="300 531" + +for dataset_index in $dataset_indices +do + for dataset_name in $datasets + do + for checkpoint in $checkpoint_indices + do + job_name="${dataset_name}-${dataset_index}-${checkpoint}" + mkdir -p out/ + run_cmd="jobs/schedule_eval.sh ${dataset_name} ${dataset_index} ${checkpoint}" + sbatch_cmd="sbatch --job-name $job_name ${run_cmd}" + cmd="$sbatch_cmd" + echo -e "${cmd}" + ${cmd} + sleep 1 + done + done +done + +datasets="aifgen-lipschitz" +dataset_indices="0 1 2" +checkpoint_indices="300 900 1063" + +for dataset_index in $dataset_indices +do + for dataset_name in $datasets + do + for checkpoint in $checkpoint_indices + do + job_name="${dataset_name}-${dataset_index}-${checkpoint}" + mkdir -p out/ + run_cmd="jobs/schedule_eval.sh ${dataset_name} ${dataset_index} ${checkpoint}" + sbatch_cmd="sbatch --job-name $job_name ${run_cmd}" + cmd="$sbatch_cmd" + echo -e "${cmd}" + ${cmd} + sleep 1 + done + done +done + +datasets="aifgen-piecewise-preference-shift" +dataset_indices="0 1 2 3 4 5 6 7" +checkpoint_indices="300 1200 2100" + +for dataset_index in $dataset_indices +do + for dataset_name in $datasets + do + for checkpoint in $checkpoint_indices + do + job_name="${dataset_name}-${dataset_index}-${checkpoint}" + mkdir -p out/ + run_cmd="jobs/schedule_eval.sh ${dataset_name} ${dataset_index} ${checkpoint}" + sbatch_cmd="sbatch --job-name $job_name ${run_cmd}" + cmd="$sbatch_cmd" + echo -e "${cmd}" + ${cmd} + sleep 1 + done + done +done \ No newline at end of file diff --git a/jobs/parallel_eval_cppo_dataset.sh b/jobs/parallel_eval_cppo_dataset.sh new file mode 100755 index 00000000..8a79b689 --- /dev/null +++ b/jobs/parallel_eval_cppo_dataset.sh @@ -0,0 +1,22 @@ +# EVAL +datasets="CPPO-RL" +dataset_indices="0 1" +checkpoint_indices="300 1800 2100" + +for dataset_index in $dataset_indices +do + for dataset_name in $datasets + do + for checkpoint in $checkpoint_indices + do + job_name="${dataset_name}-${dataset_index}-${checkpoint}" + mkdir -p out/ + run_cmd="jobs/schedule_eval_cppo_dataset.sh ${dataset_name} ${dataset_index} ${checkpoint}" + sbatch_cmd="sbatch --job-name $job_name ${run_cmd}" + cmd="$sbatch_cmd" + echo -e "${cmd}" + ${cmd} + sleep 1 + done + done +done \ No newline at end of file diff --git a/jobs/ppo/ppo_piecewise_multi_gpu.sh b/jobs/ppo/ppo_piecewise_multi_gpu.sh index 5078949f..f4ac032f 100644 --- a/jobs/ppo/ppo_piecewise_multi_gpu.sh +++ b/jobs/ppo/ppo_piecewise_multi_gpu.sh @@ -30,7 +30,7 @@ accelerate launch --config_file benchmarks/cppo/accelerate_configs/deepspeed_zer --response_length 256 \ --num_train_epochs 4 \ --gradient_checkpointing \ - --per_device_train_batch_size 8 \ + --per_device_train_batch_size 4 \ --logging_steps 10 \ --eval_strategy steps \ --eval_steps 200 \ diff --git a/jobs/schedule_eval.sh b/jobs/schedule_eval.sh new file mode 100644 index 00000000..8111f49d --- /dev/null +++ b/jobs/schedule_eval.sh @@ -0,0 +1,77 @@ +#!/bin/bash +#SBATCH --job-name=aif-gen-evaluation +#SBATCH --nodes=1 # Request 2 nodes +#SBATCH --gpus-per-node=h100:4 # Request 4 H100 GPUs per node +#SBATCH --ntasks-per-node=4 # One task per GPU +#SBATCH --cpus-per-task=6 +#SBATCH --mem=64G +#SBATCH --time=2:00:00 +#SBATCH --output=out/%x.%j.out # Include job name + job ID +#SBATCH --error=out/%x.%j.err # Include job name + job ID +#SBATCH --mail-type=ALL +#SBATCH --account=aip-rrabba +#SBATCH --mail-user=shahrad_m@icloud.com # Update with your email +source .env + +dataset_name=${1:-'aifgen-lipschitz'} +dataset_index=${2:-'0'} +checkpoint=${3:-'300'} + +#DPO on CPPO dataset - DIFFERENT FILE +# accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ +# benchmarks/parallel_eval_checkpoints.py \ +# --checkpoint_dir "/scratch/s/shahradm/${dataset_name}/Qwen2-0.5B-DPO-/dataset-${dataset_index}/checkpoint-${checkpoint}" \ +# --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ +# --wandb_run_name "test_eval_Qwen2-0.5B-DPO-rl256-v5-dataset-${dataset_index}-checkpoint-${checkpoint}" \ +# --reward_model_path "/lustre/orion/bif151/scratch/ivan.anokhin/AIF-Gen/${dataset_name}/Qwen2-0.5B-Reward-8gpus/Qwen2-0.5B-Instruct_${dataset_name}_REWARD" \ +# --wandb_project eval_${dataset_name} \ +# --learning_rate 0. \ +# --response_length 256 \ +# --dataset_name $dataset_name \ +# --per_device_eval_batch_size 16 \ +# --per_device_train_batch_size 1 \ +# --gradient_accumulation_steps 1 \ +# --bf16 \ +# --output_dir "/lustre/orion/bif151/scratch/ivan.anokhin/AIF-Gen/${dataset_name}/eval_Qwen2-0.5B-DPO-rl256-v5-8gpus-s${dataset_index}" \ +# --no_remove_unused_columns + + +#PPO - not on CPPO dataset +accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ + benchmarks/parallel_eval_checkpoints.py \ + --checkpoint_dir "/scratch/s/shahradm/Qwen2-0.5B-PPO-${dataset_name}/Qwen2-0.5B-Instruct_${dataset_name}_PPO_${dataset_index}/checkpoint-${checkpoint}" \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --wandb_run_name "test_eval_Qwen2-0.5B-PPO-rl256-v1-dataset-${dataset_index}-checkpoint-${checkpoint}" \ + --reward_model_path "LifelongAlignment/Qwen2-0.5B-Instruct_${dataset_name}_REWARD" \ + --wandb_project eval_${dataset_name}_post_may_19 \ + --learning_rate 0. \ + --response_length 256 \ + --dataset_name $dataset_name \ + --per_device_eval_batch_size 32 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --bf16 \ + --output_dir "/scratch/s/shahradm/${dataset_name}/eval_Qwen2-0.5B-PPO-8gpus-rl256-v1-s${dataset_index}" \ + --no_remove_unused_columns + +# PPO - on CPPO dataset - DIFFERENT FILE + +# CPPO - not on CPPO dataset +accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ + benchmarks/parallel_eval_checkpoints.py \ + --checkpoint_dir "/home/s/shahradm/links/projects/aip-rrabba/shared/aifgen_experiments/Qwen2-0.5B-CPPO-${dataset_name}/Qwen2-0.5B-Instruct_${dataset_name}_CPPO_${dataset_index}/checkpoint-${checkpoint}" \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --wandb_run_name "test_eval_Qwen2-0.5B-CPPO-rl256-v1-dataset-${dataset_index}-checkpoint-${checkpoint}" \ + --reward_model_path "LifelongAlignment/Qwen2-0.5B-Instruct_${dataset_name}_REWARD" \ + --wandb_project eval_${dataset_name}_post_may_19 \ + --learning_rate 0. \ + --response_length 256 \ + --dataset_name $dataset_name \ + --per_device_eval_batch_size 32 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --bf16 \ + --output_dir "/scratch/s/shahradm/${dataset_name}/eval_Qwen2-0.5B-CPPO-8gpus-rl256-v1-s${dataset_index}" \ + --no_remove_unused_columns + +# CPPO - on CPPO dataset - DIFFERENT FILE \ No newline at end of file diff --git a/jobs/schedule_eval_cppo_dataset.sh b/jobs/schedule_eval_cppo_dataset.sh new file mode 100644 index 00000000..ab8d5474 --- /dev/null +++ b/jobs/schedule_eval_cppo_dataset.sh @@ -0,0 +1,74 @@ +#!/bin/bash +#SBATCH --job-name=aif-gen-evaluation +#SBATCH --nodes=1 # Request 2 nodes +#SBATCH --gpus-per-node=h100:4 # Request 4 H100 GPUs per node +#SBATCH --ntasks-per-node=4 # One task per GPU +#SBATCH --cpus-per-task=6 +#SBATCH --mem=64G +#SBATCH --time=3:00:00 +#SBATCH --output=out/%x.%j.out # Include job name + job ID +#SBATCH --error=out/%x.%j.err # Include job name + job ID +#SBATCH --mail-type=ALL +#SBATCH --account=aip-rrabba +#SBATCH --mail-user=shahrad_m@icloud.com # Update with your email +source .env + +dataset_name=${1:-'CPPO-RL'} +dataset_index=${2:-'0'} +checkpoint=${3:-'300'} + +#DPO +accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ + benchmarks/parallel_eval_checkpoints.py \ + --checkpoint_dir "/scratch/s/shahradm/projects/Qwen2-0.5B-DPO-CPPO-RL/dataset-${dataset_index}/checkpoint-${checkpoint}" \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --wandb_run_name "test_eval_Qwen2-0.5B-DPO-rl256-v5-CPPO-${dataset_index}-checkpoint-${checkpoint}" \ + --reward_model_path "LifelongAlignment/Qwen2.5-0.5B-Instruct_CPPO_REWARD" \ + --wandb_project eval_${dataset_name}_post_may_19 \ + --learning_rate 0. \ + --response_length 256 \ + --dataset_name $dataset_name \ + --per_device_eval_batch_size 32 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --bf16 \ + --output_dir "/scratch/s/shahradm/${dataset_name}/eval_Qwen2-0.5B-DPO-8gpus-rl256-v1-s${dataset_index}" \ + --no_remove_unused_columns + + +#PPO +accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ + benchmarks/parallel_eval_checkpoints.py \ + --checkpoint_dir "/home/s/shahradm/Qwen2-0.5B-PPO-CPPO-RL/Qwen2-0.5B-Instruct_CPPO-RL_PPO_${dataset_index}/checkpoint-${checkpoint}" \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --wandb_run_name "test_eval_Qwen2-0.5B-PPO-rl256-v5-CPPO-${dataset_index}-checkpoint-${checkpoint}" \ + --reward_model_path "LifelongAlignment/Qwen2.5-0.5B-Instruct_CPPO_REWARD" \ + --wandb_project eval_${dataset_name}_post_may_19 \ + --learning_rate 0. \ + --response_length 256 \ + --dataset_name $dataset_name \ + --per_device_eval_batch_size 32 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --bf16 \ + --output_dir "/scratch/s/shahradm/${dataset_name}/eval_Qwen2-0.5B-PPO-8gpus-rl256-v1-s${dataset_index}" \ + --no_remove_unused_columns + + +# CPPO +accelerate launch --config_file benchmarks/dpo/accelerate_configs/deepspeed_zero3.yaml \ + benchmarks/parallel_eval_checkpoints.py \ + --checkpoint_dir "/home/s/shahradm/Qwen2-0.5B-CPPO-CPPO-RL/Qwen2-0.5B-Instruct_CPPO-RL_CPPO_${dataset_index}/checkpoint-${checkpoint}" \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --wandb_run_name "test_eval_Qwen2-0.5B-CPPO-rl256-v5-CPPO-${dataset_index}-checkpoint-${checkpoint}" \ + --reward_model_path "LifelongAlignment/Qwen2.5-0.5B-Instruct_CPPO_REWARD" \ + --wandb_project eval_${dataset_name}_post_may_19 \ + --learning_rate 0. \ + --response_length 256 \ + --dataset_name $dataset_name \ + --per_device_eval_batch_size 32 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --bf16 \ + --output_dir "/scratch/s/shahradm/${dataset_name}/eval_Qwen2-0.5B-CPPO-8gpus-rl256-v1-s${dataset_index}" \ + --no_remove_unused_columns