-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Use Triton to implement a high-performance FlashAttention-2 kernel that balances ease of use and speed for users without deep CUDA knowledge.
Sub-Issues:
- Understand Triton Kernel Design for Attention
- Study Triton’s documentation and explore how it handles tiling, memory access, and parallelism for attention mechanisms.
- Document how Triton improves over manual CUDA implementations in certain cases.
- Develop Triton Kernel for FlashAttention-2 Forward Pass
- Implement the forward pass of FlashAttention-2 in Triton, taking advantage of Triton’s optimized memory management techniques.
- Ensure efficient work partitioning across warps, reducing the number of memory reads and writes.
- Develop Triton Kernel for FlashAttention-2 Backward Pass
- Implement the backward pass using Triton, ensuring that it supports recomputation of intermediate matrices and memory-efficient backpropagation.
- Integrate Triton Kernel with PyTorch
- Use Triton’s Python bindings to integrate the Triton-based FlashAttention-2 kernel into the high-level PyTorch API.
- Ensure seamless switching between Triton and CUDA implementations for performance comparison.
- Test and Benchmark Triton Kernel
- Write tests to validate the Triton implementation’s correctness.
- Compare the performance of the Triton implementation against CUDA and native PyTorch implementations.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels