[CUDA] cuDNN backward attention#2762
Conversation
6a22b32 to
8f2b0c5
Compare
05b01fa to
aa048f5
Compare
|
I just noticed a behavior change in Metal backend by replacing I think the change is good for performance and tests are passing, but I want to make sure I'm not missing anything? |
Actually we added that in the first place because it's faster for training on Metal to use the unfused SDPA for both forward and backward (since the forward and backward can share some computation there). |
0824fe4 to
5160d4c
Compare
|
Thanks for the info, I updated the code to keep behavior unchanged for Metal backend. |
This PR uses cuDNN backward attention for
fast::ScaledDotProductAttention::vjp.ScaledDotProductAttentionVJPprimitive is added, but only implemented in CUDA backend.statsoutput is generated from the forward attention op, which is required by the backward op.For training a 0.6B model:
before:
RAM usage: 60511MiB / 81920MiB
after:
RAM usage: 52593MiB / 81920MiB