Skip to content

Implement FlashAttention v2 in Triton #4

@debashishc

Description

@debashishc

Use Triton to implement a high-performance FlashAttention-2 kernel that balances ease of use and speed for users without deep CUDA knowledge.

Sub-Issues:

  1. Understand Triton Kernel Design for Attention
  • Study Triton’s documentation and explore how it handles tiling, memory access, and parallelism for attention mechanisms.
  • Document how Triton improves over manual CUDA implementations in certain cases.
  1. Develop Triton Kernel for FlashAttention-2 Forward Pass
  • Implement the forward pass of FlashAttention-2 in Triton, taking advantage of Triton’s optimized memory management techniques.
  • Ensure efficient work partitioning across warps, reducing the number of memory reads and writes.
  1. Develop Triton Kernel for FlashAttention-2 Backward Pass
  • Implement the backward pass using Triton, ensuring that it supports recomputation of intermediate matrices and memory-efficient backpropagation.
  1. Integrate Triton Kernel with PyTorch
  • Use Triton’s Python bindings to integrate the Triton-based FlashAttention-2 kernel into the high-level PyTorch API.
  • Ensure seamless switching between Triton and CUDA implementations for performance comparison.
  1. Test and Benchmark Triton Kernel
  • Write tests to validate the Triton implementation’s correctness.
  • Compare the performance of the Triton implementation against CUDA and native PyTorch implementations.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions