-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Fixes StaticCache Crashes #42467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fixes StaticCache Crashes #42467
Conversation
|
@zucchini-nlp This doesn't have the |
zucchini-nlp
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for a quick fix @i3hz !
I think we need to allow max_batch_size which will take precedence if available when lazy initializing the cache. Early cache initialization is currently used only in export, but we can allow users to re-use cache across several generation with max batch size. It would also require us to change a few places in generation imo
After that, we can verify that there are no unwanted graph breaks and run the bench
LMK if this makes sense and you need guidance
src/transformers/cache_utils.py
Outdated
| if key_states.dtype != k_out.dtype: | ||
| key_states = key_states.to(k_out.dtype) | ||
| if value_states.dtype != v_out.dtype: | ||
| value_states = value_states.to(v_out.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be handled by lazy init already, so not really needed imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright I'll remove that
| batch_size = key_states.shape[0] | ||
| if k_out.shape[0] != batch_size: | ||
| k_out = k_out[:batch_size] | ||
| v_out = v_out[:batch_size] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you run benchmarks and check if it creates excessive cudagraph breaks, as per the last comment from mobicham. In any case, a small benchmark run will be needed before merging the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I verified the fix with GPT2 using torch.compile(mode='reduce-overhead', fullgraph=True) . For some reason it keeps failing with whisper models and I can't really figure out why .
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add the bench script to PR description pls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep I'll add it
So basically I should add And also change line 1837 from or cache_to_check.max_batch_size < batch_size |
|
Yep, and a small test as well |
|
hi @zucchini-nlp I'm still stuck on this. I’ve been testing with |
| k_out = self.keys | ||
| v_out = self.values | ||
| batch_size = key_states.shape[0] | ||
| if k_out.shape[0] != batch_size: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess k_out.shape[0] >= batch_size is better
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When debugging the torch.compile stuff, can you check this:
assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
assert v_out.data_ptr() == v_out[:batch_size].data_ptr() , "invalid v_out data copy()!"If there's no copy, I don't see why Cudagraphs would break with Whisper.
What error do you get exactly btw?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess k_out.shape[0] <= batch_size is better
Wait should it be < or > considering k_out will be larger than the batch_size
assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
Builtin `operator.*` comparison with constant `self` failed
Explanation: Failed to compare DataPtrVariable() with DataPtrVariable(), because DataPtrVariable() is not a Python constant or its mutation check fails.
About the actual torch.compile error i'm getting i'm trying max_batch_size = 8 and the list being 8,4,2,1
on 4 it crashes with
Dynamo failed to run FX node with fake tensors: call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(s72, 6, 1, 64), dtype=torch.float16,
grad_fn=<TransposeBackward0>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
grad_fn=<Error>), FakeTensor(..., device='cuda:0', size=(8, 6, 32, 64), dtype=torch.float16,
grad_fn=<Error>)), **{'attn_mask': None, 'dropout_p': 0.0, 'scale': 1.0, 'is_causal': False}): got RuntimeError('Attempting to broadcast a dimension of length 8 at -2! Mismatching argument at index 1 had [8, 6]; but expected shape should be broadcastable to [s72, 6]')
from user code:
File "/home/vedth/stuhdy/z.py", line 21, in decoder_forward
out = model.model.decoder(
File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 865, in forward
layer_outputs = decoder_layer(
File "/home/vedth/stuhdy/transformers/src/transformers/modeling_layers.py", line 94, in __call__
return super().__call__(*args, **kwargs)
File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/vedth/stuhdy/transformers/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 501, in forward
hidden_states, cross_attn_weights = self.encoder_attn(
File "/home/vedth/stuhdy/transformers/src/transformers/models/whisper/modeling_whisper.py", line 347, in forward
attn_output, attn_weights = attention_interface(
File "/home/vedth/stuhdy/transformers/src/transformers/integrations/sdpa_attention.py", line 92, in sdpa_attention_forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait should it be < or > considering k_out will be larger than the batch_size
Oh sorry you're right, I meant current_batch_size <= max_batch_size
assert k_out.data_ptr() == k_out[:batch_size].data_ptr() , "invalid k_out data copy()!"
I meant run it without torch.compile, just to see if it performs any copy
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace
I see, I will try to debug next week too 👍
|
@i3hz I will also do some debugging next week |
Thanks a lot |
What does this PR do?
Fixes #42454
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@zucchini-nlp @Rocketknight1 @mobicham
Benchmarking script -