diff --git a/.gitignore b/.gitignore index 45e03b9e5..032c30f76 100644 --- a/.gitignore +++ b/.gitignore @@ -150,4 +150,5 @@ OUT/ examples/experiments/grounded_program_synthesis/dataset ckpts/ -ray_results/ +ray_result/ +examples/checkpoints/ diff --git a/examples/dpo_ultrafeedback.py b/examples/dpo_ultrafeedback.py new file mode 100644 index 000000000..4d8dd5b27 --- /dev/null +++ b/examples/dpo_ultrafeedback.py @@ -0,0 +1,99 @@ +import sys +import json +from functools import partial + +from datasets import load_dataset +from transformers import AutoTokenizer + +import trlx +from trlx.data.default_configs import ( + DPOConfig, + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + +model_path = "HuggingFaceH4/mistral-7b-sft-beta" +wandb_project = "trlx" + +default_config = TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=1, + total_steps=70000, + batch_size=1, + checkpoint_interval=100000, + eval_interval=5000, + seed=42, + project_name=wandb_project, + pipeline="PromptPipeline", + trainer="AccelerateDPOTrainer", + checkpoint_dir="checkpoints/dpo_ultrafeedback", + ), + model=ModelConfig(model_path=model_path, num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path=model_path, truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=5e-7, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=1.0e-5)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps + method=DPOConfig( + name="DPOConfig", + gen_kwargs=dict(max_new_tokens=768, temperature=0.7, top_k=50, top_p=0.95, do_sample=True), + beta=0.1, + label_pad_token_id=-100, + padding_value=0, + ), +) + + +def preprocess(sample, tokenizer, test=False): + """ + Formats the input to the same training style used for mistral-7b-v0.1 + When fine-tuning, modify your pre-processing to match the prompt template used during pretraining. + """ + assert len(sample["chosen"]) == len(sample["rejected"]) == 2 + + assistant_prompt = "<|assistant|>" + + prompt, chosen = tokenizer.apply_chat_template(sample["chosen"], tokenize=False).split(assistant_prompt) + rejected = tokenizer.apply_chat_template(sample["rejected"], tokenize=False).split(assistant_prompt)[-1] + + return { + "prompt": prompt if not test else prompt + assistant_prompt, + "chosen": assistant_prompt + chosen, + "rejected": assistant_prompt + rejected, + } + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + dataset = load_dataset("HuggingFaceH4/ultrafeedback_binarized") + + dataset["dpo_train"] = dataset["train_prefs"].map( + partial(preprocess, tokenizer=tokenizer, test=False), + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], + ) + dataset["dpo_test"] = dataset["test_prefs"].map( + partial(preprocess, tokenizer=tokenizer, test=True), + remove_columns=["prompt_id", "score_chosen", "score_rejected", "messages"], + ) + + print( + f"Length of training dataset : {len(dataset['dpo_train'])} \ + Length of test dataset : {len(dataset['dpo_test'])}" + ) + + trlx.train( + config=config, + samples=dataset["dpo_train"], + eval_prompts=dataset["dpo_test"]["prompt"][:2], # running eval on subset only + stop_sequences=["<|user|>", "<|User|>"], + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/examples/hh/dpo_hh.py b/examples/hh/dpo_hh.py new file mode 100644 index 000000000..d102c33da --- /dev/null +++ b/examples/hh/dpo_hh.py @@ -0,0 +1,64 @@ +import json +import sys + +from datasets import load_dataset + +import trlx +from trlx.data.default_configs import ( + DPOConfig, + ModelConfig, + OptimizerConfig, + SchedulerConfig, + TokenizerConfig, + TrainConfig, + TRLConfig, +) + +default_config = TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=4, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="AccelerateDPOTrainer", + checkpoint_dir="checkpoints/dpo_hh", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), + scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4)), # train.total_steps + method=DPOConfig( + name="DPOConfig", + gen_kwargs=dict(max_new_tokens=40, top_k=20, top_p=1.0, do_sample=True), + beta=0.1, + label_pad_token_id=-100, + padding_value=0, + ), +) + + +def preprocess(sample): + sample["dpo"] = [sample["prompt"], sample["chosen"], sample["rejected"]] + return sample + + +def main(hparams={}): + config = TRLConfig.update(default_config, hparams) + + dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess) + + trlx.train( + config=config, + samples=dataset["train"]["dpo"], + eval_prompts=dataset["test"]["prompt"][:280], + # metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, + stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], + ) + + +if __name__ == "__main__": + hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) + main(hparams) diff --git a/trlx/data/default_configs.py b/trlx/data/default_configs.py index 5277d7010..22255d4b8 100644 --- a/trlx/data/default_configs.py +++ b/trlx/data/default_configs.py @@ -2,6 +2,7 @@ from trlx.models.modeling_ilql import ILQLConfig from trlx.models.modeling_ppo import PPOConfig +from trlx.trainer.accelerate_dpo_trainer import DPOConfig from trlx.trainer.accelerate_sft_trainer import SFTConfig from .configs import ( @@ -146,3 +147,29 @@ def default_nemo_1_3b_config(): here = Path(__file__).parent return OmegaConf.load(here.parent.parent / "configs" / "nemo_configs" / "megatron_1.3b.yaml") + + +def default_dpo_config(): + return TRLConfig( + train=TrainConfig( + seq_length=1024, + epochs=100, + total_steps=1000, + batch_size=8, + checkpoint_interval=10000, + eval_interval=100, + pipeline="PromptPipeline", + trainer="DPOTrainer", + ), + model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), + tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), + optimizer=OptimizerConfig( + name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) + ), + scheduler=SchedulerConfig( + name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps + ), + method=DPOConfig( + name="DPOConfig", gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), beta=0.1 + ), + ) diff --git a/trlx/data/dpo_types.py b/trlx/data/dpo_types.py new file mode 100644 index 000000000..fe7eec917 --- /dev/null +++ b/trlx/data/dpo_types.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from transformers import BatchEncoding + + +@dataclass +class DPOElement: + prompt_tokens: BatchEncoding + chosen_tokens: BatchEncoding + rejected_tokens: BatchEncoding + + +# TODO: Extend to include a concrete class for DPOPreferenceBatch diff --git a/trlx/pipeline/__init__.py b/trlx/pipeline/__init__.py index c7dba9e97..7e927cd49 100644 --- a/trlx/pipeline/__init__.py +++ b/trlx/pipeline/__init__.py @@ -166,8 +166,8 @@ def __next__(self): # noqa: C901 minibatch = BatchEncoding(sliced_data) elif is_dataclass(batch): minibatch = batch.__class__(**sliced_data) - # else: - # minibatch = sliced_data + else: + minibatch = sliced_data minibatches.append(minibatch) diff --git a/trlx/pipeline/offline_pipeline.py b/trlx/pipeline/offline_pipeline.py index cee900cfc..21e31d775 100644 --- a/trlx/pipeline/offline_pipeline.py +++ b/trlx/pipeline/offline_pipeline.py @@ -10,6 +10,7 @@ PreTrainedTokenizerFast, ) +from trlx.data.dpo_types import DPOElement from trlx.data.ilql_types import ( ILQLBatch, ILQLElement, @@ -277,3 +278,126 @@ def create_loader(self, batch_size: int): collate_fn=ilql_seq2seq_collate_fn, drop_last=torch.distributed.is_initialized(), ) + + +class DPOStore(BaseRolloutStore): + # Adapted from TRL + def __init__( + self, + preferences: List[DPOElement], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + label_pad_token_id: int, + padding_value: int, + ): + super().__init__() + self.tokenizer = tokenizer + self.label_pad_token_id = label_pad_token_id + self.padding_value = padding_value + + self.history = [ + self._build_batch_from_preference_tokens(preference_element) for preference_element in preferences + ] + + @staticmethod + def tokenize_preferences( + sample: Iterable[str], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + max_length=2048, + max_prompt_length=256, + ) -> DPOElement: + if isinstance(sample, Iterable): + if len(sample) != 3: + raise ValueError( + f"Expected iterable of length 3 (prompt, chosen response, rejected response). Got {len(sample)}" + ) + prompt_tokens = tokenizer(sample["prompt"], add_special_tokens=False) + chosen_tokens = tokenizer(sample["chosen"], add_special_tokens=False) + rejected_tokens = tokenizer(sample["rejected"], add_special_tokens=False) + else: + raise ValueError(f"{sample} is not an iterable") + + chosen_tokens["input_ids"].append(tokenizer.eos_token_id) + chosen_tokens["attention_mask"].append(1) + + rejected_tokens["input_ids"].append(tokenizer.eos_token_id) + rejected_tokens["attention_mask"].append(1) + + longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"])) + + # if combined sequence is too long, truncate the prompt only + if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: + if tokenizer.truncation_side == "right": + prompt_tokens = {k: v[:max_prompt_length] for k, v in prompt_tokens.items()} + elif tokenizer.truncation_side == "left": + prompt_tokens = {k: v[-max_prompt_length:] for k, v in prompt_tokens.items()} + + # if that's still too long, truncate the response + if len(prompt_tokens["input_ids"]) + longer_response_length > max_length: + chosen_tokens = {k: v[: max_length - max_prompt_length] for k, v in chosen_tokens.items()} + rejected_tokens = {k: v[: max_length - max_prompt_length] for k, v in rejected_tokens.items()} + + return DPOElement(prompt_tokens=prompt_tokens, chosen_tokens=chosen_tokens, rejected_tokens=rejected_tokens) + + def _build_batch_from_preference_tokens(self, preference_tokens: DPOElement) -> Dict: + # Create labels + chosen_sequence_tokens = { + k: preference_tokens.prompt_tokens[k] + preference_tokens.chosen_tokens[k] + for k in preference_tokens.chosen_tokens + } + rejected_sequence_tokens = { + k: preference_tokens.prompt_tokens[k] + preference_tokens.rejected_tokens[k] + for k in preference_tokens.rejected_tokens + } + chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:] + chosen_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [ + self.label_pad_token_id + ] * len(preference_tokens.prompt_tokens["input_ids"]) + rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:] + rejected_sequence_tokens["labels"][: len(preference_tokens.prompt_tokens["input_ids"])] = [ + self.label_pad_token_id + ] * len(preference_tokens.prompt_tokens["input_ids"]) + + batch = {} + + for k, toks in { + "chosen": chosen_sequence_tokens, + "rejected": rejected_sequence_tokens, + "prompt": preference_tokens.prompt_tokens, + }.items(): + for type_key, tokens in toks.items(): + if type_key == "token_type_ids": + continue + batch[f"{k}_{type_key}"] = tokens + + return batch + + def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: + def collate_fn(batch: Iterable[dict]): + # first, pad everything to the same length + padded_batch = {} + for k in batch[0].keys(): + if k.endswith("_input_ids") or k.endswith("_attention_mask") or k.endswith("_labels"): + # adapted from https://stackoverflow.com/questions/73256206 + if "prompt" in k: + to_pad = [torch.LongTensor(ex[k][::-1]) for ex in batch] + else: + to_pad = [torch.LongTensor(ex[k]) for ex in batch] + if k.endswith("_input_ids"): + padding_value = self.tokenizer.pad_token_id + elif k.endswith("_labels"): + padding_value = self.label_pad_token_id + elif k.endswith("_attention_mask"): + padding_value = self.padding_value + else: + raise ValueError(f"Unexpected key in batch '{k}'") + + padded_batch[k] = pad_sequence(to_pad, batch_first=True, padding_value=padding_value) + # for the prompt, flip back so padding is on left side + if "prompt" in k: + padded_batch[k] = padded_batch[k].flip(dims=[1]) + else: + padded_batch[k] = [ex[k] for ex in batch] + + return padded_batch + + return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle, pin_memory=True) diff --git a/trlx/trainer/accelerate_dpo_trainer.py b/trlx/trainer/accelerate_dpo_trainer.py new file mode 100644 index 000000000..e01b8e684 --- /dev/null +++ b/trlx/trainer/accelerate_dpo_trainer.py @@ -0,0 +1,323 @@ +from dataclasses import dataclass +from typing import Dict, Iterable, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from accelerate.utils import is_deepspeed_available +from transformers import AutoModelForCausalLM, PretrainedConfig + +if is_deepspeed_available(): + import deepspeed + +import trlx.utils.logging as logging +from trlx.data.configs import TRLConfig +from trlx.data.method_configs import MethodConfig, register_method +from trlx.pipeline.offline_pipeline import DPOStore +from trlx.trainer import register_trainer +from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer +from trlx.utils.modeling import pad_to_length + + +logger = logging.get_logger(__name__) + + +@dataclass +@register_method +class DPOConfig(MethodConfig): + """ + Config for DPO training + + Args: + gen_kwargs (Dict[str, Any]) : kwargs for generation + beta (float) : Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. + label_pad_token_id (int) : token to pad labels with. -100 token is ignored + for CELoss + padding_value (int) : token to pad input sequence with + """ + + gen_kwargs: dict + beta: float = 0.1 + label_pad_token_id: int = -100 # -100 is ignore token for CELoss + padding_value: int = 0 + + +@register_trainer +class AccelerateDPOTrainer(AccelerateRLTrainer): + """DPO Accelerate Trainer""" + + def __init__(self, config: TRLConfig, **kwargs): + super().__init__(config, **kwargs) + + # TODO: Avoid setting up a reference model when hydra heads are used + self.ref_model = self.get_arch(self.config) + try: + if self.accelerator.state.deepspeed_plugin.zero_stage == 3: + self.ref_model = self._prepare_deepspeed_zero3(self.ref_model) + except: + self.ref_model.to(self.accelerator.device) + self.ref_model.eval() + + self.generate_kwargs = dict( + config.method.gen_kwargs, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + + self.beta: float = config.method.beta + self.label_pad_token_id: int = config.method.label_pad_token_id + self.padding_value: int = config.method.padding_value + + def get_arch(self, config): + from_fn = AutoModelForCausalLM.from_pretrained + if issubclass(type(config.model.model_path), PretrainedConfig): + from_fn = AutoModelForCausalLM.from_config + + model = from_fn(config.model.model_path) + + if config.model.peft_config is not None: + # Initialize the peft adapter + import peft + + peft_config = config.model.peft_config + if not isinstance(peft_config, peft.PeftConfig): + if isinstance(peft_config, dict): + peft_config = peft.get_peft_config(peft_config) + else: + raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") + model = peft.get_peft_model(model, peft_config) + if self.accelerator.is_main_process: + model.print_trainable_parameters() + + return model + + def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]: + """Concatenate the chosen and rejected inputs into a single tensor. + Args: + batch (Dict): A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', + which are tensors of shape (batch_size, sequence_length). + Returns: + A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'. + """ + max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1]) + concatenated_batch = {} + for k in batch: + if k.startswith("chosen") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k else self.padding_value + concatenated_key = k.replace("chosen", "concatenated") + concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value) + for k in batch: + if k.startswith("rejected") and isinstance(batch[k], torch.Tensor): + pad_value = self.label_pad_token_id if "labels" in k else self.padding_value + concatenated_key = k.replace("rejected", "concatenated") + concatenated_batch[concatenated_key] = torch.cat( + ( + concatenated_batch[concatenated_key], + pad_to_length(batch[k], max_length, pad_value=pad_value), + ), + dim=0, + ) + return concatenated_batch + + def dpo_loss( + self, + policy_chosen_logps: torch.FloatTensor, + policy_rejected_logps: torch.FloatTensor, + reference_chosen_logps: torch.FloatTensor, + reference_rejected_logps: torch.FloatTensor, + reference_free: bool = False, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Compute the DPO loss for a batch of policy and reference model log probabilities. + Args: + policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,) + policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,) + reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,) + reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape:(batch_size,) + beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the + reference model as beta -> 0. + reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns + equal probability to all responses. + Returns: + A tuple of three tensors: (losses, chosen_rewards, rejected_rewards). + The losses tensor contains the DPO loss for each example in the batch. + The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, + respectively. + """ + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + + if reference_free: + ref_logratios = 0 + + logits = pi_logratios - ref_logratios + + losses = -F.logsigmoid(self.beta * logits) + chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() + rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() + + return losses, chosen_rewards, rejected_rewards + + def _get_batch_logps( + self, + logits: torch.FloatTensor, + labels: torch.LongTensor, + average_log_prob: bool = False, + ) -> torch.FloatTensor: + """Compute the log probabilities of the given labels under the given logits. + Args: + logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size) + labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are + ignored. Shape: (batch_size, sequence_length) + average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum + of the log probabilities of the (non-masked) tokens. + Returns: + A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given + logits. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != self.label_pad_token_id + + # dummy token; we'll ignore the losses on these tokens later + labels[labels == self.label_pad_token_id] = 0 + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + return (per_token_logps * loss_mask).sum(-1) + + def concatenated_forward( + self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]] + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """ + Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. + This is faster and avoids two forward passes. + Args: + model: Base model being trained + batch(Dict): : A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', + which are tensors of shape (batch_size, sequence_length). + + Returns: + A tuple containing 4 tensors : (chosen_log_probabilities, + rejected_log_probabilities, + chosen_logits, + rejected_logits) + The 2 {chosen, rejected}_logp tensors contains the per-token chosen and rejected log probabilities respectively. + The 2 {chosen, rejected}_logits tensors contains the raw logits for chosen and rejected responses from the model + forward pass. + """ + concatenated_batch = self.concatenated_inputs(batch) + all_logits = model( + concatenated_batch["concatenated_input_ids"], + attention_mask=concatenated_batch["concatenated_attention_mask"], + ).logits.to(torch.float32) + all_logps = self._get_batch_logps( + all_logits, + concatenated_batch["concatenated_labels"], + average_log_prob=False, + ) + chosen_logps = all_logps[: batch["chosen_input_ids"].shape[0]] + rejected_logps = all_logps[batch["chosen_input_ids"].shape[0] :] + + chosen_logits = all_logits[: batch["chosen_input_ids"].shape[0]] + rejected_logits = all_logits[batch["chosen_input_ids"].shape[0] :] + return (chosen_logps, rejected_logps, chosen_logits, rejected_logits) + + def loss(self, batch: Dict[str, Union[List, torch.LongTensor]]): + stats = {} + + ( + policy_chosen_logps, + policy_rejected_logps, + policy_chosen_logits, + policy_rejected_logits, + ) = self.concatenated_forward(self.model, batch) + with torch.no_grad(): + ( + reference_chosen_logps, + reference_rejected_logps, + _, + _, + ) = self.concatenated_forward(self.ref_model, batch) + + losses, chosen_rewards, rejected_rewards = self.dpo_loss( + policy_chosen_logps, + policy_rejected_logps, + reference_chosen_logps, + reference_rejected_logps, + ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() + + stats["rewards/chosen"] = chosen_rewards.cpu().numpy().mean() + stats["rewards/rejected"] = rejected_rewards.cpu().numpy().mean() + stats["rewards/accuracies"] = reward_accuracies.cpu().numpy().mean() + stats["rewards/margins"] = (chosen_rewards - rejected_rewards).cpu().numpy().mean() + stats["logps/rejected"] = policy_rejected_logps.detach().cpu().numpy().mean() + stats["logps/chosen"] = policy_chosen_logps.detach().cpu().numpy().mean() + + stats["logits/rejected"] = policy_rejected_logits.detach().cpu().numpy().mean() + stats["logits/chosen"] = policy_chosen_logits.detach().cpu().numpy().mean() + + stats["loss"] = losses.detach().cpu().numpy().mean() + + return losses.mean(), stats + + def _prepare_deepspeed_zero3(self, model: nn.Module): + # Adapted from accelerate: + # https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473 + # TODO: figure out if any other parameters are needed to optimize inference + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + + # See DeepSpeed docs for definition of these parameters: https://deepspeed.readthedocs.io/en/latest/zero3.html + config_kwargs = { + "train_micro_batch_size_per_gpu": self.config.train.batch_size, + "train_batch_size": self.config.train.batch_size + * self.accelerator.state.num_processes + * deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"] + * self.accelerator.num_processes, + "zero_optimization": {"stage": 3, "offload_param": {"device": deepspeed_plugin.offload_param_device}}, + } + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None: + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model + + def prepare_learning(self): + train_dataloader = self.store.create_loader(self.config.train.batch_size) + eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) + + ( + self.model, + self.opt, + self.train_dataloader, + self.eval_dataloader, + ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader) + + self.n_updates_per_batch = 1 + self.total_steps = self.config.train.epochs * len(self.train_dataloader) + self.total_steps = min(self.total_steps, self.config.train.total_steps) + + def make_experience(self, samples: Iterable[Iterable], seq_length: int, max_prompt_length: int): + preferences = [ + DPOStore.tokenize_preferences(sample, self.tokenizer, seq_length, max_prompt_length) for sample in samples + ] + self.store = DPOStore(preferences, self.tokenizer, self.label_pad_token_id, self.padding_value) diff --git a/trlx/trainer/accelerate_sft_trainer.py b/trlx/trainer/accelerate_sft_trainer.py index d5cbe3ea5..b76f50f14 100644 --- a/trlx/trainer/accelerate_sft_trainer.py +++ b/trlx/trainer/accelerate_sft_trainer.py @@ -87,7 +87,7 @@ def prepare_learning(self): self.total_steps = self.config.train.epochs * len(self.train_dataloader) self.total_steps = min(self.total_steps, self.config.train.total_steps) - def make_experience(self, samples, seq_length): + def make_experience(self, samples, seq_length, **kwargs): if isinstance(samples[0], str): self.store = PromptPipeline(samples, seq_length, self.tokenizer) else: diff --git a/trlx/trainer/nemo_sft_trainer.py b/trlx/trainer/nemo_sft_trainer.py index 7f25254f1..d7410b35c 100644 --- a/trlx/trainer/nemo_sft_trainer.py +++ b/trlx/trainer/nemo_sft_trainer.py @@ -133,7 +133,7 @@ def eval_collate(elems): torch.set_float32_matmul_precision("medium") self.trainer.fit(self.model) - def make_experience(self, samples, seq_length): + def make_experience(self, samples, seq_length, **kwargs): if isinstance(samples[0], str): self.store = PromptPipeline(samples, seq_length, self.tokenizer) else: diff --git a/trlx/trlx.py b/trlx/trlx.py index 7fbce94f4..ba75b5f54 100644 --- a/trlx/trlx.py +++ b/trlx/trlx.py @@ -64,7 +64,7 @@ def train( # noqa: C901 config = default_ppo_config() elif rewards: config = default_ilql_config() - else: + else: # Alternatively, could be `default_dpo_config()`. But, ignoring since passing `config` implicitly is deprecated config = default_sft_config() set_seed(config.train.seed) @@ -102,7 +102,7 @@ def train( # noqa: C901 if eval_prompts is None: eval_prompts = prompts[:batch_size] - # Offline training from the collected samples (e.g. SFT, ILQL) + # Offline training from the collected samples (e.g. SFT, ILQL, DPO) elif samples: if rewards is not None: if len(samples) != len(rewards): @@ -114,7 +114,8 @@ def train( # noqa: C901 if rewards is not None: trainer.make_experience(samples, rewards, config.train.seq_length) else: - trainer.make_experience(samples, config.train.seq_length) + # this should be abstracted for all trainers with **kwargs + trainer.make_experience(samples, config.train.seq_length, max_prompt_length) else: raise ValueError("Either `samples` or `reward_fn` should be given for training") diff --git a/trlx/utils/loading.py b/trlx/utils/loading.py index 9c7dccf76..3dc5a4c52 100644 --- a/trlx/utils/loading.py +++ b/trlx/utils/loading.py @@ -6,6 +6,7 @@ # Register load trainers via module import from trlx.trainer import _TRAINERS, register_trainer +from trlx.trainer.accelerate_dpo_trainer import AccelerateDPOTrainer from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 47688f553..e5de3d990 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -265,6 +265,22 @@ def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int): ) +def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float], dim: int = -1) -> torch.Tensor: + # From original TRL + if tensor.size(dim) >= length: + return tensor + else: + pad_size = list(tensor.shape) + pad_size[dim] = length - tensor.size(dim) + return torch.cat( + [ + tensor, + pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device), + ], + dim=dim, + ) + + class RunningMoments: def __init__(self): """