Skip to content

PreTrainedTokenizerBase.pad() broken on tensordict.TensorDict #42370

@ligz08

Description

@ligz08

System Info

In src/transformers/tokenization_utils_base.py on line 3370:

encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}

for key in encoded_inputs[0] raises RuntimeError: generator raised StopIteration when encoded_inputs is a list of tensordict.TensorDict objects.

TensorDict.__iter__() is designed to iterate over the first shape-dimension of its value tensors, rather than dict keys. To iterate over keys, it needs TensorDict.keys().

This bug was not there as of transformers 4.51.3, at which point the same line in tokenization_utils_base.py reads:

encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}

so I believe the breaking change happened at #37283 where many .keys() were removed.

It's up to debate whether TensorDict.__iter__() should behave more like a batch or a dict, but imo from transformers pov, it's better to be safe -- to always call .keys() explicitly -- than sorry.

Who can help?

@cyyever @Cyrilvallez @ArthurZucker @Rocketknight1

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

from transformers import AutoTokenizer, DataCollatorWithPadding
from tensordict import TensorDict
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
collator = DataCollatorWithPadding(tokenizer=tokenizer)
batch = [
    TensorDict({'input_ids': torch.tensor([9,8,7]), 'attention_mask': torch.tensor([1,1,1])}),
    TensorDict({'input_ids': torch.tensor([6,5]), 'attention_mask': torch.tensor([1,1])}),
]
collator(batch)

Expected behavior

Expected:

{'input_ids': tensor([[9, 8, 7],
        [6, 5, 1]]), 'attention_mask': tensor([[1, 1, 1],
        [1, 1, 0]])}

Actual:

  File "/.../.venv/lib/python3.12/site-packages/transformers/data/data_collator.py", line 271, in __call__
    batch = pad_without_fast_tokenizer_warning(
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../.venv/lib/python3.12/site-packages/transformers/data/data_collator.py", line 66, in pad_without_fast_tokenizer_warning
    padded = tokenizer.pad(*pad_args, **pad_kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.../.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py", line 3370, in pad
    encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
                                                                                   ~~~~~~~~~~~~~~^^^
RuntimeError: generator raised StopIteration

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions