-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Fixes StaticCache Crashes #42467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fixes StaticCache Crashes #42467
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -331,16 +331,21 @@ def update( | |
| cache_position = ( | ||
| cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device) | ||
| ) | ||
|
|
||
| k_out = self.keys | ||
| v_out = self.values | ||
| batch_size = key_states.shape[0] | ||
| if k_out.shape[0] != batch_size: | ||
| k_out = k_out[:batch_size] | ||
| v_out = v_out[:batch_size] | ||
|
Comment on lines
+336
to
+339
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you run benchmarks and check if it creates excessive cudagraph breaks, as per the last comment from mobicham. In any case, a small benchmark run will be needed before merging the PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I verified the fix with GPT2 using
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add the bench script to PR description pls?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep I'll add it |
||
| # Update the cache | ||
| try: | ||
| self.keys.index_copy_(2, cache_position, key_states) | ||
| self.values.index_copy_(2, cache_position, value_states) | ||
| k_out.index_copy_(2, cache_position, key_states) | ||
| v_out.index_copy_(2, cache_position, value_states) | ||
| except NotImplementedError: | ||
| # Fallback for devices like MPS where index_copy_ might not be supported. | ||
| self.keys[:, :, cache_position] = key_states | ||
| self.values[:, :, cache_position] = value_states | ||
| return self.keys, self.values | ||
| k_out[:, :, cache_position] = key_states | ||
| v_out[:, :, cache_position] = value_states | ||
| return k_out, v_out | ||
|
|
||
| def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: | ||
| """Return the length and offset of the cache, used to generate the attention mask""" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess
k_out.shape[0] >= batch_sizeis betterThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When debugging the torch.compile stuff, can you check this:
If there's no copy, I don't see why Cudagraphs would break with Whisper.
What error do you get exactly btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait should it be
<or>considering k_out will be larger than the batch_sizeAbout the actual torch.compile error i'm getting i'm trying max_batch_size = 8 and the list being 8,4,2,1
on 4 it crashes with
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry you're right, I meant
current_batch_size <= max_batch_sizeI meant run it without torch.compile, just to see if it performs any copy
I see, I will try to debug next week too 👍