Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2517,9 +2517,45 @@ def _inner_training_loop(
and self.accelerator.distributed_type != DistributedType.DEEPSPEED
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)

cp_context = contextlib.nullcontext()
if self.accelerator.parallelism_config and self.accelerator.parallelism_config.cp_enabled:
cp_size_multiple = self.accelerator.parallelism_config.cp_size * 2
max_length = (
(inputs["input_ids"].shape[1] + (cp_size_multiple - 1)) // cp_size_multiple
) * cp_size_multiple
inputs = self._pad_batch(inputs, max_length=max_length)
cp_context = self.accelerator.maybe_context_parallel(
buffers=[
inputs["input_ids"],
inputs["shift_labels"],
inputs["labels"],
inputs["position_ids"],
],
buffer_seq_dims=[1, 1, 1, 1],
no_restore_buffers={
inputs["input_ids"],
inputs["shift_labels"],
inputs["labels"],
inputs["position_ids"],
},
)
Comment on lines +2521 to +2542
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have _prepare_context_parallel_inputs in training_step ! If things needs to be changed, it's better to do it there

with context(), cp_context:
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
if self.accelerator.parallelism_config and self.accelerator.parallelism_config.cp_enabled:
loss_reduce_grp = (
self.accelerator.torch_device_mesh["dp_cp"].get_group()
if self.accelerator.parallelism_config.dp_cp_dim_names
else None
)
if num_items_in_batch:
dist.all_reduce(tr_loss_step, op=dist.ReduceOp.SUM, group=loss_reduce_grp)
tr_loss_step = tr_loss_step / (
getattr(self.accelerator.parallelism_config, "dp_replicate_size", 1)
* getattr(self.accelerator.parallelism_config, "dp_shard_size", 1)
)
else:
dist.all_reduce(tr_loss_step, op=dist.ReduceOp.AVG, group=loss_reduce_grp)
if (
args.logging_nan_inf_filter
and not is_torch_xla_available()
Expand Down Expand Up @@ -5310,3 +5346,33 @@ def set_initial_training_values(
len_dataloader,
max_steps,
)

def _pad_batch(self, batch, pad_token_id=0, label_pad_token_id=-100, max_length=0):
# helper function for CP to pad batches to sequence length multiple of cp_degree * 2
def pad_and_truncate(sequence, pad_value, max_length):
sequence = sequence[:max_length]
return torch.cat(
[torch.full((max_length - len(sequence),), pad_value, dtype=torch.long).to(sequence.device), sequence]
)

input_ids = torch.stack(
[pad_and_truncate(example, pad_token_id, max_length) for example in batch["input_ids"]]
)
attention_mask = torch.stack([pad_and_truncate(example, 0, max_length) for example in batch["attention_mask"]])
labels = torch.stack(
[pad_and_truncate(example, label_pad_token_id, max_length) for example in batch["labels"]]
)

# prepare shift_labels for loss computation
shift_labels = torch.nn.functional.pad(labels, (0, 1), value=label_pad_token_id)
shift_labels = shift_labels[..., 1:].contiguous()
position_ids = (
torch.cumsum(torch.ones(size=input_ids.size(), dtype=input_ids.dtype, device=input_ids.device), dim=1) - 1
)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
"shift_labels": shift_labels,
"position_ids": position_ids,
}