Skip to content

Conversation

timt51
Copy link
Collaborator

@timt51 timt51 commented Nov 18, 2022

DO NOT LAND ON MAIN BRANCH

Implementation of causal prefix mask for cross attention (see Dao-AILab#20 (comment) for more info).

The original FlashAttention tests have been partially adjusted to take into account the new causal prefix masking scheme. The modified tests should correctly test that the output of flash_attn_unpadded_*_func, out, is correct, and the that gradients dq, dk, and dv are correct.

It has not been adjusted to properly test the output S_dmask (contains information about the attention values and dropout) because doing so requires figuring out the format of S_dmask (which is a non-standard format, see convert_flash_attn_S_to_softmax in the test file). This means that we cannot be sure about (1) whether the returned attention values are correct, and (2) whether causal prefix masking works with dropout. My guess is it does, assuming one can figure out how the data is formatted, but it hasn't been proven.

There may also be an effect on performance. I've seen the backward pass maybe taking longer... but hard to say.

@timt51 timt51 marked this pull request as draft November 18, 2022 02:35
@timt51 timt51 linked an issue Nov 18, 2022 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[1] Release causal prefix flashattn

1 participant