-
Notifications
You must be signed in to change notification settings - Fork 31.2k
fix: Restore explicit .keys() calls for TensorDict compatibility #42373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix: Restore explicit .keys() calls for TensorDict compatibility #42373
Conversation
Fixes issue where TensorDict objects cause RuntimeError: generator raised StopIteration when used with data collators and tokenization utilities. Problem: - TensorDict.__iter__() iterates over batch dimensions instead of dictionary keys - PR huggingface#37283 removed explicit .keys() calls, breaking TensorDict compatibility - Affected DataCollatorWithPadding, DataCollatorForLanguageModeling, and other collators Solution: - Restored explicit .keys() calls in 5 critical locations where dict-list conversion happens - Added len() check to handle empty batch edge case - Changes are backward compatible and generalize to all Mapping objects Files modified: - src/transformers/tokenization_utils_base.py: Fixed pad() method - src/transformers/tokenization_mistral_common.py: Fixed pad() method - src/transformers/feature_extraction_sequence_utils.py: Fixed pad() method - src/transformers/models/mluke/tokenization_mluke.py: Fixed pad() method - src/transformers/models/luke/tokenization_luke.py: Fixed pad() method Testing: - Added comprehensive test suite: tests/trainer/test_tensordict_compatibility.py - 7 test cases covering basic padding, variable lengths, mixed inputs, additional fields - Added @require_tensordict decorator and is_tensordict_available() in testing_utils.py - All existing tests pass (54/54 data collator tests, 2/2 padding tests) Impact: - Zero performance regression for standard dict usage - Restores functionality for TensorDict and other Mapping implementations - Fully backward compatible
|
|
||
| @require_torch | ||
| @require_tensordict | ||
| class TensorDictCompatibilityTest(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo these tests could be added to tests/tokenization/test_tokenization_utils.py and/or tests/trainer/test_data_collator.py, rather than creating a new .py file
| and len(encoded_inputs) > 0 | ||
| and isinstance(encoded_inputs[0], Mapping) | ||
| ): | ||
| # Use .keys() explicitly to support dict-like objects (e.g., TensorDict) that implement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the same comment needs to be repeated everywhere. A short inline one like "call .keys() explicitly to avoid issue #42370" is probably better for readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
have done the required changes , sir
- Move TensorDict tests from standalone file to test_data_collator.py - Simplify comments from verbose explanation to short reference: 'Call .keys() explicitly to avoid issue huggingface#42370' - Delete tests/trainer/test_tensordict_compatibility.py (tests now in test_data_collator.py) Addresses feedback from @ligz08 in PR review
|
Hey, thank you for the PR! As far as I know, we aren't using |
In my humble opinion it's good to keep some minimal tests against
Agreed, I take back my previous comment. |
…bility Changed from 'avoid issue huggingface#42370' to 'for compatibility with TensorDict and other Mapping subclasses' so users don't need to look up the issue on GitHub to understand why .keys() is needed. Addresses maintainer feedback.
|
The failing test test_training_gradient_checkpointing_use_reentrant_false in Longformer is unrelated to this PR. Here's why: Scope of changes: This PR only modifies tokenization padding logic (adding explicit .keys() calls) in 5 tokenizer files. No changes to model architecture or gradient computation. Luke's gradient checkpointing tests are already skipped: The Luke model (one of the modified tokenizers) has all three gradient checkpointing tests explicitly disabled with @unittest.skip because "This architecture seem to not compute gradients properly when using GC" (see test_modeling_luke.py). No connection between changes and failure: Tokenization happens during data preprocessing |
|
@ligz08 @Rocketknight1 any more changes required ? |
|
Hmn, I'd still remove the tests and |
- Removed TensorDictCompatibilityTest class from test_data_collator.py - Removed is_tensordict_available() and require_tensordict() from testing_utils.py - TensorDict is not a CI dependency, so these tests would be skipped anyway - The .keys() fix for TensorDict compatibility remains in place
|
[For maintainers] Suggested jobs to run (before merge) run-slow: luke, mluke |
|
@Rocketknight1 I have made the required changes sir |
Fix: Restore explicit .keys() calls for TensorDict compatibility
Fixes: PreTrainedTokenizerBase.pad() broken on tensordict.TensorDict #42370
Open
What does this PR do?
This PR restores TensorDict compatibility that was broken by PR #37283, which removed explicit
.keys()calls from dictionary iteration patterns. The issue causedRuntimeError: generator raised StopIterationwhen using TensorDict objects with data collators and tokenization utilities.Problem
TensorDict objects implement
__iter__()to iterate over batch dimensions rather than dictionary keys (as standard dicts do). When PR #37283 replaced explicit.keys()calls with implicit iteration (e.g.,for key in dict[0]), it broke compatibility with TensorDict and similar Mapping implementations.Error encountered:
This affected:
DataCollatorWithPadding
DataCollatorForLanguageModeling
tokenizer.pad() method
Any code path that converts a list of dict-like objects into a dict of lists
Solution
Restored explicit .keys() calls in 5 critical locations where dict-list conversion happens during padding operations. This ensures compatibility with any object implementing the Mapping protocol, regardless of how it implements iter().
Pattern applied:
Also added a length check to prevent IndexError on empty batches:
if isinstance(encoded_inputs, (list, tuple)) and len(encoded_inputs) > 0 and isinstance(encoded_inputs[0], Mapping):Before submitting
This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
Did you read the contributor guideline, Pull Request section?
Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
Did you make sure to update the documentation with your changes? (Added inline comments explaining the fix)
Did you write any new necessary tests? (Yes - comprehensive test suite with 7 test cases)
Who can review?
@ArthurZucker @Rocketknight1 - This affects tokenizers and data collators @SunMarc - This affects the trainer data collators
Additional Context:
This is a minimal, surgical fix that only touches the specific lines where dict-list conversion happens. The fix is defensive and generalizes well to any Mapping implementation that might have non-standard iter() behavior.