From a34219c410b88cf04ad4ac804bd1f6af28ef8c76 Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Thu, 7 Sep 2023 15:03:01 +0000 Subject: [PATCH 01/11] Initial commit DPO support --- trlx/pipeline/offline_pipeline.py | 110 +++++++++++ trlx/trainer/accelerate_dpo_trainer.py | 241 +++++++++++++++++++++++++ 2 files changed, 351 insertions(+) create mode 100644 trlx/trainer/accelerate_dpo_trainer.py diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index cee900cfc..1974dacde 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -277,3 +277,113 @@ def create_loader(self, batch_size: int): collate_fn=ilql_seq2seq_collate_fn, drop_last=torch.distributed.is_initialized(), ) + + +@dataclass +class DPOPreferences: + prompt_tokens: Tuple + chosen_tokens: Tuple + rejected_tokens: Tuple + + +class DPOStore(BaseRolloutStore): + # Adapted from TRL + def __init__(self, preferences: List[DPOPreferences], tokenizer: PreTrainedTokenizer): + super().__init__() + self.tokenizer = tokenizer + + self.history = [ + self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences + ] + + @staticmethod + def tokenize_preferences(samples, tokenizer, max_length=2048): + chosen_tokens = tokenizer(samples[0], add_special_tokens=False) + rejected_tokens = tokenizer(samples[1], add_special_tokens=False) + prompt_tokens = tokenizer(samples[2], add_special_tokens=False) + + chosen_tokens["input_ids"].append(tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt only + if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: + if tokenizer.truncation_side == "right": + prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} + elif tokenizer.truncation_side == "left": + prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} + + # if that's still too long, truncate the response + if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: + chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()} + rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()} + + return DPOPreferences(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) + + def _build_batch_from_preference_tokens(self, preference_tokens: DPOPreferences): + # Create labels + chosen_sequence_tokens = { + k: preference_tokens.prompt_tokens[k] + preference_tokens.chosen_tokens[k] + for k in preference_tokens.chosen_tokens + } + rejected_sequence_tokens = { + k: preference_tokens.prompt_tokens[k] + preference_tokens.rejected_tokens[k] + for k in preference_tokens.rejected_tokens + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [ + self.label_pad_token_id + ] * len(preference_tokens.prompt_tokens["input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [ + self.label_pad_token_id + ] * len(preference_tokens.prompt_tokens["input_ids"]) + + batch = {} + + for k, toks in { + "chosen": chosen_sequence_tokens, + "rejected": rejected_sequence_tokens, + "prompt": preference_tokens.prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}_{type_key}"] = tokens + + return batch + + def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: + def collate_fn(batch): + # first, pad everything to the same length + padded_batch = {} + for k in batch[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + # adapted from https://stackoverflow.com/questions/73256206 + if "prompt" in k: + to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] + else: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + if k.endswith("_input_ids"): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = self.padding_value + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + # for the prompt, flip back so padding is on left side + if "prompt" in k: + padded_batch[k] = padded_batch[k].flip(dims=[1]) + else: + padded_batch[k] = [ex[k] for ex in batch] + + return padded_batch + + return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py new file mode 100644 index 000000000..d4613d8b9 --- /dev/null +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -0,0 +1,241 @@ +from dataclasses import dataclass +from typing import Dict, List, Tuple, Union + +import torch +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, PretrainedConfig + +from trlx.data.configs import TRLConfig +from trlx.data.method_configs import MethodConfig, register_method +from trlx.pipeline.offline_pipeline import DPOStore +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer + + +@dataclass +@register_method +class DPOConfig(MethodConfig): + """ + Config for DPO training + + :param gen_kwargs: kwargs for generation + :type gen_kwargs: Dict[str, Any] + """ + + gen_kwargs: dict + + +@register_trainer +class AccelerateDPOTrainer(AccelerateRLTrainer): + def __init__(self, config: TRLConfig, **kwargs): + super().__init__(config, **kwargs) + + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + def get_arch(self, config): + from_fn = AutoModelForCausalLM.from_pretrained + if issubclass(type(config.model.model_path), PretrainedConfig): + from_fn = AutoModelForCausalLM.from_config + + model = from_fn(config.model.model_path) + + if config.model.peft_config is not None: + # Initialize the peft adapter + import peft + + peft_config = config.model.peft_config + if not isinstance(peft_config, peft.PeftConfig): + if isinstance(peft_config, dict): + peft_config = peft.get_peft_config(peft_config) + else: + raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") + model = peft.get_peft_model(model, peft_config) + if self.accelerator.is_main_process: + model.print_trainable_parameters() + + return model + + def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + Args: + batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of + shape (batch_size, sequence_length). + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + concatenated_batch = {} + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k else self.padding_value + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k else self.padding_value + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ) + return concatenated_batch + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + reference_free: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape:(batch_size,) + beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the + reference model as beta -> 0. + reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns + equal probability to all responses. + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, + respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + if reference_free: + ref_logratios = 0 + + logits = pi_logratios - ref_logratios + + losses = -F.logsigmoid(self.beta * logits) + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are + ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum + of the log probabilities of the (non-masked) tokens. + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given + logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != self.label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == self.label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + concatenated_batch = self.concatenated_inputs(batch) + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + ).logits.to(torch.float32) + all_logps = self._get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + ) + chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]] + rejected_logps = all_logps[batch["chosen_input_ids"].shape[0] :] + + chosen_logits = all_logits[: batch["chosen_input_ids"].shape[0]] + rejected_logits = all_logits[batch["chosen_input_ids"].shape[0] :] + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def loss(self, batch): + stats = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(self.model, batch) + with torch.no_grad(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.ref_model, batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + stats["rewards/chosen"] = chosen_rewards.cpu().numpy().mean() + stats["rewards/rejected"] = rejected_rewards.cpu().numpy().mean() + stats["rewards/accuracies"] = reward_accuracies.cpu().numpy().mean() + stats["rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean() + stats["logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean() + stats["logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() + + stats["logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean() + stats["logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() + + stats["loss"] = losses.detach().cpu().numpy().mean() + + return losses.mean(), stats + + def prepare_learning(self): + train_dataloader = self.store.create_loader(self.config.train.batch_size) + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + + ( + self.model, + self.opt, + self.train_dataloader, + self.eval_dataloader, + ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader) + + self.n_updates_per_batch = 1 + self.total_steps = self.config.train.epochs * len(self.train_dataloader) + self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def make_experience(self, samples, seq_length): + preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples] + self.store = DPOStore(preferences, self.tokenizer) From cd923c11e8142aa22158e9d8f02059cf2b0cafea Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Fri, 8 Sep 2023 14:59:36 +0000 Subject: [PATCH 02/11] Add default config for DPO and trainer functionality --- trlx/data/default_configs.py | 27 ++++++++++++++++++++++++++ trlx/trainer/accelerate_dpo_trainer.py | 12 ++++++++++++ trlx/utils/modeling.py | 16 +++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 5277d7010..22255d4b8 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -2,6 +2,7 @@ from trlx.models.modeling_ilql import ILQLConfig from trlx.models.modeling_ppo import PPOConfig +from trlx.trainer.accelerate_dpo_trainer import DPOConfig from trlx.trainer.accelerate_sft_trainer import SFTConfig from .configs import ( @@ -146,3 +147,29 @@ def default_nemo_1_3b_config(): here = Path(__file__).parent return OmegaConf.load(here.parent.parent / "configs" / "nemo_configs" / "megatron_1.3b.yaml") + + +def default_dpo_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=8, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="DPOTrainer", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps + ), + method=DPOConfig( + name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), beta=0.1 + ), + ) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index d4613d8b9..6e895ab18 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -2,6 +2,7 @@ from typing import Dict, List, Tuple, Union import torch +import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, PretrainedConfig @@ -10,6 +11,7 @@ from trlx.pipeline.offline_pipeline import DPOStore from trlx.trainer import register_trainer from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.utils.modeling import pad_to_length @dataclass @@ -23,6 +25,7 @@ class DPOConfig(MethodConfig): """ gen_kwargs: dict + beta: float = 0.1 @register_trainer @@ -30,12 +33,21 @@ class AccelerateDPOTrainer(AccelerateRLTrainer): def __init__(self, config: TRLConfig, **kwargs): super().__init__(config, **kwargs) + # Set up a reference model when hydra heads are not used + if not hasattr(self.model, "frozen_head") and not self.model.peft_type: + self.ref_model = self.get_arch(self.config) + self.ref_model.to(self.accelerator.device) + self.ref_model.eval() + self.generate_kwargs = dict( config.method.gen_kwargs, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, ) + # `beta` corresponding to the DPO hyperparameter + self.beta = config.method.beta + def get_arch(self, config): from_fn = AutoModelForCausalLM.from_pretrained if issubclass(type(config.model.model_path), PretrainedConfig): diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 47688f553..e5de3d990 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -265,6 +265,22 @@ def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int): ) +def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: + # From original TRL + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [ + tensor, + pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), + ], + dim=dim, + ) + + class RunningMoments: def __init__(self): """ From deb71c195b656b34e0d701298536ce803451735b Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Tue, 12 Sep 2023 15:30:15 +0000 Subject: [PATCH 03/11] Add DPO training example and fix minor bugs --- examples/hh/dpo_hh.py | 99 ++++++++++++++++++++++++++ trlx/pipeline/__init__.py | 4 +- trlx/pipeline/offline_pipeline.py | 24 ++++--- trlx/trainer/accelerate_dpo_trainer.py | 17 +++-- trlx/trlx.py | 4 +- trlx/utils/loading.py | 1 + 6 files changed, 130 insertions(+), 19 deletions(-) create mode 100644 examples/hh/dpo_hh.py diff --git a/examples/hh/dpo_hh.py b/examples/hh/dpo_hh.py new file mode 100644 index 000000000..acddecc82 --- /dev/null +++ b/examples/hh/dpo_hh.py @@ -0,0 +1,99 @@ +import json +import sys +from collections import defaultdict + +import tqdm +from datasets import Dataset, load_dataset + +import trlx +from trlx.data.default_configs import ( + DPOConfig, + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + +default_config = TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=1, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateDPOTrainer", + checkpoint_dir="checkpoints/dpo_hh", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps + method=DPOConfig( + name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), beta=0.1 + ), +) + + +def get_hh(split: str, sanity_check=False, silent=False): + dataset = load_dataset("Anthropic/hh-rlhf", split=split) + if sanity_check: + dataset = dataset.select(range(min(len(dataset), 1000))) + + def extract_anthropic_prompt(prompt_and_response): + """Extract the anthropic prompt from a prompt and response pair.""" + search_term = "\n\nAssistant:" + search_term_idx = prompt_and_response.rfind(search_term) + assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" + return prompt_and_response[: search_term_idx + len(search_term)] + + def split_prompt_and_responses(ex): + prompt = extract_anthropic_prompt(ex["chosen"]) + chosen_response = ex["chosen"][len(prompt) :] + rejected_response = ex["rejected"][len(prompt) :] + return prompt, chosen_response, rejected_response + + data = defaultdict(lambda: defaultdict(list)) + for row in tqdm.tqdm(dataset, desc="Processing HH", disable=silent): + prompt, chosen, rejected = split_prompt_and_responses(row) + responses = [chosen, rejected] + n_responses = len(data[prompt]["responses"]) + data[prompt]["pairs"].append((n_responses, n_responses + 1)) + data[prompt]["responses"].extend(responses) + data[prompt]["sft_target"] = chosen + + def gen(): + for prompt, values in data.items(): + yield { + "prompt": prompt, + "responses": values["responses"], + "pairs": values["pairs"], + } + + return Dataset.from_generator(gen) + + +def preprocess(sample): + pass + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess) + + trlx.train( + config=config, + samples=dataset["train"], + eval_prompts=dataset["test"]["prompt"][:280], + # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, + stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/pipeline/__init__.py b/trlx/pipeline/__init__.py index c7dba9e97..7e927cd49 100644 --- a/trlx/pipeline/__init__.py +++ b/trlx/pipeline/__init__.py @@ -166,8 +166,8 @@ def __next__(self): # noqa: C901 minibatch = BatchEncoding(sliced_data) elif is_dataclass(batch): minibatch = batch.__class__(**sliced_data) - # else: - # minibatch = sliced_data + else: + minibatch = sliced_data minibatches.append(minibatch) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 1974dacde..0d0256ffd 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -288,9 +288,17 @@ class DPOPreferences: class DPOStore(BaseRolloutStore): # Adapted from TRL - def __init__(self, preferences: List[DPOPreferences], tokenizer: PreTrainedTokenizer): + def __init__( + self, + preferences: List[DPOPreferences], + tokenizer: PreTrainedTokenizer, + label_pad_token_id: int, + padding_value: int, + ): super().__init__() self.tokenizer = tokenizer + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value self.history = [ self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences @@ -298,9 +306,9 @@ def __init__(self, preferences: List[DPOPreferences], tokenizer: PreTrainedToken @staticmethod def tokenize_preferences(samples, tokenizer, max_length=2048): - chosen_tokens = tokenizer(samples[0], add_special_tokens=False) - rejected_tokens = tokenizer(samples[1], add_special_tokens=False) - prompt_tokens = tokenizer(samples[2], add_special_tokens=False) + chosen_tokens = tokenizer(samples["chosen"], add_special_tokens=False) + rejected_tokens = tokenizer(samples["rejected"], add_special_tokens=False) + prompt_tokens = tokenizer(samples["prompt"], add_special_tokens=False) chosen_tokens["input_ids"].append(tokenizer.eos_token_id) chosen_tokens["attention_mask"].append(1) @@ -313,14 +321,14 @@ def tokenize_preferences(samples, tokenizer, max_length=2048): # if combined sequence is too long, truncate the prompt only if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: if tokenizer.truncation_side == "right": - prompt_tokens = {k: v[: self.max_prompt_length] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()} elif tokenizer.truncation_side == "left": - prompt_tokens = {k: v[-self.max_prompt_length :] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()} # if that's still too long, truncate the response if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: - chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()} - rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()} + chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()} + rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()} return DPOPreferences(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index 6e895ab18..2ae2b1750 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -25,7 +25,9 @@ class DPOConfig(MethodConfig): """ gen_kwargs: dict - beta: float = 0.1 + beta: float = 0.1 # Beta value for DPO loss calculation + label_pad_token_id: int = -100 # -100 is ignore token for CELoss + padding_value: int = 0 @register_trainer @@ -33,11 +35,10 @@ class AccelerateDPOTrainer(AccelerateRLTrainer): def __init__(self, config: TRLConfig, **kwargs): super().__init__(config, **kwargs) - # Set up a reference model when hydra heads are not used - if not hasattr(self.model, "frozen_head") and not self.model.peft_type: - self.ref_model = self.get_arch(self.config) - self.ref_model.to(self.accelerator.device) - self.ref_model.eval() + # TODO: Avoid setting up a reference model when hydra heads are used + self.ref_model = self.get_arch(self.config) + self.ref_model.to(self.accelerator.device) + self.ref_model.eval() self.generate_kwargs = dict( config.method.gen_kwargs, @@ -47,6 +48,8 @@ def __init__(self, config: TRLConfig, **kwargs): # `beta` corresponding to the DPO hyperparameter self.beta = config.method.beta + self.label_pad_token_id = config.method.label_pad_token_id + self.padding_value = config.method.padding_value def get_arch(self, config): from_fn = AutoModelForCausalLM.from_pretrained @@ -250,4 +253,4 @@ def prepare_learning(self): def make_experience(self, samples, seq_length): preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples] - self.store = DPOStore(preferences, self.tokenizer) + self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value) diff --git a/trlx/trlx.py b/trlx/trlx.py index 7fbce94f4..6d98019fd 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -64,7 +64,7 @@ def train( # noqa: C901 config = default_ppo_config() elif rewards: config = default_ilql_config() - else: + else: # Alternatively, could be DPO. But, ignoring since passing `config` implicitly is deprecated config = default_sft_config() set_seed(config.train.seed) @@ -102,7 +102,7 @@ def train( # noqa: C901 if eval_prompts is None: eval_prompts = prompts[:batch_size] - # Offline training from the collected samples (e.g. SFT, ILQL) + # Offline training from the collected samples (e.g. SFT, ILQL, DPO) elif samples: if rewards is not None: if len(samples) != len(rewards): diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 9c7dccf76..3dc5a4c52 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -6,6 +6,7 @@ # Register load trainers via module import from trlx.trainer import _TRAINERS, register_trainer +from trlx.trainer.accelerate_dpo_trainer import AccelerateDPOTrainer from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer From ca7a828e337cdb6d42e6ecb45d6649920dc5084d Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Wed, 13 Sep 2023 08:33:18 +0000 Subject: [PATCH 04/11] Update .gitignore and minor refactor --- .gitignore | 3 ++- examples/hh/dpo_hh.py | 48 ++++++------------------------------------- 2 files changed, 8 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index 45e03b9e5..032c30f76 100644 --- a/.gitignore +++ b/.gitignore @@ -150,4 +150,5 @@ OUT/ examples/experiments/grounded_program_synthesis/dataset ckpts/ -ray_results/ +ray_result/ +examples/checkpoints/ diff --git a/examples/hh/dpo_hh.py b/examples/hh/dpo_hh.py index acddecc82..7d202215f 100644 --- a/examples/hh/dpo_hh.py +++ b/examples/hh/dpo_hh.py @@ -1,9 +1,7 @@ import json import sys -from collections import defaultdict -import tqdm -from datasets import Dataset, load_dataset +from datasets import load_dataset import trlx from trlx.data.default_configs import ( @@ -33,49 +31,15 @@ optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps method=DPOConfig( - name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), beta=0.1 + name="DPOConfig", + gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), + beta=0.1, + label_pad_token_id=-100, + padding_value=0, ), ) -def get_hh(split: str, sanity_check=False, silent=False): - dataset = load_dataset("Anthropic/hh-rlhf", split=split) - if sanity_check: - dataset = dataset.select(range(min(len(dataset), 1000))) - - def extract_anthropic_prompt(prompt_and_response): - """Extract the anthropic prompt from a prompt and response pair.""" - search_term = "\n\nAssistant:" - search_term_idx = prompt_and_response.rfind(search_term) - assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'" - return prompt_and_response[: search_term_idx + len(search_term)] - - def split_prompt_and_responses(ex): - prompt = extract_anthropic_prompt(ex["chosen"]) - chosen_response = ex["chosen"][len(prompt) :] - rejected_response = ex["rejected"][len(prompt) :] - return prompt, chosen_response, rejected_response - - data = defaultdict(lambda: defaultdict(list)) - for row in tqdm.tqdm(dataset, desc="Processing HH", disable=silent): - prompt, chosen, rejected = split_prompt_and_responses(row) - responses = [chosen, rejected] - n_responses = len(data[prompt]["responses"]) - data[prompt]["pairs"].append((n_responses, n_responses + 1)) - data[prompt]["responses"].extend(responses) - data[prompt]["sft_target"] = chosen - - def gen(): - for prompt, values in data.items(): - yield { - "prompt": prompt, - "responses": values["responses"], - "pairs": values["pairs"], - } - - return Dataset.from_generator(gen) - - def preprocess(sample): pass From 02bd944abed05d029a079c8f40dcf1568b733833 Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Wed, 13 Sep 2023 13:25:53 +0000 Subject: [PATCH 05/11] Add type hinting and minor refactor --- examples/hh/dpo_hh.py | 5 ++-- trlx/data/dpo_types.py | 13 ++++++++++ trlx/pipeline/offline_pipeline.py | 35 ++++++++++++++------------ trlx/trainer/accelerate_dpo_trainer.py | 19 ++++++++------ 4 files changed, 46 insertions(+), 26 deletions(-) create mode 100644 trlx/data/dpo_types.py diff --git a/examples/hh/dpo_hh.py b/examples/hh/dpo_hh.py index 7d202215f..48714608a 100644 --- a/examples/hh/dpo_hh.py +++ b/examples/hh/dpo_hh.py @@ -41,7 +41,8 @@ def preprocess(sample): - pass + sample["dpo"] = [sample["prompt"], sample["chosen"], sample["rejected"]] + return sample def main(hparams={}): @@ -51,7 +52,7 @@ def main(hparams={}): trlx.train( config=config, - samples=dataset["train"], + samples=dataset["train"]["dpo"], eval_prompts=dataset["test"]["prompt"][:280], # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], diff --git a/trlx/data/dpo_types.py b/trlx/data/dpo_types.py new file mode 100644 index 000000000..fe7eec917 --- /dev/null +++ b/trlx/data/dpo_types.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from transformers import BatchEncoding + + +@dataclass +class DPOElement: + prompt_tokens: BatchEncoding + chosen_tokens: BatchEncoding + rejected_tokens: BatchEncoding + + +# TODO: Extend to include a concrete class for DPOPreferenceBatch diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 0d0256ffd..a824ffb98 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -10,6 +10,7 @@ PreTrainedTokenizerFast, ) +from trlx.data.dpo_types import DPOElement from trlx.data.ilql_types import ( ILQLBatch, ILQLElement, @@ -279,19 +280,12 @@ def create_loader(self, batch_size: int): ) -@dataclass -class DPOPreferences: - prompt_tokens: Tuple - chosen_tokens: Tuple - rejected_tokens: Tuple - - class DPOStore(BaseRolloutStore): # Adapted from TRL def __init__( self, - preferences: List[DPOPreferences], - tokenizer: PreTrainedTokenizer, + preferences: List[DPOElement], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], label_pad_token_id: int, padding_value: int, ): @@ -305,10 +299,19 @@ def __init__( ] @staticmethod - def tokenize_preferences(samples, tokenizer, max_length=2048): - chosen_tokens = tokenizer(samples["chosen"], add_special_tokens=False) - rejected_tokens = tokenizer(samples["rejected"], add_special_tokens=False) - prompt_tokens = tokenizer(samples["prompt"], add_special_tokens=False) + def tokenize_preferences( + sample: Iterable[str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048 + ) -> DPOElement: + if isinstance(sample, Iterable): + if len(sample) != 3: + raise ValueError( + f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}" + ) + prompt_tokens = tokenizer(sample[0], add_special_tokens=False) + chosen_tokens = tokenizer(sample[1], add_special_tokens=False) + rejected_tokens = tokenizer(sample[2], add_special_tokens=False) + else: + raise ValueError(f"{sample} is not an iterable") chosen_tokens["input_ids"].append(tokenizer.eos_token_id) chosen_tokens["attention_mask"].append(1) @@ -330,9 +333,9 @@ def tokenize_preferences(samples, tokenizer, max_length=2048): chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()} rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()} - return DPOPreferences(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) + return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) - def _build_batch_from_preference_tokens(self, preference_tokens: DPOPreferences): + def _build_batch_from_preference_tokens(self, preference_tokens: DPOElement) -> Dict: # Create labels chosen_sequence_tokens = { k: preference_tokens.prompt_tokens[k] + preference_tokens.chosen_tokens[k] @@ -366,7 +369,7 @@ def _build_batch_from_preference_tokens(self, preference_tokens: DPOPreferences) return batch def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: - def collate_fn(batch): + def collate_fn(batch: Iterable[dict]): # first, pad everything to the same length padded_batch = {} for k in batch[0].keys(): diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index 2ae2b1750..854c3152d 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import Dict, Iterable, List, Tuple, Union import torch import torch.nn as nn @@ -32,6 +32,8 @@ class DPOConfig(MethodConfig): @register_trainer class AccelerateDPOTrainer(AccelerateRLTrainer): + """DPO Accelerate Trainer""" + def __init__(self, config: TRLConfig, **kwargs): super().__init__(config, **kwargs) @@ -47,9 +49,9 @@ def __init__(self, config: TRLConfig, **kwargs): ) # `beta` corresponding to the DPO hyperparameter - self.beta = config.method.beta - self.label_pad_token_id = config.method.label_pad_token_id - self.padding_value = config.method.padding_value + self.beta: float = config.method.beta + self.label_pad_token_id: int = config.method.label_pad_token_id + self.padding_value: int = config.method.padding_value def get_arch(self, config): from_fn = AutoModelForCausalLM.from_pretrained @@ -177,8 +179,9 @@ def _get_batch_logps( def concatenated_forward( self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: - """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. - We do this to avoid doing two forward passes, because it's faster for FSDP. + """ + Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + This is faster and avoids two forward passes. """ concatenated_batch = self.concatenated_inputs(batch) all_logits = model( @@ -197,7 +200,7 @@ def concatenated_forward( rejected_logits = all_logits[batch["chosen_input_ids"].shape[0] :] return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) - def loss(self, batch): + def loss(self, batch: Dict[str, Union[List, torch.LongTensor]]): stats = {} ( @@ -251,6 +254,6 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - def make_experience(self, samples, seq_length): + def make_experience(self, samples: Iterable[Iterable], seq_length: int): preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples] self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value) From 00d0f2c77ce5c0d3804f576fdbda5237b6d8e9af Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Thu, 14 Sep 2023 08:50:37 +0000 Subject: [PATCH 06/11] Update docstrings --- examples/hh/dpo_hh.py | 2 +- trlx/trainer/accelerate_dpo_trainer.py | 28 ++++++++++++++++++++------ trlx/trlx.py | 2 +- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/examples/hh/dpo_hh.py b/examples/hh/dpo_hh.py index 48714608a..d102c33da 100644 --- a/examples/hh/dpo_hh.py +++ b/examples/hh/dpo_hh.py @@ -19,7 +19,7 @@ seq_length=1024, epochs=100, total_steps=1000, - batch_size=1, + batch_size=4, checkpoint_interval=10000, eval_interval=100, pipeline="PromptPipeline", diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index 854c3152d..ab4fb5eca 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -20,12 +20,16 @@ class DPOConfig(MethodConfig): """ Config for DPO training - :param gen_kwargs: kwargs for generation - :type gen_kwargs: Dict[str, Any] + Args: + gen_kwargs (Dict[str, Any]) : kwargs for generation + beta (float) : Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + label_pad_token_id (int) : token to pad labels with. -100 token is ignored + for CELoss + padding_value (int) : token to pad input sequence with """ gen_kwargs: dict - beta: float = 0.1 # Beta value for DPO loss calculation + beta: float = 0.1 label_pad_token_id: int = -100 # -100 is ignore token for CELoss padding_value: int = 0 @@ -48,7 +52,6 @@ def __init__(self, config: TRLConfig, **kwargs): pad_token_id=self.tokenizer.pad_token_id, ) - # `beta` corresponding to the DPO hyperparameter self.beta: float = config.method.beta self.label_pad_token_id: int = config.method.label_pad_token_id self.padding_value: int = config.method.padding_value @@ -79,8 +82,8 @@ def get_arch(self, config): def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: """Concatenate the chosen and rejected inputs into a single tensor. Args: - batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of - shape (batch_size, sequence_length). + batch (Dict): A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', + which are tensors of shape (batch_size, sequence_length). Returns: A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. """ @@ -182,6 +185,19 @@ def concatenated_forward( """ Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. This is faster and avoids two forward passes. + Args: + model: Base model being trained + batch(Dict): : A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', + which are tensors of shape (batch_size, sequence_length). + + Returns: + A tuple containing 4 tensors : (chosen_log_probabilities, + rejected_log_probabilities, + chosen_logits, + rejected_logits) + The 2 {chosen, rejected}_logp tensors contains the per-token chosen and rejected log probabilities respectively. + The 2 {chosen, rejected}_logits tensors contains the raw logits for chosen and rejected responses from the model + forward pass. """ concatenated_batch = self.concatenated_inputs(batch) all_logits = model( diff --git a/trlx/trlx.py b/trlx/trlx.py index 6d98019fd..7f2aef9c0 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -64,7 +64,7 @@ def train( # noqa: C901 config = default_ppo_config() elif rewards: config = default_ilql_config() - else: # Alternatively, could be DPO. But, ignoring since passing `config` implicitly is deprecated + else: # Alternatively, could be `default_dpo_config()`. But, ignoring since passing `config` implicitly is deprecated config = default_sft_config() set_seed(config.train.seed) From 9b43dd50c3d8628e62e6663d47d434fc98d63a91 Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Wed, 27 Sep 2023 08:43:36 +0000 Subject: [PATCH 07/11] Add Deepspeed init support when using stage 3 --- trlx/trainer/accelerate_dpo_trainer.py | 43 +++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index ab4fb5eca..c6f91dd76 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -4,8 +4,12 @@ import torch import torch.nn as nn import torch.nn.functional as F +from accelerate.utils import is_deepspeed_available from transformers import AutoModelForCausalLM, PretrainedConfig +if is_deepspeed_available(): + import deepspeed + from trlx.data.configs import TRLConfig from trlx.data.method_configs import MethodConfig, register_method from trlx.pipeline.offline_pipeline import DPOStore @@ -43,7 +47,10 @@ def __init__(self, config: TRLConfig, **kwargs): # TODO: Avoid setting up a reference model when hydra heads are used self.ref_model = self.get_arch(self.config) - self.ref_model.to(self.accelerator.device) + if self.accelerator.state.deepspeed_plugin.zero_stage == 3: + self.ref_model = self._prepare_deepspeed_zero3(self.ref_model) + else: + self.ref_model.to(self.accelerator.device) self.ref_model.eval() self.generate_kwargs = dict( @@ -255,6 +262,40 @@ def loss(self, batch: Dict[str, Union[List, torch.LongTensor]]): return losses.mean(), stats + def _prepare_deepspeed_zero3(self, model: nn.Module): + # Adapted from accelerate: + # https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + # TODO: figure out if any other parameters are needed to optimize inference + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + + # See DeepSpeed docs for definition of these parameters: https://deepspeed.readthedocs.io/en/latest/zero3.html + config_kwargs = { + "train_micro_batch_size_per_gpu": self.config.train.batch_size, + "train_batch_size": self.config.train.batch_size + * self.accelerator.state.num_processes + * deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"] + * self.accelerator.num_processes, + "zero_optimization": {"stage": 3, "offload_param": {"device": deepspeed_plugin.offload_param_device}}, + } + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None: + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + def prepare_learning(self): train_dataloader = self.store.create_loader(self.config.train.batch_size) eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) From 737a49655f936ae7d5c3088206d32011e63c7b84 Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Tue, 14 Nov 2023 14:04:19 +0000 Subject: [PATCH 08/11] Add dpo example --- examples/dpo_ultrafeedback.py | 71 +++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/dpo_ultrafeedback.py diff --git a/examples/dpo_ultrafeedback.py b/examples/dpo_ultrafeedback.py new file mode 100644 index 000000000..a0e6cfb51 --- /dev/null +++ b/examples/dpo_ultrafeedback.py @@ -0,0 +1,71 @@ +import itertools +import json +import sys + +from datasets import load_dataset + +import trlx +from trlx.data.default_configs import ( + DPOConfig, + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + +default_config = TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=1, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateDPOTrainer", + checkpoint_dir="checkpoints/dpo_ultrafeedback", + ), + model=ModelConfig(model_path="HuggingFaceH4/mistral-7b-sft-beta", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="HuggingFaceH4/mistral-7b-sft-beta", truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps + method=DPOConfig( + name="DPOConfig", + gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), + beta=0.1, + label_pad_token_id=-100, + padding_value=0, + ), +) + + +def preprocess(sample): + """ + Return list of lists with Context/Prompt at index 0, Chosen at index 1 and rejected at index 2 + """ + assert len(sample["chosen"]) == len(sample["rejected"]) == 2 + + sample["dpo"] = [sample["prompt"], sample["chosen"][1]["content"], sample["rejected"][1]["content"]] + return sample + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized").map(preprocess) + + trlx.train( + config=config, + samples=dataset["train_prefs"]["dpo"], + eval_prompts=dataset["test_prefs"]["prompt"][:8], + # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, + stop_sequences=["User:", "user:", "Assistant:", "assistant:"] + + ["{e}x {i}put:".format(e=e, i=i) for e, i in itertools.product(["e", "E"], ["in", "In", "out", "Out"])], + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) From dfa814d834eeaa7e4ecec14747549da076cad08f Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Wed, 15 Nov 2023 14:59:11 +0000 Subject: [PATCH 09/11] Update hyperparementers --- examples/dpo_ultrafeedback.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dpo_ultrafeedback.py b/examples/dpo_ultrafeedback.py index a0e6cfb51..c97a4ce5e 100644 --- a/examples/dpo_ultrafeedback.py +++ b/examples/dpo_ultrafeedback.py @@ -29,11 +29,11 @@ ), model=ModelConfig(model_path="HuggingFaceH4/mistral-7b-sft-beta", num_layers_unfrozen=-1), tokenizer=TokenizerConfig(tokenizer_path="HuggingFaceH4/mistral-7b-sft-beta", truncation_side="right"), - optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=2e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps method=DPOConfig( name="DPOConfig", - gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), + gen_kwargs=dict(max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), beta=0.1, label_pad_token_id=-100, padding_value=0, @@ -59,7 +59,7 @@ def main(hparams={}): trlx.train( config=config, samples=dataset["train_prefs"]["dpo"], - eval_prompts=dataset["test_prefs"]["prompt"][:8], + eval_prompts=dataset["test_prefs"]["prompt"][:128], # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, stop_sequences=["User:", "user:", "Assistant:", "assistant:"] + ["{e}x {i}put:".format(e=e, i=i) for e, i in itertools.product(["e", "E"], ["in", "In", "out", "Out"])], From 6d63004ee065c1f7c72a386ea4b290339e2a3707 Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Mon, 20 Nov 2023 17:01:46 +0100 Subject: [PATCH 10/11] Fix prompt truncation bug and handle deepspeed preparation --- examples/dpo_ultrafeedback.py | 68 ++++++++++++++++++-------- trlx/pipeline/offline_pipeline.py | 19 ++++--- trlx/trainer/accelerate_dpo_trainer.py | 17 +++++-- trlx/trainer/accelerate_sft_trainer.py | 2 +- trlx/trainer/nemo_sft_trainer.py | 2 +- trlx/trlx.py | 3 +- 6 files changed, 75 insertions(+), 36 deletions(-) diff --git a/examples/dpo_ultrafeedback.py b/examples/dpo_ultrafeedback.py index c97a4ce5e..f02a608c5 100644 --- a/examples/dpo_ultrafeedback.py +++ b/examples/dpo_ultrafeedback.py @@ -1,8 +1,9 @@ -import itertools -import json import sys +import json +from functools import partial from datasets import load_dataset +from transformers import AutoTokenizer import trlx from trlx.data.default_configs import ( @@ -15,25 +16,30 @@ TRLConfig, ) +model_path = "HuggingFaceH4/mistral-7b-sft-beta" +wandb_project = "trlx" + default_config = TRLConfig( train=TrainConfig( seq_length=1024, - epochs=100, - total_steps=1000, + epochs=2, + total_steps=1000000, batch_size=1, - checkpoint_interval=10000, - eval_interval=100, + checkpoint_interval=100000, + eval_interval=1000, + seed=42, + project_name=wandb_project, pipeline="PromptPipeline", trainer="AccelerateDPOTrainer", checkpoint_dir="checkpoints/dpo_ultrafeedback", ), - model=ModelConfig(model_path="HuggingFaceH4/mistral-7b-sft-beta", num_layers_unfrozen=-1), - tokenizer=TokenizerConfig(tokenizer_path="HuggingFaceH4/mistral-7b-sft-beta", truncation_side="right"), - optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=2e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), + model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps method=DPOConfig( name="DPOConfig", - gen_kwargs=dict(max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), + gen_kwargs=dict(max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), beta=0.1, label_pad_token_id=-100, padding_value=0, @@ -41,28 +47,50 @@ ) -def preprocess(sample): +def preprocess(sample, tokenizer, test=False): """ - Return list of lists with Context/Prompt at index 0, Chosen at index 1 and rejected at index 2 + Formats the input to the same training style used for mistral-7b-v0.1 + When fine-tuning, modify your pre-processing to match the prompt template used during pretraining. """ assert len(sample["chosen"]) == len(sample["rejected"]) == 2 - sample["dpo"] = [sample["prompt"], sample["chosen"][1]["content"], sample["rejected"][1]["content"]] - return sample + assistant_prompt = "<|assistant|>" + + prompt, chosen = tokenizer.apply_chat_template(sample["chosen"], tokenize=False).split(assistant_prompt) + rejected = tokenizer.apply_chat_template(sample["rejected"], tokenize=False).split(assistant_prompt)[-1] + + return { + "prompt": prompt if not test else prompt + assistant_prompt, + "chosen": assistant_prompt + chosen, + "rejected": assistant_prompt + rejected, + } def main(hparams={}): config = TRLConfig.update(default_config, hparams) - dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized").map(preprocess) + tokenizer = AutoTokenizer.from_pretrained(model_path) + dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized") + + dataset["dpo_train"] = dataset["train_prefs"].map( + partial(preprocess, tokenizer=tokenizer, test=False), + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], + ) + dataset["dpo_test"] = dataset["test_prefs"].map( + partial(preprocess, tokenizer=tokenizer, test=True), + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], + ) + + print( + f"Length of training dataset : {len(dataset['dpo_train'])} \ + Length of test dataset : {len(dataset['dpo_test'])}" + ) trlx.train( config=config, - samples=dataset["train_prefs"]["dpo"], - eval_prompts=dataset["test_prefs"]["prompt"][:128], - # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, - stop_sequences=["User:", "user:", "Assistant:", "assistant:"] - + ["{e}x {i}put:".format(e=e, i=i) for e, i in itertools.product(["e", "E"], ["in", "In", "out", "Out"])], + samples=dataset["dpo_train"], + eval_prompts=dataset["dpo_test"]["prompt"][:8], # running eval on subset only + stop_sequences=["<|user|>", "<|User|>"], ) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index a824ffb98..6bab1db73 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -300,16 +300,19 @@ def __init__( @staticmethod def tokenize_preferences( - sample: Iterable[str], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_length=2048 + sample: Iterable[str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + max_length=2048, + max_prompt_length=256, ) -> DPOElement: if isinstance(sample, Iterable): if len(sample) != 3: raise ValueError( f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}" ) - prompt_tokens = tokenizer(sample[0], add_special_tokens=False) - chosen_tokens = tokenizer(sample[1], add_special_tokens=False) - rejected_tokens = tokenizer(sample[2], add_special_tokens=False) + prompt_tokens = tokenizer(sample["prompt"], add_special_tokens=False) + chosen_tokens = tokenizer(sample["chosen"], add_special_tokens=False) + rejected_tokens = tokenizer(sample["rejected"], add_special_tokens=False) else: raise ValueError(f"{sample} is not an iterable") @@ -324,14 +327,14 @@ def tokenize_preferences( # if combined sequence is too long, truncate the prompt only if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: if tokenizer.truncation_side == "right": - prompt_tokens = {k: v[:max_length] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()} elif tokenizer.truncation_side == "left": - prompt_tokens = {k: v[-max_length:] for k, v in prompt_tokens.items()} + prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()} # if that's still too long, truncate the response if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: - chosen_tokens = {k: v[: max_length - max_length] for k, v in chosen_tokens.items()} - rejected_tokens = {k: v[: max_length - max_length] for k, v in rejected_tokens.items()} + chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()} + rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()} return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py index c6f91dd76..e01b8e684 100644 --- a/trlx/trainer/accelerate_dpo_trainer.py +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -10,6 +10,7 @@ if is_deepspeed_available(): import deepspeed +import trlx.utils.logging as logging from trlx.data.configs import TRLConfig from trlx.data.method_configs import MethodConfig, register_method from trlx.pipeline.offline_pipeline import DPOStore @@ -18,6 +19,9 @@ from trlx.utils.modeling import pad_to_length +logger = logging.get_logger(__name__) + + @dataclass @register_method class DPOConfig(MethodConfig): @@ -47,9 +51,10 @@ def __init__(self, config: TRLConfig, **kwargs): # TODO: Avoid setting up a reference model when hydra heads are used self.ref_model = self.get_arch(self.config) - if self.accelerator.state.deepspeed_plugin.zero_stage == 3: - self.ref_model = self._prepare_deepspeed_zero3(self.ref_model) - else: + try: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3: + self.ref_model = self._prepare_deepspeed_zero3(self.ref_model) + except: self.ref_model.to(self.accelerator.device) self.ref_model.eval() @@ -311,6 +316,8 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - def make_experience(self, samples: Iterable[Iterable], seq_length: int): - preferences = [DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length) for sample in samples] + def make_experience(self, samples: Iterable[Iterable], seq_length: int, max_prompt_length: int): + preferences = [ + DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length, max_prompt_length) for sample in samples + ] self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value) diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index d5cbe3ea5..b76f50f14 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -87,7 +87,7 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - def make_experience(self, samples, seq_length): + def make_experience(self, samples, seq_length, **kwargs): if isinstance(samples[0], str): self.store = PromptPipeline(samples, seq_length, self.tokenizer) else: diff --git a/trlx/trainer/nemo_sft_trainer.py b/trlx/trainer/nemo_sft_trainer.py index 7f25254f1..d7410b35c 100644 --- a/trlx/trainer/nemo_sft_trainer.py +++ b/trlx/trainer/nemo_sft_trainer.py @@ -133,7 +133,7 @@ def eval_collate(elems): torch.set_float32_matmul_precision("medium") self.trainer.fit(self.model) - def make_experience(self, samples, seq_length): + def make_experience(self, samples, seq_length, **kwargs): if isinstance(samples[0], str): self.store = PromptPipeline(samples, seq_length, self.tokenizer) else: diff --git a/trlx/trlx.py b/trlx/trlx.py index 7f2aef9c0..ba75b5f54 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -114,7 +114,8 @@ def train( # noqa: C901 if rewards is not None: trainer.make_experience(samples, rewards, config.train.seq_length) else: - trainer.make_experience(samples, config.train.seq_length) + # this should be abstracted for all trainers with **kwargs + trainer.make_experience(samples, config.train.seq_length, max_prompt_length) else: raise ValueError("Either `samples` or `reward_fn` should be given for training") From 4a246038cd43fa44538404c84ca768d0bb89547a Mon Sep 17 00:00:00 2001 From: sandeepchittilla Date: Thu, 23 Nov 2023 16:12:56 +0100 Subject: [PATCH 11/11] Use slower training parameters --- examples/dpo_ultrafeedback.py | 12 ++++++------ trlx/pipeline/offline_pipeline.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/dpo_ultrafeedback.py b/examples/dpo_ultrafeedback.py index f02a608c5..4d8dd5b27 100644 --- a/examples/dpo_ultrafeedback.py +++ b/examples/dpo_ultrafeedback.py @@ -22,11 +22,11 @@ default_config = TRLConfig( train=TrainConfig( seq_length=1024, - epochs=2, - total_steps=1000000, + epochs=1, + total_steps=70000, batch_size=1, checkpoint_interval=100000, - eval_interval=1000, + eval_interval=5000, seed=42, project_name=wandb_project, pipeline="PromptPipeline", @@ -35,11 +35,11 @@ ), model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1), tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"), - optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=5e-7, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=1.0e-5)), scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps method=DPOConfig( name="DPOConfig", - gen_kwargs=dict(max_new_tokens=512, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), + gen_kwargs=dict(max_new_tokens=768, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), beta=0.1, label_pad_token_id=-100, padding_value=0, @@ -89,7 +89,7 @@ def main(hparams={}): trlx.train( config=config, samples=dataset["dpo_train"], - eval_prompts=dataset["dpo_test"]["prompt"][:8], # running eval on subset only + eval_prompts=dataset["dpo_test"]["prompt"][:2], # running eval on subset only stop_sequences=["<|user|>", "<|User|>"], ) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index 6bab1db73..21e31d775 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -400,4 +400,4 @@ def collate_fn(batch: Iterable[dict]): return padded_batch - return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle) + return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle, pin_memory=True)