Fix Incorrect Hidden State Extraction with Right Padding#26
Fix Incorrect Hidden State Extraction with Right Padding#26xiuyuz wants to merge 1 commit intoGen-Verse:mainfrom
Conversation
|
This might be relevant to #25 discovered by @wonjun-chung. |
|
Hi @xiuyuz , Thanks for your great contribution to our LatentMAS work! We will shortly review the code and merge it properly :) We will also ensure to mention your extension to our work later in the README. Thanks again! |
There was a problem hiding this comment.
Pull request overview
This PR fixes a bug where hidden states were incorrectly extracted from padding tokens when using right-padded batches in latent reasoning steps. The fix uses the attention mask to identify the last non-padding token for each sequence in the batch, ensuring correct initialization of latent reasoning steps.
Key Changes:
- Modified
generate_latent_batchto use attention mask-based indexing to find the last real token instead of always using[:, -1, :] - Modified
generate_latent_batch_hidden_statewith the same padding-aware logic - Added conditional logic to only apply the fix when
past_key_values is None(initial call with potential padding)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| last_token_indices = attention_mask.sum(1) - 1 | ||
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | ||
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | ||
| else: | ||
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
There was a problem hiding this comment.
The indentation inside this else block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Line 402 should be indented with 12 spaces (8 base + 4 for the else block) rather than 13 spaces.
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] | |
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
| # Identify last token index | ||
| # attention_mask (at this point, if past is None, it is just original mask) | ||
| last_token_indices = attention_mask.sum(1) - 1 | ||
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) |
There was a problem hiding this comment.
For clarity and consistency with how attention_mask is handled, batch_indices should be created using self.device instead of input_ids.device. While input_ids.device should be the same as self.device (otherwise the model call would fail), using self.device explicitly makes the device management more clear and matches the pattern used for attention_mask at line 289.
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| batch_indices = torch.arange(input_ids.shape[0], device=self.device) |
|
|
||
| if past_key_values is None: | ||
| last_token_indices = attention_mask.sum(1) - 1 | ||
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) |
There was a problem hiding this comment.
For clarity and consistency with how attention_mask is handled, batch_indices should be created using self.HF_device instead of input_ids.device. While input_ids.device should be the same as self.HF_device (otherwise the model call would fail), using self.HF_device explicitly makes the device management more clear and matches the pattern used for attention_mask at line 375.
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| batch_indices = torch.arange(input_ids.shape[0], device=self.HF_device) |
| if past_key_values is None: | ||
| # Identify last token index | ||
| # attention_mask (at this point, if past is None, it is just original mask) | ||
| last_token_indices = attention_mask.sum(1) - 1 |
There was a problem hiding this comment.
The attention_mask sum could include padding tokens (zeros) in the mask. If a sequence is completely padded (all zeros in attention_mask), then attention_mask.sum(1) would be 0, resulting in last_token_indices being -1. While this is technically a valid Python index (referring to the last element), it would give incorrect behavior for completely padded sequences. Consider adding validation or handling for this edge case.
| last_token_indices = attention_mask.sum(1) - 1 | |
| token_counts = attention_mask.sum(1) | |
| # Clamp to avoid negative indices when a sequence is fully padded (sum == 0). | |
| last_token_indices = torch.clamp(token_counts - 1, min=0) |
| last_token_indices = attention_mask.sum(1) - 1 | ||
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | ||
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] |
There was a problem hiding this comment.
The attention_mask sum could include padding tokens (zeros) in the mask. If a sequence is completely padded (all zeros in attention_mask), then attention_mask.sum(1) would be 0, resulting in last_token_indices being -1. While this is technically a valid Python index (referring to the last element), it would give incorrect behavior for completely padded sequences. Consider adding validation or handling for this edge case.
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| seq_lengths = attention_mask.sum(1) | |
| # Handle fully padded sequences (sum == 0) to avoid negative indices (-1) | |
| if torch.any(seq_lengths == 0): | |
| seq_lengths = seq_lengths.clone() | |
| seq_lengths[seq_lengths == 0] = 1 | |
| last_token_indices = seq_lengths - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] |
| # Identify last token index | ||
| # attention_mask (at this point, if past is None, it is just original mask) | ||
| last_token_indices = attention_mask.sum(1) - 1 | ||
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | ||
| e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :] | ||
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | ||
| else: | ||
| # Assume no padding in incremental decoding steps | ||
| e_t = outputs.hidden_states[0][:, -1, :] | ||
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
There was a problem hiding this comment.
The indentation inside this if block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Lines 315-320 should be indented with 12 spaces (8 base + 4 for the if block) rather than 13 spaces.
| # Identify last token index | |
| # attention_mask (at this point, if past is None, it is just original mask) | |
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :] | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| # Assume no padding in incremental decoding steps | |
| e_t = outputs.hidden_states[0][:, -1, :] | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] | |
| # Identify last token index | |
| # attention_mask (at this point, if past is None, it is just original mask) | |
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :] | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| # Assume no padding in incremental decoding steps | |
| e_t = outputs.hidden_states[0][:, -1, :] | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
| # Identify last token index | ||
| # attention_mask (at this point, if past is None, it is just original mask) | ||
| last_token_indices = attention_mask.sum(1) - 1 | ||
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | ||
| e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :] | ||
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | ||
| else: | ||
| # Assume no padding in incremental decoding steps | ||
| e_t = outputs.hidden_states[0][:, -1, :] | ||
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
There was a problem hiding this comment.
The indentation inside this if block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Lines 323-324 should be indented with 12 spaces (8 base + 4 for the else block) rather than 13 spaces.
| # Identify last token index | |
| # attention_mask (at this point, if past is None, it is just original mask) | |
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :] | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| # Assume no padding in incremental decoding steps | |
| e_t = outputs.hidden_states[0][:, -1, :] | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] | |
| # Identify last token index | |
| # attention_mask (at this point, if past is None, it is just original mask) | |
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| e_t = outputs.hidden_states[0][batch_indices, last_token_indices, :] | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| # Assume no padding in incremental decoding steps | |
| e_t = outputs.hidden_states[0][:, -1, :] | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
| last_token_indices = attention_mask.sum(1) - 1 | ||
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | ||
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | ||
| else: | ||
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
There was a problem hiding this comment.
The indentation inside this if block appears to use 5 spaces instead of the standard 4 spaces used throughout the rest of the codebase. Lines 398-400 should be indented with 12 spaces (8 base + 4 for the if block) rather than 13 spaces.
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] | |
| last_token_indices = attention_mask.sum(1) - 1 | |
| batch_indices = torch.arange(input_ids.shape[0], device=input_ids.device) | |
| last_hidden = outputs.hidden_states[-1][batch_indices, last_token_indices, :] | |
| else: | |
| last_hidden = outputs.hidden_states[-1][:, -1, :] |
Fix Incorrect Hidden State Extraction with Right Padding
Summary
This PR fixes a bug in
models.pywhere the model was incorrectly extracting hidden states from padding tokens when using right-padded batches.Bug Description
In
generate_latent_batchandgenerate_latent_batch_hidden_state, the code used[:, -1, :]to extract the hidden states of the last token in the sequence:When using batch generation with right padding (e.g.,
[token_A, token_B, PAD]), index-1corresponds to thePADtoken. As a result, the latent reasoning steps were being initialized with the hidden state of the padding token rather than the last actual token of the prompt.Fix
The fix uses the
attention_maskto determine the index of the last non-padding token for each sequence in the batch:This ensures that
last_hidden(ande_t) correctly corresponds to the last real token (e.g.,token_B), regardless of the padding.