From 05f48cb597f30d3868e233b6a351d53b398aada7 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Sun, 30 Nov 2025 00:45:20 +0530 Subject: [PATCH] feat: allow cp with trainer Signed-off-by: Mehant Kammakomati --- src/transformers/trainer.py | 70 +++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 65331ba8322c..ea0092c44061 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2517,9 +2517,45 @@ def _inner_training_loop( and self.accelerator.distributed_type != DistributedType.DEEPSPEED else contextlib.nullcontext ) - with context(): - tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + cp_context = contextlib.nullcontext() + if self.accelerator.parallelism_config and self.accelerator.parallelism_config.cp_enabled: + cp_size_multiple = self.accelerator.parallelism_config.cp_size * 2 + max_length = ( + (inputs["input_ids"].shape[1] + (cp_size_multiple - 1)) // cp_size_multiple + ) * cp_size_multiple + inputs = self._pad_batch(inputs, max_length=max_length) + cp_context = self.accelerator.maybe_context_parallel( + buffers=[ + inputs["input_ids"], + inputs["shift_labels"], + inputs["labels"], + inputs["position_ids"], + ], + buffer_seq_dims=[1, 1, 1, 1], + no_restore_buffers={ + inputs["input_ids"], + inputs["shift_labels"], + inputs["labels"], + inputs["position_ids"], + }, + ) + with context(), cp_context: + tr_loss_step = self.training_step(model, inputs, num_items_in_batch) + if self.accelerator.parallelism_config and self.accelerator.parallelism_config.cp_enabled: + loss_reduce_grp = ( + self.accelerator.torch_device_mesh["dp_cp"].get_group() + if self.accelerator.parallelism_config.dp_cp_dim_names + else None + ) + if num_items_in_batch: + dist.all_reduce(tr_loss_step, op=dist.ReduceOp.SUM, group=loss_reduce_grp) + tr_loss_step = tr_loss_step / ( + getattr(self.accelerator.parallelism_config, "dp_replicate_size", 1) + * getattr(self.accelerator.parallelism_config, "dp_shard_size", 1) + ) + else: + dist.all_reduce(tr_loss_step, op=dist.ReduceOp.AVG, group=loss_reduce_grp) if ( args.logging_nan_inf_filter and not is_torch_xla_available() @@ -5310,3 +5346,33 @@ def set_initial_training_values( len_dataloader, max_steps, ) + + def _pad_batch(self, batch, pad_token_id=0, label_pad_token_id=-100, max_length=0): + # helper function for CP to pad batches to sequence length multiple of cp_degree * 2 + def pad_and_truncate(sequence, pad_value, max_length): + sequence = sequence[:max_length] + return torch.cat( + [torch.full((max_length - len(sequence),), pad_value, dtype=torch.long).to(sequence.device), sequence] + ) + + input_ids = torch.stack( + [pad_and_truncate(example, pad_token_id, max_length) for example in batch["input_ids"]] + ) + attention_mask = torch.stack([pad_and_truncate(example, 0, max_length) for example in batch["attention_mask"]]) + labels = torch.stack( + [pad_and_truncate(example, label_pad_token_id, max_length) for example in batch["labels"]] + ) + + # prepare shift_labels for loss computation + shift_labels = torch.nn.functional.pad(labels, (0, 1), value=label_pad_token_id) + shift_labels = shift_labels[..., 1:].contiguous() + position_ids = ( + torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1 + ) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "labels": labels, + "shift_labels": shift_labels, + "position_ids": position_ids, + }