Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/feature_extraction_sequence_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ def pad(
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(processed_features, (list, tuple)) and isinstance(processed_features[0], (dict, BatchFeature)):
# Use .keys() explicitly to support dict-like objects (e.g., TensorDict) that implement
# __iter__ differently than standard dicts (e.g., iterating over batch dimensions)
processed_features = {
key: [example[key] for example in processed_features] for key in processed_features[0]
key: [example[key] for example in processed_features] for key in processed_features[0].keys()
}

# The model's main input name, usually `input_values`, has be passed for padding
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/luke/tokenization_luke.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,7 +1449,9 @@ def pad(
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
# Use .keys() explicitly to support dict-like objects (e.g., TensorDict) that implement
# __iter__ differently than standard dicts (e.g., iterating over batch dimensions)
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}

# The model's main input name, usually `input_ids`, has be passed for padding
if self.model_input_names[0] not in encoded_inputs:
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/mluke/tokenization_mluke.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,9 @@ def pad(
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
# Use .keys() explicitly to support dict-like objects (e.g., TensorDict) that implement
# __iter__ differently than standard dicts (e.g., iterating over batch dimensions)
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}

# The model's main input name, usually `input_ids`, has be passed for padding
if self.model_input_names[0] not in encoded_inputs:
Expand Down
26 changes: 26 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,32 @@ def require_torch(test_case):
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)


def is_tensordict_available():
"""
Check if tensordict is available.

Returns:
bool: True if tensordict can be imported, False otherwise.
"""
try:
import tensordict # noqa: F401

return True
except ImportError:
return False


def require_tensordict(test_case):
"""
Decorator marking a test that requires tensordict.

These tests are skipped when tensordict isn't installed.
TensorDict is used for testing compatibility with dict-like objects
that implement __iter__ differently than standard dicts.
"""
return unittest.skipUnless(is_tensordict_available(), "test requires tensordict")(test_case)


def require_torch_greater_or_equal(version: str):
"""
Decorator marking a test that requires PyTorch version >= `version`.
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/tokenization_mistral_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,9 @@ def pad(
# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
# Use .keys() explicitly to support dict-like objects (e.g., TensorDict) that implement
# __iter__ differently than standard dicts (e.g., iterating over batch dimensions)
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}

# The model's main input name, usually `input_ids`, has been passed for padding
if self.model_input_names[0] not in encoded_inputs:
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3476,8 +3476,14 @@ def pad(

# If we have a list of dicts, let's convert it in a dict of lists
# We do this to allow using this method as a collate_fn function in PyTorch Dataloader
if isinstance(encoded_inputs, (list, tuple)) and isinstance(encoded_inputs[0], Mapping):
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0]}
if (
isinstance(encoded_inputs, (list, tuple))
and len(encoded_inputs) > 0
and isinstance(encoded_inputs[0], Mapping)
):
# Use .keys() explicitly to support dict-like objects (e.g., TensorDict) that implement
Copy link

@ligz08 ligz08 Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the same comment needs to be repeated everywhere. A short inline one like "call .keys() explicitly to avoid issue #42370" is probably better for readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have done the required changes , sir

# __iter__ differently than standard dicts (e.g., iterating over batch dimensions)
encoded_inputs = {key: [example[key] for example in encoded_inputs] for key in encoded_inputs[0].keys()}

# The model's main input name, usually `input_ids`, has been passed for padding
if self.model_input_names[0] not in encoded_inputs:
Expand Down
240 changes: 240 additions & 0 deletions tests/trainer/test_tensordict_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tests for TensorDict compatibility with data collators and tokenizers.
This module tests that dict-like objects (specifically TensorDict) work correctly
with transformers' padding and collation functionality. TensorDict implements
__iter__ to iterate over batch dimensions rather than keys, which requires
explicit .keys() calls in the codebase.
"""

import unittest

from transformers import (
AutoTokenizer,
DataCollatorForLanguageModeling,
DataCollatorWithPadding,
is_torch_available,
)
from transformers.testing_utils import require_tensordict, require_torch


if is_torch_available():
import torch


@require_torch
@require_tensordict
class TensorDictCompatibilityTest(unittest.TestCase):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo these tests could be added to tests/tokenization/test_tokenization_utils.py and/or tests/trainer/test_data_collator.py, rather than creating a new .py file

"""Test suite for TensorDict compatibility with data collators and tokenizers."""

def setUp(self):
"""Set up test fixtures."""
from tensordict import TensorDict

self.TensorDict = TensorDict
# Use a small, fast-loading tokenizer for tests
self.tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

def test_data_collator_with_padding_tensordict(self):
"""
Test that DataCollatorWithPadding works correctly with TensorDict inputs.
This is a regression test for issue where TensorDict.__iter__() iterates
over batch dimensions instead of keys, causing RuntimeError: generator raised StopIteration.
"""
collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

# Create batch with TensorDict objects of different lengths
batch = [
self.TensorDict(
{"input_ids": torch.tensor([9, 8, 7]), "attention_mask": torch.tensor([1, 1, 1])},
batch_size=[],
),
self.TensorDict(
{"input_ids": torch.tensor([6, 5]), "attention_mask": torch.tensor([1, 1])}, batch_size=[]
),
]

# This should not raise RuntimeError
result = collator(batch)

# Verify the output is correctly padded (can be dict or Mapping like BatchEncoding)
from collections.abc import Mapping

self.assertIsInstance(result, Mapping)
self.assertIn("input_ids", result)
self.assertIn("attention_mask", result)

# Check shapes - should be padded to max length (3)
self.assertEqual(result["input_ids"].shape, torch.Size([2, 3]))
self.assertEqual(result["attention_mask"].shape, torch.Size([2, 3]))

# Check padding is correct
expected_input_ids = torch.tensor([[9, 8, 7], [6, 5, self.tokenizer.pad_token_id]])
expected_attention_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])

self.assertTrue(torch.equal(result["input_ids"], expected_input_ids))
self.assertTrue(torch.equal(result["attention_mask"], expected_attention_mask))

def test_data_collator_with_padding_tensordict_variable_lengths(self):
"""Test DataCollatorWithPadding with TensorDict inputs of highly variable lengths."""
collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

batch = [
self.TensorDict(
{"input_ids": torch.tensor([1, 2, 3, 4, 5]), "attention_mask": torch.tensor([1, 1, 1, 1, 1])},
batch_size=[],
),
self.TensorDict({"input_ids": torch.tensor([6]), "attention_mask": torch.tensor([1])}, batch_size=[]),
self.TensorDict(
{"input_ids": torch.tensor([7, 8, 9]), "attention_mask": torch.tensor([1, 1, 1])}, batch_size=[]
),
]

result = collator(batch)

# Should be padded to max length (5)
self.assertEqual(result["input_ids"].shape, torch.Size([3, 5]))
self.assertEqual(result["attention_mask"].shape, torch.Size([3, 5]))

# Check that shorter sequences are padded
self.assertEqual(result["input_ids"][1, 1:].tolist(), [self.tokenizer.pad_token_id] * 4)
self.assertEqual(result["attention_mask"][1, 1:].tolist(), [0] * 4)

def test_data_collator_language_modeling_tensordict(self):
"""Test DataCollatorForLanguageModeling with TensorDict inputs."""
collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)

batch = [
self.TensorDict(
{"input_ids": torch.tensor([1, 2, 3, 4])},
batch_size=[],
),
self.TensorDict(
{"input_ids": torch.tensor([5, 6])},
batch_size=[],
),
]

result = collator(batch)

self.assertIn("input_ids", result)
self.assertIn("labels", result)
# Should be padded
self.assertEqual(result["input_ids"].shape[0], 2)
self.assertEqual(result["labels"].shape[0], 2)

def test_tokenizer_pad_method_with_tensordict(self):
"""Test tokenizer.pad() method directly with TensorDict inputs."""
# Create pre-tokenized inputs as TensorDict
batch = [
self.TensorDict(
{
"input_ids": torch.tensor([101, 2023, 2003, 102]),
"attention_mask": torch.tensor([1, 1, 1, 1]),
},
batch_size=[],
),
self.TensorDict(
{
"input_ids": torch.tensor([101, 102]),
"attention_mask": torch.tensor([1, 1]),
},
batch_size=[],
),
]

# This should not raise RuntimeError
result = self.tokenizer.pad(batch, return_tensors="pt")

self.assertIn("input_ids", result)
self.assertIn("attention_mask", result)
self.assertEqual(result["input_ids"].shape, torch.Size([2, 4]))

def test_mixed_tensordict_and_dict_inputs(self):
"""Test that collator handles mixed TensorDict and regular dict inputs gracefully."""
collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

# Mix of TensorDict and regular dict
batch = [
self.TensorDict(
{"input_ids": torch.tensor([1, 2, 3]), "attention_mask": torch.tensor([1, 1, 1])}, batch_size=[]
),
{"input_ids": torch.tensor([4, 5]), "attention_mask": torch.tensor([1, 1])},
]

result = collator(batch)

self.assertEqual(result["input_ids"].shape, torch.Size([2, 3]))
self.assertEqual(result["attention_mask"].shape, torch.Size([2, 3]))

def test_tensordict_with_additional_fields(self):
"""Test TensorDict inputs with additional fields beyond input_ids and attention_mask."""
collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

batch = [
self.TensorDict(
{
"input_ids": torch.tensor([1, 2, 3]),
"attention_mask": torch.tensor([1, 1, 1]),
"token_type_ids": torch.tensor([0, 0, 0]),
"special_tokens_mask": torch.tensor([1, 0, 1]),
},
batch_size=[],
),
self.TensorDict(
{
"input_ids": torch.tensor([4, 5]),
"attention_mask": torch.tensor([1, 1]),
"token_type_ids": torch.tensor([0, 0]),
"special_tokens_mask": torch.tensor([1, 0]),
},
batch_size=[],
),
]

result = collator(batch)

# All fields should be present and padded
self.assertIn("input_ids", result)
self.assertIn("attention_mask", result)
self.assertIn("token_type_ids", result)
self.assertIn("special_tokens_mask", result)

# Check all are padded to same length
for key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
self.assertEqual(result[key].shape, torch.Size([2, 3]), f"Field {key} has wrong shape")

def test_single_tensordict_input(self):
"""Test collator with a single TensorDict input."""
collator = DataCollatorWithPadding(tokenizer=self.tokenizer)

batch = [
self.TensorDict(
{"input_ids": torch.tensor([1, 2, 3]), "attention_mask": torch.tensor([1, 1, 1])}, batch_size=[]
),
]

result = collator(batch)

# Single input should not cause issues
self.assertEqual(result["input_ids"].shape, torch.Size([1, 3]))
self.assertEqual(result["attention_mask"].shape, torch.Size([1, 3]))


if __name__ == "__main__":
unittest.main()