diff --git a/benchmarks/continual_eval_checkpoints.py b/benchmarks/continual_eval_checkpoints.py index f5ffe555..d7f3307b 100644 --- a/benchmarks/continual_eval_checkpoints.py +++ b/benchmarks/continual_eval_checkpoints.py @@ -114,6 +114,7 @@ def main( training_args.output_dir = f'{output_dir}/dataset-{i}' # using ContinualDPOTrainer for all pipelines (PPO, DPO, COPR, ..) only for evaluation + # ToDo: train_dataset is never used here, pass a dummy dataset to make it clear trainer = ContinualDPOTrainer( args=training_args, processing_class=tokenizer, diff --git a/benchmarks/dpo/dpo_continual.py b/benchmarks/dpo/dpo_continual.py index 7abcb5b8..8f7a844b 100644 --- a/benchmarks/dpo/dpo_continual.py +++ b/benchmarks/dpo/dpo_continual.py @@ -1,12 +1,9 @@ """Adaptation of the DPO TRL training script for continual learning.""" +import os + import torch import wandb as wb -from continual_dpo_trainer import ( - ContinualDPOArguments, - ContinualDPOConfig, - ContinualDPOTrainer, -) from datasets import Dataset from transformers import ( AutoModelForCausalLM, @@ -112,7 +109,7 @@ def main( # Load reward model if path provided if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path + f'_{str(i)}', num_labels=1 + os.path.join(training_args.reward_model_path, str(i)), num_labels=1 ) trainer = ContinualDPOTrainer( diff --git a/benchmarks/grpo/README.md b/benchmarks/grpo/README.md new file mode 100644 index 00000000..3bf45e0b --- /dev/null +++ b/benchmarks/grpo/README.md @@ -0,0 +1,32 @@ +# Adaptation of TRL for Continual Learning + +This repository adapts TRL for continual learning. The commands below use a consistent set of parameters that you’ve identified as working. You can use any of the entrypoints (uv run, accelerate launch, wandb) with the following commands. + +### Sync additional dependencies + +```sh +uv sync --group benchmarks.grpo +``` + +## Run GRPO + +### Using uv run (vanilla Python with PEFT) + +```sh +python benchmarks/grpo/grpo_continual.py \ + --dataset_name debug \ + --mock \ + --bf16 \ + --learning_rate 3e-6 \ + --output_dir "$SCRATCH/Qwen2-0.5B-GRPO-test" \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --logging_steps 20 \ + --per_device_eval_batch_size 2 \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --reward_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD +``` diff --git a/benchmarks/grpo/continual_grpo_trainer.py b/benchmarks/grpo/continual_grpo_trainer.py new file mode 100644 index 00000000..10a2864d --- /dev/null +++ b/benchmarks/grpo/continual_grpo_trainer.py @@ -0,0 +1,270 @@ +import functools +import inspect +import os +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from accelerate import Accelerator, PartialState +from datasets import Dataset +from torch.utils.data import DataLoader +from transformers import ( + BaseImageProcessor, + DataCollatorWithPadding, + FeatureExtractionMixin, + GenerationConfig, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, +) +from trl import GRPOConfig, GRPOTrainer, ScriptArguments +from trl.models.utils import unwrap_model_for_generation +from trl.trainer.utils import batch_generation, get_reward + + +@dataclass +class GRPOScriptArguments(ScriptArguments): + """Script arguments for the GRPO training script.""" + + dataset_name: str = field( + default='debug', + metadata={'help': 'The name or path of the continual dataset to use.'}, + ) + wandb_project: Optional[str] = field( + default='AIFGen-grpo-continual-test', + metadata={'help': 'Override the default WandB project name.'}, + ) + wandb_entity: Optional[str] = field( + default=None, + metadata={'help': 'The WandB entity (team) to use.'}, + ) + wandb_run_name: Optional[str] = field( + default=None, + metadata={'help': 'The WandB run name.'}, + ) + + def __post_init__(self) -> None: + if self.wandb_project: + os.environ['WANDB_PROJECT'] = self.wandb_project + if self.wandb_entity: + os.environ['WANDB_ENTITY'] = self.wandb_entity + + +@dataclass +class ContinualGRPOConfig(GRPOConfig): + reward_model_path: Optional[str] = field( + default='AIF-Gen/Qwen/Qwen2-0.5B-Reward/debug_REWARD', + metadata={ + 'help': 'Reward model id of a pretrained model hosted inside a model repo on huggingface.co or ' + 'local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`.' + }, + ) + mock: bool = field( + default=False, + metadata={'help': 'Whether to use mock dataset.'}, + ) + eval_greedy_policy: bool = field( + default=False, + metadata={'help': 'Whether to use greedy policy for evaluation.'}, + ) + dataset_num_proc: int = field( + default=1, + metadata={'help': 'Number of processes to use for dataset preprocessing.'}, + ) + response_length: int = field( + default=53, + metadata={ + 'help': 'Length of the response. Borrowed from PPOCOnfig and used only for evaluation.' + }, + ) + + +class ContinualGRPOTrainer(GRPOTrainer): + # Shared accelerator instance across all trainer instances + shared_accelerator: Optional[Accelerator] = None + accelerator: Accelerator # now non-optional after creation + + def __init__( + self, + args: Optional[ContinualGRPOConfig] = None, + processing_class: Optional[ + Union[ + PreTrainedTokenizerBase, + BaseImageProcessor, + FeatureExtractionMixin, + ProcessorMixin, + ] + ] = None, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + reward_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + peft_config: Optional[dict] = None, + ): + # catching this here to test our implementation of the configs + if args is None: + raise ValueError('`args` cannot be None') + + if ContinualGRPOTrainer.shared_accelerator is None: + ContinualGRPOTrainer.shared_accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps + ) + self.accelerator = ContinualGRPOTrainer.shared_accelerator + + super().__init__( + args=args, + processing_class=processing_class, + model=model, + reward_funcs=reward_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + ) + + self.accelerator = ( + ContinualGRPOTrainer.shared_accelerator + ) # turn the accelerator back to the shared one + + self.eval_policy_dataset = self.preprocess_policy_dataset(eval_dataset) + # using the same data_collator as in PPO trainer + data_collator = DataCollatorWithPadding(self.processing_class) + self.eval_policy_dataloader = DataLoader( + self.eval_policy_dataset, + batch_size=self.args.per_device_eval_batch_size, + collate_fn=data_collator, + drop_last=True, + ) # no need to shuffle eval dataset + # Ensure accelerator is available + # TODO remove the check once ruff issues are resolved + # fmt: off + assert self.accelerator is not None, 'Accelerator must be assigned before prepare()' + # fmt: on + self.eval_policy_dataloader = self.accelerator.prepare( + self.eval_policy_dataloader + ) + + def create_accelerator_and_postprocess(self) -> None: + # Only initialize a new Accelerator if one does not exist + if ContinualGRPOTrainer.shared_accelerator is None: + super().create_accelerator_and_postprocess() + ContinualGRPOTrainer.shared_accelerator = self.accelerator + else: + # Reuse the shared accelerator + self.accelerator = ContinualGRPOTrainer.shared_accelerator + self.gather_function = self.accelerator.gather_for_metrics + if ( + 'use_gather_object' + in inspect.signature(self.gather_function).parameters.keys() + ): + self.gather_function = functools.partial( + self.gather_function, + use_gather_object=self.args.eval_use_gather_object, + ) + self.is_deepspeed_enabled = ( + getattr(self.accelerator.state, 'deepspeed_plugin', None) is not None + ) + self.is_fsdp_enabled = ( + getattr(self.accelerator.state, 'fsdp_plugin', None) is not None + ) + + def preprocess_policy_dataset(self, dataset: Dataset) -> Dataset: + # The code is from TRL PPO script https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py + dataset_text_field = 'prompt' + + def tokenize(element: dict) -> dict[str, list[int]]: + outputs = self.processing_class( + element[dataset_text_field], + padding=False, + ) + return {'input_ids': outputs['input_ids']} + + def prepare_dataset(ds: Dataset) -> Dataset: + return ds.map( + tokenize, + batched=True, + remove_columns=ds.column_names, + num_proc=self.args.dataset_num_proc, + ) + + # Compute only on main process for faster data processing. + with PartialState().local_main_process_first(): + dataset = prepare_dataset(dataset) + return dataset + + def evaluate_policy(self) -> dict: + """Evaluate the policy using the evaluation policy dataloader. + + Returns: + dict: A dictionary containing evaluation metrics. + """ + # The code is heavily based on the training loop of TRL PPOTrainer function https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L677 + mode = self.model.training + # there is no self.model? TODO + self.model.eval() + eval_metrics = defaultdict(list) + processing_class = self.processing_class + if self.args.eval_greedy_policy: + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + top_k=None, + do_sample=False, + ) + else: + # Using the same hyperpaprams as during PPO training + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(self.args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + with torch.no_grad(): + if self.eval_policy_dataloader is not None: + for batch in self.eval_policy_dataloader: + query = batch['input_ids'].to(self.accelerator.device) + context_length = query.shape[1] + with unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=None, + ) as unwrapped_model: + query_response, _ = batch_generation( + unwrapped_model, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + postprocessed_query_response = torch.cat( + (query, postprocessed_response), 1 + ) + _, score, _ = get_reward( + # self.reward_model, + self.reward_funcs[0], + postprocessed_query_response, + processing_class.pad_token_id, + context_length, + ) + eval_metrics['score'].extend( + self.accelerator.gather_for_metrics(score).float().cpu().numpy() + ) + self.model.train(mode) + return {'eval_' + k: float(np.mean(v)) for k, v in eval_metrics.items()} + + def log( + self, logs: dict[str, Union[float, dict]], start_time: Optional[float] = None + ) -> None: + """Log `logs` on the various objects watching training, including stored metrics.""" + train_eval = 'train' if 'loss' in logs else 'eval' + print(f'Logging {train_eval} metrics...') + if train_eval == 'eval': + print('Computing policy metrics...') + eval_policy_metrics = self.evaluate_policy() + logs.update(eval_policy_metrics) + return super().log(logs, start_time) diff --git a/benchmarks/grpo/grpo_continual.py b/benchmarks/grpo/grpo_continual.py new file mode 100644 index 00000000..1713bc21 --- /dev/null +++ b/benchmarks/grpo/grpo_continual.py @@ -0,0 +1,147 @@ +# Adaptation of the GRPO TRL training script for continual learning. + +import argparse +import os + +import torch +import wandb as wb +from continual_grpo_trainer import ( + ContinualGRPOConfig, + ContinualGRPOTrainer, + GRPOScriptArguments, +) +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, +) +from trl import ( + ModelConfig, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + +from benchmarks.dataloading import init_continual_dataset + + +# The code is based on TRL GRPO script https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py +def main(script_args, training_args, model_args): + # Determine torch dtype and quantization configs + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ['auto', None] + else getattr(torch, model_args.torch_dtype) + ) + if script_args.wandb_run_name is not None: + training_args.run_name = script_args.wandb_run_name + + quantization_config = get_quantization_config(model_args) + + # Model & Tokenizer Setup + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + peft_config = get_peft_config(model_args) + + # Load tokenizer and set chat template if needed + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + # Initialize continual dataset + continual_dataset: list[dict[str, Dataset]] = init_continual_dataset( + script_args.dataset_name, + mock=training_args.mock, + tokenizer=tokenizer, + tools=None, + ) + output_dir = training_args.output_dir + + # Validate reward model paths if provided + if training_args.reward_model_path is not None: + for i, _ in enumerate(continual_dataset): + reward_path = os.path.join(training_args.reward_model_path, str(i)) + if not os.path.exists(reward_path): + raise FileNotFoundError( + f'Reward model not found for dataset {i} at {reward_path}' + ) + + # Task Loop + for i, dataset in enumerate(continual_dataset): + current_dataset_name: str = f'dataset-{i}' + training_args.output_dir = f'{output_dir}/dataset-{i}' + + # Reward model + reward_model = AutoModelForSequenceClassification.from_pretrained( + f'{training_args.reward_model_path}/{i}', num_labels=1 + ) + + # Initialize the GRPO trainer + trainer = ContinualGRPOTrainer( + args=training_args, + processing_class=tokenizer, + model=model, + reward_model=reward_model, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], + peft_config=peft_config, + ) + + # Train + trainer.train() + + # Evaluate + metrics = trainer.evaluate_policy() + print(f'eval/dataset/{i}') + metrics['dataset'] = i + trainer.log_metrics(f'eval/dataset/{i}', metrics) + trainer.save_metrics(f'eval', metrics) + wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] + wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined] + + # Save and push to hub + trainer.save_model(training_args.output_dir + f'/dataset-{i}') + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name + f'/dataset-{i}') + + # If using DeepSpeed through Accelerate, tear down the engine after training. + if hasattr(trainer, 'deepspeed') and trainer.deepspeed is not None: + # Remove reference to the DeepSpeed engine to allow proper cleanup. + del trainer.deepspeed + # Free cached GPU memory. + torch.cuda.empty_cache() + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (GRPOScriptArguments, ContinualGRPOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser( + 'grpo', help='Run the GRPO training script', dataclass_types=dataclass_types + ) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == '__main__': + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) diff --git a/benchmarks/ppo/continual_ppo_trainer.py b/benchmarks/ppo/continual_ppo_trainer.py index a2d4a0a8..b4d3d758 100644 --- a/benchmarks/ppo/continual_ppo_trainer.py +++ b/benchmarks/ppo/continual_ppo_trainer.py @@ -142,7 +142,7 @@ def __init__( peft_config, ) - # No need for anything else as PPO itself is already set up with the reward model + # # No need for anything else as PPO itself is already set up with the reward model self.accelerator = ( ContinualPPOTrainer.shared_accelerator ) # turn the accelerator back to the shared one