-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Description
System Info
In src/transformers/tokenization_utils_base.py on line 3370:
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}for key in encoded_inputs[0] raises RuntimeError: generator raised StopIteration when encoded_inputs is a list of tensordict.TensorDict objects.
TensorDict.__iter__() is designed to iterate over the first shape-dimension of its value tensors, rather than dict keys. To iterate over keys, it needs TensorDict.keys().
This bug was not there as of transformers 4.51.3, at which point the same line in tokenization_utils_base.py reads:
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}so I believe the breaking change happened at #37283 where many .keys() were removed.
It's up to debate whether TensorDict.__iter__() should behave more like a batch or a dict, but imo from transformers pov, it's better to be safe -- to always call .keys() explicitly -- than sorry.
Who can help?
@cyyever @Cyrilvallez @ArthurZucker @Rocketknight1
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
from transformers import AutoTokenizer, DataCollatorWithPadding
from tensordict import TensorDict
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
collator = DataCollatorWithPadding(tokenizer=tokenizer)
batch = [
TensorDict({'input_ids': torch.tensor([9,8,7]), 'attention_mask': torch.tensor([1,1,1])}),
TensorDict({'input_ids': torch.tensor([6,5]), 'attention_mask': torch.tensor([1,1])}),
]
collator(batch)Expected behavior
Expected:
{'input_ids': tensor([[9, 8, 7],
[6, 5, 1]]), 'attention_mask': tensor([[1, 1, 1],
[1, 1, 0]])}
Actual:
File "/.../.venv/lib/python3.12/site-packages/transformers/data/data_collator.py", line 271, in __call__
batch = pad_without_fast_tokenizer_warning(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../.venv/lib/python3.12/site-packages/transformers/data/data_collator.py", line 66, in pad_without_fast_tokenizer_warning
padded = tokenizer.pad(*pad_args, **pad_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/.../.venv/lib/python3.12/site-packages/transformers/tokenization_utils_base.py", line 3370, in pad
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
~~~~~~~~~~~~~~^^^
RuntimeError: generator raised StopIteration