diff --git a/src/transformers/models/clap/modeling_clap.py b/src/transformers/models/clap/modeling_clap.py index 89ad2ec26a61..53d08089e869 100644 --- a/src/transformers/models/clap/modeling_clap.py +++ b/src/transformers/models/clap/modeling_clap.py @@ -1043,8 +1043,11 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l Returns: torch.Tensor """ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. - mask = input_ids.ne(padding_idx).int() - incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask + mask = input_ids.ne(padding_idx) + incremental_indices = torch.cumsum(mask, dim=1) + if past_key_values_length != 0: + incremental_indices = incremental_indices + past_key_values_length + incremental_indices = incremental_indices * mask return incremental_indices.long() + padding_idx