From 1ce21ce91815d70f41867959207ca03424f37b89 Mon Sep 17 00:00:00 2001 From: avecplezir Date: Thu, 13 Mar 2025 17:22:10 -0400 Subject: [PATCH 1/3] grpo init --- benchmarks/continual_eval_checkpoints.py | 1 + benchmarks/dpo/dpo_continual.py | 7 +- benchmarks/grpo/README.md | 47 +++++++ benchmarks/grpo/continual_grpo_trainer.py | 143 ++++++++++++++++++++++ benchmarks/grpo/grpo_continual.py | 136 ++++++++++++++++++++ 5 files changed, 329 insertions(+), 5 deletions(-) create mode 100644 benchmarks/grpo/README.md create mode 100644 benchmarks/grpo/continual_grpo_trainer.py create mode 100644 benchmarks/grpo/grpo_continual.py diff --git a/benchmarks/continual_eval_checkpoints.py b/benchmarks/continual_eval_checkpoints.py index f5ffe555..d7f3307b 100644 --- a/benchmarks/continual_eval_checkpoints.py +++ b/benchmarks/continual_eval_checkpoints.py @@ -114,6 +114,7 @@ def main( training_args.output_dir = f'{output_dir}/dataset-{i}' # using ContinualDPOTrainer for all pipelines (PPO, DPO, COPR, ..) only for evaluation + # ToDo: train_dataset is never used here, pass a dummy dataset to make it clear trainer = ContinualDPOTrainer( args=training_args, processing_class=tokenizer, diff --git a/benchmarks/dpo/dpo_continual.py b/benchmarks/dpo/dpo_continual.py index 7abcb5b8..6e1f3639 100644 --- a/benchmarks/dpo/dpo_continual.py +++ b/benchmarks/dpo/dpo_continual.py @@ -1,12 +1,9 @@ """Adaptation of the DPO TRL training script for continual learning.""" +import os + import torch import wandb as wb -from continual_dpo_trainer import ( - ContinualDPOArguments, - ContinualDPOConfig, - ContinualDPOTrainer, -) from datasets import Dataset from transformers import ( AutoModelForCausalLM, diff --git a/benchmarks/grpo/README.md b/benchmarks/grpo/README.md new file mode 100644 index 00000000..0b025c51 --- /dev/null +++ b/benchmarks/grpo/README.md @@ -0,0 +1,47 @@ +# Adaptation of TRL for Continual Learning + +This repository adapts TRL for continual learning. The commands below use a consistent set of parameters that you’ve identified as working. You can use any of the entrypoints (uv run, accelerate launch, wandb) with the following commands. + +### Sync additional dependencies + +```sh +uv sync --group benchmarks.ppo +``` + +## Run GRPO + +### Using uv run (vanilla Python with PEFT) + +```sh +uv run benchmarks/grpo/grpo_continual.py \ + --dataset_name benchmarks/continual_data_debug.json \ + --sft_model_path Qwen/Qwen2-0.5B-Instruct \ + --reward_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD \ + --learning_rate 5.0e-6 \ + --num_train_epochs 1 \ + --gradient_accumulation_steps 8 \ + --gradient_checkpointing \ + --logging_steps 20 \ + --eval_strategy steps \ + --eval_steps 20 \ + --save_steps 20 \ + --bf16 \ + --output_dir "$SCRATCH/Qwen2-0.5B-PPO-test" \ + --no_remove_unused_columns \ + --use_peft \ + --lora_r 32 \ + --lora_alpha 16 +``` + +```sh +python benchmarks/grpo/grpo_continual.py \ + --dataset_name debug \ + --dataset_train_split descriptiveness \ + --learning_rate 3e-6 \ + --output_dir models/minimal/grpo \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 1 \ + --use_peft \ + --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ + --reward_model_path /home/mila/i/ivan.anokhin/AIF-Gen/Qwen/Qwen2-0.5B-Reward/debug +``` diff --git a/benchmarks/grpo/continual_grpo_trainer.py b/benchmarks/grpo/continual_grpo_trainer.py new file mode 100644 index 00000000..8a7df633 --- /dev/null +++ b/benchmarks/grpo/continual_grpo_trainer.py @@ -0,0 +1,143 @@ +import functools +import inspect +import os +from dataclasses import dataclass, field +from typing import Optional, Union + +import torch.nn as nn +from accelerate import Accelerator +from datasets import Dataset +from transformers import ( + BaseImageProcessor, + FeatureExtractionMixin, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, +) +from trl import GRPOConfig, ScriptArguments +from trl.trainer.ppo_trainer import PPOTrainer + + +@dataclass +class GRPOScriptArguments(ScriptArguments): + """Script arguments for the GRPO training script. + + Args: + reward_model_path (`str` or `None`): + Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a + directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`]. + """ + + reward_model_path: Optional[str] = field( + default=None, + metadata={ + 'help': 'Reward model id of a pretrained model hosted inside a model repo on huggingface.co or ' + 'local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`.' + }, + ) + dataset_name: str = field( + default='debug', + metadata={'help': 'The name or path of the continual dataset to use.'}, + ) + wandb_project: Optional[str] = field( + default='AIFGen-ppo-continual-test', + metadata={'help': 'Override the default WandB project name.'}, + ) + wandb_entity: Optional[str] = field( + default=None, + metadata={'help': 'The WandB entity (team) to use.'}, + ) + wandb_run_name: Optional[str] = field( + default=None, + metadata={'help': 'The WandB run name.'}, + ) + + def __post_init__(self) -> None: + if self.wandb_project: + os.environ['WANDB_PROJECT'] = self.wandb_project + if self.wandb_entity: + os.environ['WANDB_ENTITY'] = self.wandb_entity + + +@dataclass +class ContinualGRPOConfig(GRPOConfig): + mock: bool = field( + default=False, + metadata={'help': 'Whether to use mock dataset.'}, + ) + eval_greedy_policy: bool = field( + default=False, + metadata={'help': 'Whether to use greedy policy for evaluation.'}, + ) + + +class ContinualGRPOTrainer(PPOTrainer): + # Shared accelerator instance across all trainer instances + shared_accelerator: Optional[Accelerator] = None + accelerator: Accelerator # now non-optional after creation + + def __init__( + self, + args: Optional[ContinualGRPOConfig] = None, + processing_class: Optional[ + Union[ + PreTrainedTokenizerBase, + BaseImageProcessor, + FeatureExtractionMixin, + ProcessorMixin, + ] + ] = None, + model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + reward_model: Optional[Union[PreTrainedModel, nn.Module, str]] = None, + train_dataset: Optional[Dataset] = None, + eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, + peft_config: Optional[dict] = None, + ): + # catching this here to test our implementation of the configs + if args is None: + raise ValueError('`args` cannot be None') + + if ContinualGRPOTrainer.shared_accelerator is None: + ContinualGRPOTrainer.shared_accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps + ) + self.accelerator = ContinualGRPOTrainer.shared_accelerator + + super().__init__( + args=args, + processing_class=processing_class, + model=model, + reward_funcs=reward_model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + peft_config=peft_config, + ) + + # No need for anything else as PPO itself is already set up with the reward model + self.accelerator = ( + ContinualGRPOTrainer.shared_accelerator + ) # turn the accelerator back to the shared one + + def create_accelerator_and_postprocess(self) -> None: + # Only initialize a new Accelerator if one does not exist + if ContinualGRPOTrainer.shared_accelerator is None: + super().create_accelerator_and_postprocess() + ContinualGRPOTrainer.shared_accelerator = self.accelerator + else: + # Reuse the shared accelerator + self.accelerator = ContinualGRPOTrainer.shared_accelerator + self.gather_function = self.accelerator.gather_for_metrics + if ( + 'use_gather_object' + in inspect.signature(self.gather_function).parameters.keys() + ): + self.gather_function = functools.partial( + self.gather_function, + use_gather_object=self.args.eval_use_gather_object, + ) + self.is_deepspeed_enabled = ( + getattr(self.accelerator.state, 'deepspeed_plugin', None) is not None + ) + self.is_fsdp_enabled = ( + getattr(self.accelerator.state, 'fsdp_plugin', None) is not None + ) diff --git a/benchmarks/grpo/grpo_continual.py b/benchmarks/grpo/grpo_continual.py new file mode 100644 index 00000000..bd7b18e0 --- /dev/null +++ b/benchmarks/grpo/grpo_continual.py @@ -0,0 +1,136 @@ +# Adaptation of the GRPO TRL training script for continual learning. + +import argparse +import os + +import torch +from continual_grpo_trainer import ( + ContinualGRPOConfig, + ContinualGRPOTrainer, + GRPOScriptArguments, +) +from datasets import Dataset +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, +) +from trl import ( + ModelConfig, + TrlParser, + get_kbit_device_map, + get_peft_config, + get_quantization_config, +) +from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE + +from benchmarks.dataloading import init_continual_dataset + + +# The code is based on TRL DPO script https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py +def main(script_args, training_args, model_args): + # Determine torch dtype and quantization configs + torch_dtype = ( + model_args.torch_dtype + if model_args.torch_dtype in ['auto', None] + else getattr(torch, model_args.torch_dtype) + ) + if script_args.wandb_run_name is not None: + training_args.run_name = script_args.wandb_run_name + + quantization_config = get_quantization_config(model_args) + + # Model & Tokenizer Setup + model_kwargs = dict( + revision=model_args.model_revision, + attn_implementation=model_args.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + model = AutoModelForCausalLM.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=model_args.trust_remote_code, + **model_kwargs, + ) + peft_config = get_peft_config(model_args) + + # Load tokenizer and set chat template if needed + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE + + # Initialize continual dataset + continual_dataset: list[dict[str, Dataset]] = init_continual_dataset( + script_args.dataset_name, + mock=training_args.mock, + tokenizer=tokenizer, + tools=training_args.tools, + ) + output_dir = training_args.output_dir + + # Validate reward model paths if provided + if training_args.reward_model_path is not None: + for i, _ in enumerate(continual_dataset): + reward_path = os.path.join(training_args.reward_model_path, str(i)) + if not os.path.exists(reward_path): + raise FileNotFoundError( + f'Reward model not found for dataset {i} at {reward_path}' + ) + + # Task Loop + for i, dataset in enumerate(continual_dataset): + # Dataset + # dataset = dataset[script_args.dataset_train_split] + # dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) + # train_dataset = dataset.select(range(len(dataset) - eval_samples)) + # eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) + + # Reward model + reward_model = AutoModelForSequenceClassification.from_pretrained( + f'{script_args.reward_model_path}/{i}', num_labels=1 + ) + training_args.output_dir = f'{output_dir}/dataset-{i}' + + # Initialize the GRPO trainer + trainer = ContinualGRPOTrainer( + args=training_args, + processing_class=tokenizer, + model=model, + reward_model=reward_model, + train_dataset=dataset[script_args.dataset_train_split], + eval_dataset=dataset[script_args.dataset_test_split], + peft_config=peft_config, + ) + + # Train and push the model to the Hub + trainer.train() + + # ToDo: GRPOTrainer doesn't have a evaluate method, so we need to implement it to track the performance at each dataset + + # Save and push to hub + trainer.save_model(training_args.output_dir + f'/dataset-{i}') + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name + f'/dataset-{i}') + + +def make_parser(subparsers: argparse._SubParsersAction = None): + dataclass_types = (GRPOScriptArguments, ContinualGRPOConfig, ModelConfig) + if subparsers is not None: + parser = subparsers.add_parser( + 'grpo', help='Run the GRPO training script', dataclass_types=dataclass_types + ) + else: + parser = TrlParser(dataclass_types) + return parser + + +if __name__ == '__main__': + parser = make_parser() + script_args, training_args, model_args = parser.parse_args_and_config() + main(script_args, training_args, model_args) From d9cfdac65251e4d9e4e57614b907410d00f30107 Mon Sep 17 00:00:00 2001 From: avecplezir Date: Thu, 13 Mar 2025 21:55:49 -0400 Subject: [PATCH 2/3] grpo fixes --- benchmarks/dpo/dpo_continual.py | 2 +- benchmarks/grpo/README.md | 37 ++--- benchmarks/grpo/continual_grpo_trainer.py | 166 +++++++++++++++++++--- benchmarks/grpo/grpo_continual.py | 31 ++-- benchmarks/ppo/continual_ppo_trainer.py | 2 +- 5 files changed, 181 insertions(+), 57 deletions(-) diff --git a/benchmarks/dpo/dpo_continual.py b/benchmarks/dpo/dpo_continual.py index 6e1f3639..8f7a844b 100644 --- a/benchmarks/dpo/dpo_continual.py +++ b/benchmarks/dpo/dpo_continual.py @@ -109,7 +109,7 @@ def main( # Load reward model if path provided if training_args.reward_model_path is not None: reward_model = AutoModelForSequenceClassification.from_pretrained( - training_args.reward_model_path + f'_{str(i)}', num_labels=1 + os.path.join(training_args.reward_model_path, str(i)), num_labels=1 ) trainer = ContinualDPOTrainer( diff --git a/benchmarks/grpo/README.md b/benchmarks/grpo/README.md index 0b025c51..3bf45e0b 100644 --- a/benchmarks/grpo/README.md +++ b/benchmarks/grpo/README.md @@ -5,43 +5,28 @@ This repository adapts TRL for continual learning. The commands below use a cons ### Sync additional dependencies ```sh -uv sync --group benchmarks.ppo +uv sync --group benchmarks.grpo ``` ## Run GRPO ### Using uv run (vanilla Python with PEFT) -```sh -uv run benchmarks/grpo/grpo_continual.py \ - --dataset_name benchmarks/continual_data_debug.json \ - --sft_model_path Qwen/Qwen2-0.5B-Instruct \ - --reward_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD \ - --learning_rate 5.0e-6 \ - --num_train_epochs 1 \ - --gradient_accumulation_steps 8 \ - --gradient_checkpointing \ - --logging_steps 20 \ - --eval_strategy steps \ - --eval_steps 20 \ - --save_steps 20 \ - --bf16 \ - --output_dir "$SCRATCH/Qwen2-0.5B-PPO-test" \ - --no_remove_unused_columns \ - --use_peft \ - --lora_r 32 \ - --lora_alpha 16 -``` - ```sh python benchmarks/grpo/grpo_continual.py \ --dataset_name debug \ - --dataset_train_split descriptiveness \ + --mock \ + --bf16 \ --learning_rate 3e-6 \ - --output_dir models/minimal/grpo \ - --per_device_train_batch_size 2 \ + --output_dir "$SCRATCH/Qwen2-0.5B-GRPO-test" \ + --num_train_epochs 1 \ + --per_device_train_batch_size 1 \ --gradient_accumulation_steps 1 \ + --logging_steps 20 \ + --per_device_eval_batch_size 2 \ --use_peft \ + --lora_r 32 \ + --lora_alpha 16 \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ - --reward_model_path /home/mila/i/ivan.anokhin/AIF-Gen/Qwen/Qwen2-0.5B-Reward/debug + --reward_model_path Shahradmz/Qwen2-0.5B-Instruct_continual_data_debug_REWARD ``` diff --git a/benchmarks/grpo/continual_grpo_trainer.py b/benchmarks/grpo/continual_grpo_trainer.py index 8a7df633..5895d4bf 100644 --- a/benchmarks/grpo/continual_grpo_trainer.py +++ b/benchmarks/grpo/continual_grpo_trainer.py @@ -1,46 +1,40 @@ import functools import inspect import os +from collections import defaultdict from dataclasses import dataclass, field from typing import Optional, Union +import numpy as np +import torch import torch.nn as nn -from accelerate import Accelerator +from accelerate import Accelerator, PartialState from datasets import Dataset +from torch.utils.data import DataLoader from transformers import ( BaseImageProcessor, + DataCollatorWithPadding, FeatureExtractionMixin, + GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, ) -from trl import GRPOConfig, ScriptArguments -from trl.trainer.ppo_trainer import PPOTrainer +from trl import GRPOConfig, GRPOTrainer, ScriptArguments +from trl.models.utils import unwrap_model_for_generation +from trl.trainer.utils import batch_generation, get_reward @dataclass class GRPOScriptArguments(ScriptArguments): - """Script arguments for the GRPO training script. + """Script arguments for the GRPO training script.""" - Args: - reward_model_path (`str` or `None`): - Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a - directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`]. - """ - - reward_model_path: Optional[str] = field( - default=None, - metadata={ - 'help': 'Reward model id of a pretrained model hosted inside a model repo on huggingface.co or ' - 'local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`.' - }, - ) dataset_name: str = field( default='debug', metadata={'help': 'The name or path of the continual dataset to use.'}, ) wandb_project: Optional[str] = field( - default='AIFGen-ppo-continual-test', + default='AIFGen-grpo-continual-test', metadata={'help': 'Override the default WandB project name.'}, ) wandb_entity: Optional[str] = field( @@ -61,6 +55,13 @@ def __post_init__(self) -> None: @dataclass class ContinualGRPOConfig(GRPOConfig): + reward_model_path: Optional[str] = field( + default='AIF-Gen/Qwen/Qwen2-0.5B-Reward/debug_REWARD', + metadata={ + 'help': 'Reward model id of a pretrained model hosted inside a model repo on huggingface.co or ' + 'local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`.' + }, + ) mock: bool = field( default=False, metadata={'help': 'Whether to use mock dataset.'}, @@ -69,9 +70,19 @@ class ContinualGRPOConfig(GRPOConfig): default=False, metadata={'help': 'Whether to use greedy policy for evaluation.'}, ) + dataset_num_proc: int = field( + default=1, + metadata={'help': 'Number of processes to use for dataset preprocessing.'}, + ) + response_length: int = field( + default=53, + metadata={ + 'help': 'Length of the response. Borrowed from PPOCOnfig and used only for evaluation.' + }, + ) -class ContinualGRPOTrainer(PPOTrainer): +class ContinualGRPOTrainer(GRPOTrainer): # Shared accelerator instance across all trainer instances shared_accelerator: Optional[Accelerator] = None accelerator: Accelerator # now non-optional after creation @@ -118,6 +129,24 @@ def __init__( ContinualGRPOTrainer.shared_accelerator ) # turn the accelerator back to the shared one + self.eval_policy_dataset = self.preprocess_policy_dataset(eval_dataset) + # using the same data_collator as in PPO trainer + data_collator = DataCollatorWithPadding(self.processing_class) + self.eval_policy_dataloader = DataLoader( + self.eval_policy_dataset, + batch_size=self.args.per_device_eval_batch_size, + collate_fn=data_collator, + drop_last=True, + ) # no need to shuffle eval dataset + # Ensure accelerator is available + # TODO remove the check once ruff issues are resolved + # fmt: off + assert self.accelerator is not None, 'Accelerator must be assigned before prepare()' + # fmt: on + self.eval_policy_dataloader = self.accelerator.prepare( + self.eval_policy_dataloader + ) + def create_accelerator_and_postprocess(self) -> None: # Only initialize a new Accelerator if one does not exist if ContinualGRPOTrainer.shared_accelerator is None: @@ -141,3 +170,102 @@ def create_accelerator_and_postprocess(self) -> None: self.is_fsdp_enabled = ( getattr(self.accelerator.state, 'fsdp_plugin', None) is not None ) + + def preprocess_policy_dataset(self, dataset: Dataset) -> Dataset: + # The code is from TRL PPO script https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py + dataset_text_field = 'prompt' + + def tokenize(element: dict) -> dict[str, list[int]]: + outputs = self.processing_class( + element[dataset_text_field], + padding=False, + ) + return {'input_ids': outputs['input_ids']} + + def prepare_dataset(ds: Dataset) -> Dataset: + return ds.map( + tokenize, + batched=True, + remove_columns=ds.column_names, + num_proc=self.args.dataset_num_proc, + ) + + # Compute only on main process for faster data processing. + with PartialState().local_main_process_first(): + dataset = prepare_dataset(dataset) + return dataset + + def evaluate_policy(self) -> dict: + """Evaluate the policy using the evaluation policy dataloader. + + Returns: + dict: A dictionary containing evaluation metrics. + """ + # The code is heavily based on the training loop of TRL PPOTrainer function https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py#L677 + mode = self.model.training + # there is no self.model? TODO + self.model.eval() + eval_metrics = defaultdict(list) + processing_class = self.processing_class + if self.args.eval_greedy_policy: + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + top_k=None, + do_sample=False, + ) + else: + # Using the same hyperpaprams as during PPO training + generation_config = GenerationConfig( + max_new_tokens=self.args.response_length, + temperature=(self.args.temperature + 1e-7), + top_k=0.0, + top_p=1.0, + do_sample=True, + ) + + with torch.no_grad(): + if self.eval_policy_dataloader is not None: + for batch in self.eval_policy_dataloader: + query = batch['input_ids'].to(self.accelerator.device) + context_length = query.shape[1] + with unwrap_model_for_generation( + self.model, + self.accelerator, + gather_deepspeed3_params=None, + ) as unwrapped_model: + query_response, _ = batch_generation( + unwrapped_model, + query, + query.shape[0], + processing_class.pad_token_id, + generation_config, + ) + response = query_response[:, context_length:] + postprocessed_response = response + postprocessed_query_response = torch.cat( + (query, postprocessed_response), 1 + ) + _, score, _ = get_reward( + # self.reward_model, + self.reward_funcs[0], + postprocessed_query_response, + processing_class.pad_token_id, + context_length, + ) + eval_metrics['score'].extend( + self.accelerator.gather_for_metrics(score).float().cpu().numpy() + ) + self.model.train(mode) + return {'eval_' + k: float(np.mean(v)) for k, v in eval_metrics.items()} + + def log( + self, logs: dict[str, Union[float, dict]], start_time: Optional[float] = None + ) -> None: + """Log `logs` on the various objects watching training, including stored metrics.""" + train_eval = 'train' if 'loss' in logs else 'eval' + print(f'Logging {train_eval} metrics...') + if train_eval == 'eval': + print('Computing policy metrics...') + eval_policy_metrics = self.evaluate_policy() + logs.update(eval_policy_metrics) + return super().log(logs, start_time) diff --git a/benchmarks/grpo/grpo_continual.py b/benchmarks/grpo/grpo_continual.py index bd7b18e0..e885ec4f 100644 --- a/benchmarks/grpo/grpo_continual.py +++ b/benchmarks/grpo/grpo_continual.py @@ -4,6 +4,7 @@ import os import torch +import wandb as wb from continual_grpo_trainer import ( ContinualGRPOConfig, ContinualGRPOTrainer, @@ -70,7 +71,7 @@ def main(script_args, training_args, model_args): script_args.dataset_name, mock=training_args.mock, tokenizer=tokenizer, - tools=training_args.tools, + tools=None, ) output_dir = training_args.output_dir @@ -85,17 +86,13 @@ def main(script_args, training_args, model_args): # Task Loop for i, dataset in enumerate(continual_dataset): - # Dataset - # dataset = dataset[script_args.dataset_train_split] - # dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer}) - # train_dataset = dataset.select(range(len(dataset) - eval_samples)) - # eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset))) + current_dataset_name: str = f'dataset-{i}' + training_args.output_dir = f'{output_dir}/dataset-{i}' # Reward model reward_model = AutoModelForSequenceClassification.from_pretrained( - f'{script_args.reward_model_path}/{i}', num_labels=1 + f'{training_args.reward_model_path}/{i}', num_labels=1 ) - training_args.output_dir = f'{output_dir}/dataset-{i}' # Initialize the GRPO trainer trainer = ContinualGRPOTrainer( @@ -108,16 +105,30 @@ def main(script_args, training_args, model_args): peft_config=peft_config, ) - # Train and push the model to the Hub + # Train trainer.train() - # ToDo: GRPOTrainer doesn't have a evaluate method, so we need to implement it to track the performance at each dataset + # Evaluate + metrics = trainer.evaluate_policy() + print(f'eval/dataset/{i}') + metrics['dataset'] = i + trainer.log_metrics(f'eval/dataset/{i}', metrics) + trainer.save_metrics(f'eval', metrics) + wb.log({'eval': {'last': metrics}}) # type: ignore[attr-defined] + wb.log({f'task/{current_dataset_name}/last': metrics}) # type: ignore[attr-defined] # Save and push to hub trainer.save_model(training_args.output_dir + f'/dataset-{i}') if training_args.push_to_hub: trainer.push_to_hub(dataset_name=script_args.dataset_name + f'/dataset-{i}') + # If using DeepSpeed through Accelerate, tear down the engine after training. + if hasattr(trainer, 'deepspeed') and trainer.deepspeed is not None: + # Remove reference to the DeepSpeed engine to allow proper cleanup. + del trainer.deepspeed + # Free cached GPU memory. + torch.cuda.empty_cache() + def make_parser(subparsers: argparse._SubParsersAction = None): dataclass_types = (GRPOScriptArguments, ContinualGRPOConfig, ModelConfig) diff --git a/benchmarks/ppo/continual_ppo_trainer.py b/benchmarks/ppo/continual_ppo_trainer.py index a2d4a0a8..b4d3d758 100644 --- a/benchmarks/ppo/continual_ppo_trainer.py +++ b/benchmarks/ppo/continual_ppo_trainer.py @@ -142,7 +142,7 @@ def __init__( peft_config, ) - # No need for anything else as PPO itself is already set up with the reward model + # # No need for anything else as PPO itself is already set up with the reward model self.accelerator = ( ContinualPPOTrainer.shared_accelerator ) # turn the accelerator back to the shared one From a9a5a4a8c03c769dbc4070cea018595b2fd1bc72 Mon Sep 17 00:00:00 2001 From: avecplezir Date: Thu, 13 Mar 2025 22:01:14 -0400 Subject: [PATCH 3/3] fixes --- benchmarks/grpo/continual_grpo_trainer.py | 1 - benchmarks/grpo/grpo_continual.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmarks/grpo/continual_grpo_trainer.py b/benchmarks/grpo/continual_grpo_trainer.py index 5895d4bf..10a2864d 100644 --- a/benchmarks/grpo/continual_grpo_trainer.py +++ b/benchmarks/grpo/continual_grpo_trainer.py @@ -124,7 +124,6 @@ def __init__( peft_config=peft_config, ) - # No need for anything else as PPO itself is already set up with the reward model self.accelerator = ( ContinualGRPOTrainer.shared_accelerator ) # turn the accelerator back to the shared one diff --git a/benchmarks/grpo/grpo_continual.py b/benchmarks/grpo/grpo_continual.py index e885ec4f..1713bc21 100644 --- a/benchmarks/grpo/grpo_continual.py +++ b/benchmarks/grpo/grpo_continual.py @@ -28,7 +28,7 @@ from benchmarks.dataloading import init_continual_dataset -# The code is based on TRL DPO script https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py +# The code is based on TRL GRPO script https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py def main(script_args, training_args, model_args): # Determine torch dtype and quantization configs torch_dtype = (