Skip to content

[CUDA] cuDNN backward attention#2762

Merged
zcbenz merged 6 commits intoml-explore:mainfrom
zcbenz:cudnn-sdpa-backward
Nov 18, 2025
Merged

[CUDA] cuDNN backward attention#2762
zcbenz merged 6 commits intoml-explore:mainfrom
zcbenz:cudnn-sdpa-backward

Conversation

@zcbenz
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz commented Nov 14, 2025

This PR uses cuDNN backward attention for fast::ScaledDotProductAttention::vjp.

  • A new ScaledDotProductAttentionVJP primitive is added, but only implemented in CUDA backend.
  • For training a stats output is generated from the forward attention op, which is required by the backward op.
  • The array mask has not been implemented yet - so in actual training the new code may not kick in.
  • There are some duplicate code which I will clean up later together with the convolution code.

For training a 0.6B model:

before:

RAM usage: 60511MiB / 81920MiB

INFO:root:Model has 596049920 parameters.
INFO:root:step: 100, train_loss: 10.0955, grad_norm: 4.4649, its_per_sec: 2.1158, toks_per_sec: 17332.9950, tokens: 819200
INFO:root:step: 200, train_loss: 8.1339, grad_norm: 1.5591, its_per_sec: 2.6921, toks_per_sec: 22053.9309, tokens: 1638400
INFO:root:step: 300, train_loss: 7.6831, grad_norm: 1.5067, its_per_sec: 2.6923, toks_per_sec: 22055.2519, tokens: 2457600
INFO:root:step: 400, train_loss: 7.5040, grad_norm: 1.3949, its_per_sec: 2.6794, toks_per_sec: 21949.9022, tokens: 3276800
INFO:root:step: 500, train_loss: 7.2268, grad_norm: 1.4708, its_per_sec: 2.6820, toks_per_sec: 21970.5679, tokens: 4096000

after:

RAM usage: 52593MiB / 81920MiB

INFO:root:Model has 596049920 parameters.
INFO:root:step: 100, train_loss: 10.0885, grad_norm: 4.4325, its_per_sec: 3.0472, toks_per_sec: 24962.3116, tokens: 819200
INFO:root:step: 200, train_loss: 8.1385, grad_norm: 1.6766, its_per_sec: 3.4142, toks_per_sec: 27968.7873, tokens: 1638400
INFO:root:step: 300, train_loss: 7.6979, grad_norm: 2.3700, its_per_sec: 3.4130, toks_per_sec: 27959.5179, tokens: 2457600
INFO:root:step: 400, train_loss: 7.5094, grad_norm: 1.4560, its_per_sec: 3.3883, toks_per_sec: 27756.6980, tokens: 3276800
INFO:root:step: 500, train_loss: 7.2361, grad_norm: 1.8131, its_per_sec: 3.3830, toks_per_sec: 27713.4482, tokens: 4096000

@zcbenz zcbenz force-pushed the cudnn-sdpa-backward branch from 6a22b32 to 8f2b0c5 Compare November 18, 2025 05:21
Copy link
Copy Markdown
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, LGTM!

@zcbenz zcbenz force-pushed the cudnn-sdpa-backward branch from 05b01fa to aa048f5 Compare November 18, 2025 07:25
@zcbenz
Copy link
Copy Markdown
Collaborator Author

zcbenz commented Nov 18, 2025

I just noticed a behavior change in Metal backend by replacing detail::in_grad_tracing() with output_logsumexp: for training, the fast sdpa was fallbacking to unfused ops for both forward and backward passes, but with this change it would use fused sdpa for forward pass.

I think the change is good for performance and tests are passing, but I want to make sure I'm not missing anything?

@awni
Copy link
Copy Markdown
Member

awni commented Nov 18, 2025

I think the change is good for performance and tests are passing, but I want to make sure I'm not missing anything?

Actually we added that in the first place because it's faster for training on Metal to use the unfused SDPA for both forward and backward (since the forward and backward can share some computation there).

@zcbenz zcbenz force-pushed the cudnn-sdpa-backward branch from 0824fe4 to 5160d4c Compare November 18, 2025 21:57
@zcbenz
Copy link
Copy Markdown
Collaborator Author

zcbenz commented Nov 18, 2025

Thanks for the info, I updated the code to keep behavior unchanged for Metal backend.

@zcbenz zcbenz merged commit 6f35017 into ml-explore:main Nov 18, 2025
10 checks passed
@zcbenz zcbenz deleted the cudnn-sdpa-backward branch November 18, 2025 23:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants