From cfda8f7cdf28937bbaaedbb692d4fc20720e5ce5 Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Fri, 14 Nov 2025 06:57:02 +0000 Subject: [PATCH 1/2] Add local mask for varlen --- flash-attn2/flash_attn_xpu/flash_api.cpp | 15 ++++-- .../src/kernel/varlen_kernel.hpp | 10 ++-- flash-attn2/flash_attn_xpu/src/varlen.hpp | 46 ++++++++++++++----- flash-attn2/tests/test_flash_attn.py | 7 --- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 153b9a2..39dde41 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -170,14 +170,19 @@ mha_varlen_fwd( } else { out_padded = torch::zeros_like(q_padded); } - + + bool is_local = (window_size_left != -1) | (window_size_right != -1); + q_padded = ensure_contiguous(q_padded); k_padded = ensure_contiguous(k_padded); v_padded = ensure_contiguous(v_padded); - cutlass::flash_attention::varlen::cutlass_varlen_impl(q_padded, k_padded, v_padded, out_padded, block_table_, - cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, max_seqlen_k, - softmax_scale, is_causal); + cutlass::flash_attention::varlen::cutlass_varlen_impl( + q_padded, k_padded, v_padded, out_padded, block_table_, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + softmax_scale, + window_size_left, window_size_right, + is_causal, is_local); // Remove padding from output at::Tensor out = out_padded; diff --git a/flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp b/flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp index 1de8aa8..596f25e 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/varlen_kernel.hpp @@ -111,7 +111,7 @@ class FMHAPrefillChunk { static constexpr bool CausalMask = CollectiveMainloop::CausalMask; static constexpr bool LocalMask = CollectiveMainloop::LocalMask; - static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); + // static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); static constexpr bool PagedKV = CollectiveMainloop::PagedKV; static constexpr int SubgroupSize = @@ -464,15 +464,15 @@ class FMHAPrefillChunk { CUTLASS_PRAGMA_UNROLL for (int m = 0; m < FragsM; m++) { // 2 int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv_cache - seq_len_qo; CUTLASS_PRAGMA_UNROLL for (int row = 0; row < Vec; row++) { // 8 bool left_mask = - col_idx < cute::max(0, row + row_idx + seq_len_kv_cache - - mainloop_params.window_left); + col_idx < cute::max(0, + row + row_idx + col_ref - mainloop_params.window_left); bool right_mask = col_idx > cute::min(seq_len_kv_cache, - row + row_idx + seq_len_kv_cache + - mainloop_params.window_right); + row + row_idx + col_ref + mainloop_params.window_right); if (left_mask || right_mask) { tSr(row, m, n) = ElementAccumulator{-INFINITY}; } diff --git a/flash-attn2/flash_attn_xpu/src/varlen.hpp b/flash-attn2/flash_attn_xpu/src/varlen.hpp index 7729522..ffe3a73 100644 --- a/flash-attn2/flash_attn_xpu/src/varlen.hpp +++ b/flash-attn2/flash_attn_xpu/src/varlen.hpp @@ -40,6 +40,9 @@ struct chunk_prefill_args_t { int block_size; bool is_causal; bool use_paged_kv; + int window_size_left; + int window_size_right; + bool is_local; }; template @@ -117,7 +120,8 @@ struct KernelLauncher { reinterpret_cast(args.key), stride_K_cache, reinterpret_cast(args.value), stride_V_cache, static_cast(args.block_table), args.block_size, - args.max_blocks_per_seq, args.total_seqlen_k, -1, -1}, + args.max_blocks_per_seq, args.total_seqlen_k, + args.window_size_left, args.window_size_right}, {args.sm_scale}, {reinterpret_cast(args.out), stride_O}, hw_info}; @@ -241,23 +245,33 @@ struct FMHAKernel { } static void dispatch(const chunk_prefill_args_t& args) { + #define DISPATCH_LOCAL(PagedKV, Causal) \ + if (args.is_local) { \ + run(args); \ + } else { \ + run(args); \ + } + if (args.use_paged_kv) { if (args.is_causal) { - run(args); + DISPATCH_LOCAL(true, true) + } else { - run(args); + DISPATCH_LOCAL(true, false) + } } else { if (args.is_causal) { - run(args); + DISPATCH_LOCAL(false, true) + } else { - run(args); + DISPATCH_LOCAL(false, false) + } } + #undef DISPATCH_LOCAL } }; @@ -313,7 +327,9 @@ void cutlass_varlen_impl( const at::Tensor& value_cache, at::Tensor& out, const std::optional& block_table, const at::Tensor& cu_seqlens_q, const at::Tensor& cu_seqlens_k, int max_seqlen_q, int max_seqlen_k, - double sm_scale, bool is_causal) { + double sm_scale, + int window_size_left = -1, int window_size_right = -1, + bool is_causal = false, bool is_local = false) { int num_heads_q = query.size(1); int head_size = query.size(2); int batch_size = cu_seqlens_q.numel() - 1; @@ -336,6 +352,11 @@ void cutlass_varlen_impl( total_seqlen_k = key_cache.size(0); } + if (is_local) { + window_size_left = window_size_left == -1 ? max_seqlen_k : window_size_left; + window_size_right = window_size_right == -1 ? max_seqlen_k : window_size_right; + } + chunk_prefill_args_t args = {query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), @@ -355,7 +376,10 @@ void cutlass_varlen_impl( max_blocks_per_seq, block_size, is_causal, - use_paged_kv}; + use_paged_kv, + window_size_left, + window_size_right, + is_local}; CutlassType cuType = aten_to_Cutlass_dtype(query); diff --git a/flash-attn2/tests/test_flash_attn.py b/flash-attn2/tests/test_flash_attn.py index 2cd234f..d2c75d6 100644 --- a/flash-attn2/tests/test_flash_attn.py +++ b/flash-attn2/tests/test_flash_attn.py @@ -755,8 +755,6 @@ def test_flash_attn_varlen_qkvpacked( if device == "xpu": if alibi: pytest.skip("alibi not supported on xpu currently") - if local: - pytest.skip("local attention not supported on xpu currently") if dropout_p != 0.0: pytest.skip("dropout not supported on xpu currently") @@ -1227,8 +1225,6 @@ def test_flash_attn_varlen_output( if device == "xpu": if alibi: pytest.skip("alibi not supported on xpu currently") - if local: - pytest.skip("local attention not supported on xpu currently") if dropout_p != 0.0: pytest.skip("dropout not supported on xpu currently") if softcap != 0.0: @@ -1667,9 +1663,6 @@ def test_flash_attn_varlen_causal( and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM - if device == "xpu": - if local: - pytest.skip("local attention not supported on xpu currently") if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q From ef82558a5879b16174e3a2b75665c68874518e2d Mon Sep 17 00:00:00 2001 From: YangKai0616 Date: Fri, 14 Nov 2025 09:10:26 +0000 Subject: [PATCH 2/2] Add local mask for fixed --- flash-attn2/flash_attn_xpu/flash_api.cpp | 8 ++- .../src/collective/fixed_mma.hpp | 15 ++++-- .../src/collective/fixed_softmax_epilogue.hpp | 46 ++++++++++++++--- flash-attn2/flash_attn_xpu/src/fixed.hpp | 33 +++++++++--- .../src/kernel/fixed_kernel.hpp | 50 +++++++++++++++++++ flash-attn2/flash_attn_xpu/src/varlen.hpp | 12 ++--- flash-attn2/tests/test_flash_attn.py | 17 ++----- 7 files changed, 141 insertions(+), 40 deletions(-) diff --git a/flash-attn2/flash_attn_xpu/flash_api.cpp b/flash-attn2/flash_attn_xpu/flash_api.cpp index 39dde41..33966ce 100644 --- a/flash-attn2/flash_attn_xpu/flash_api.cpp +++ b/flash-attn2/flash_attn_xpu/flash_api.cpp @@ -79,10 +79,16 @@ mha_fwd( out_padded = torch::zeros_like(q_padded); } + bool is_local = (window_size_left != -1) | (window_size_right != -1); + q_padded = ensure_contiguous(q_padded); k_padded = ensure_contiguous(k_padded); v_padded = ensure_contiguous(v_padded); - cutlass::flash_attention::fixed::cutlass_fixed_impl(q_padded, k_padded, v_padded, out_padded, softmax_scale, is_causal); + cutlass::flash_attention::fixed::cutlass_fixed_impl( + q_padded, k_padded, v_padded, out_padded, + softmax_scale, + window_size_left, window_size_right, + is_causal, is_local); // Remove padding from output at::Tensor out = out_padded; diff --git a/flash-attn2/flash_attn_xpu/src/collective/fixed_mma.hpp b/flash-attn2/flash_attn_xpu/src/collective/fixed_mma.hpp index 47810fc..298b2b0 100644 --- a/flash-attn2/flash_attn_xpu/src/collective/fixed_mma.hpp +++ b/flash-attn2/flash_attn_xpu/src/collective/fixed_mma.hpp @@ -62,7 +62,7 @@ CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { template + class GmemTiledCopyV_, bool CausalMask_, bool LocalMask_> struct FlashPrefillMma { static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; @@ -71,9 +71,9 @@ struct FlashPrefillMma { template + class GmemTiledCopyV_, bool CausalMask_, bool LocalMask_> struct FlashPrefillMma, ProblemShapeType_, ElementQ_, StrideQ_, ElementK_, StrideK_, ElementV_, - StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_> { + StrideV_, MMAOperation_, TileShapeQK_, TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, GmemTiledCopyV_, CausalMask_, LocalMask_> { // // Type Aliases // @@ -97,6 +97,7 @@ struct FlashPrefillMma, ProblemShapeType_, El using TiledMmaPV = typename TiledMMAHelper, SubgroupLayout>::TiledMMA; using ElementAccumulator = typename TiledMmaQK::ValTypeC; static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; using MmaAtomShape = typename MmaAtom::Shape_MNK; @@ -158,12 +159,16 @@ struct FlashPrefillMma, ProblemShapeType_, El StrideK dK; ElementV const *ptr_V; StrideV dV; + int window_left; + int window_right; }; struct Params { XE_Copy_Q gmem_tiled_copy_q; XE_Copy_K gmem_tiled_copy_k; XE_Copy_V gmem_tiled_copy_v; + int window_left; + int window_right; }; // @@ -185,7 +190,7 @@ struct FlashPrefillMma, ProblemShapeType_, El XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; - return Params{copyQ, copyK, copyV}; + return Params{copyQ, copyK, copyV, args.window_left, args.window_right}; } template @@ -376,7 +381,7 @@ struct FlashPrefillMma, ProblemShapeType_, El XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; - return Params{copyQ, copyK, copyV}; + return Params{copyQ, copyK, copyV, params.window_left, params.window_right}; } } }; diff --git a/flash-attn2/flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp b/flash-attn2/flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp index 0dd71a0..f560edc 100644 --- a/flash-attn2/flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp +++ b/flash-attn2/flash_attn_xpu/src/collective/fixed_softmax_epilogue.hpp @@ -50,13 +50,13 @@ namespace collective { ///////////////////////////////////////////////////////////////////////////////////////////////// -template class FlashPrefillSoftmaxEpilogue { +template class FlashPrefillSoftmaxEpilogue { static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); }; -template -class FlashPrefillSoftmaxEpilogue { +template +class FlashPrefillSoftmaxEpilogue { public: // @@ -66,6 +66,7 @@ class FlashPrefillSoftmaxEpilogue using Element = Element_; static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; using GmemTiledCopyOut = void; @@ -115,8 +116,18 @@ class FlashPrefillSoftmaxEpilogue CUTLASS_PRAGMA_UNROLL for (int z = 0; z < FragsN; z++) { auto base_indx = indx + (z * Vec * FragsM); - Element eq = frag_s(base_indx) - max_scale_bcast; - frag_s(base_indx) = sycl::native::exp2(eq); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } sum(indx) += frag_s(base_indx); } } @@ -155,7 +166,18 @@ class FlashPrefillSoftmaxEpilogue if (!is_first) { auto g = COMPAT::get_nd_item<1>().get_sub_group(); Element max_scale{max * params.scale}; - Element exp_scale{sycl::native::exp2(max_prev * params.scale - max_scale)}; + Element exp_scale; + if constexpr (LocalMask) { + if ((std::isinf(max_scale) && max_scale < 0) || + (std::isinf(max_prev) && max_prev < 0)) { + exp_scale = 0.f; + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + CUTLASS_PRAGMA_UNROLL for (int indx = 0; indx < Vec * FragsM; indx++) { auto max_scale_bcast = group_broadcast(g, max_scale, indx); @@ -164,7 +186,17 @@ class FlashPrefillSoftmaxEpilogue CUTLASS_PRAGMA_UNROLL for (int z = 0; z < FragsNAcc; z++) { auto base_indx = indx + (z * Vec * FragsM); - frag_s(base_indx) = sycl::native::exp2((frag_s(base_indx) - max_scale_bcast)); + if constexpr (LocalMask) { + if ((std::isinf(max_scale_bcast) && max_scale_bcast < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + frag_s(base_indx) = sycl::native::exp2((frag_s(base_indx) - max_scale_bcast)); + } sum(indx) += frag_s(base_indx); } CUTLASS_PRAGMA_UNROLL diff --git a/flash-attn2/flash_attn_xpu/src/fixed.hpp b/flash-attn2/flash_attn_xpu/src/fixed.hpp index 556336d..cd19f41 100644 --- a/flash-attn2/flash_attn_xpu/src/fixed.hpp +++ b/flash-attn2/flash_attn_xpu/src/fixed.hpp @@ -33,6 +33,9 @@ struct prefill_args_fixed_t { int batch_size; int seq_len_q; int seq_len_k; + int window_size_left; + int window_size_right; + bool is_local; }; template @@ -83,7 +86,7 @@ struct KernelLauncher { {reinterpret_cast(args.query), stride_Q, reinterpret_cast(args.key), stride_K, reinterpret_cast(args.value), stride_V, - }, // window_left, window_right for local mask (not supported currently) + args.window_size_left, args.window_size_right,}, {args.softmax_scale}, {reinterpret_cast(args.out), stride_O}, hw_info}; @@ -165,7 +168,7 @@ template struct FMHAKernel { - template + template static void run_impl(const prefill_args_fixed_t& args) { cutlass::KernelHardwareInfo hw_info; @@ -186,7 +189,7 @@ struct FMHAKernel { using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashPrefillSoftmaxEpilogue< - Causal, EpilogueDispatchPolicy, ElementAccumulator>; + Causal, Local, EpilogueDispatchPolicy, ElementAccumulator>; using ProblemShape = cute::tuple; @@ -197,7 +200,7 @@ struct FMHAKernel { cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, GmemTiledCopyQ, GmemTiledCopyK, - GmemTiledCopyV, Causal>; + GmemTiledCopyV, Causal, Local>; using FMHAPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefill< ProblemShape, CollectiveMainloop, CollectiveSoftmaxEpilogue, @@ -209,9 +212,17 @@ struct FMHAKernel { static void dispatch(const prefill_args_fixed_t& args) { if (args.is_causal) { - run_impl(args); + if (args.is_local) { + run_impl(args); + } else { + run_impl(args); + } } else { - run_impl(args); + if (args.is_local) { + run_impl(args); + } else { + run_impl(args); + } } } }; @@ -271,7 +282,8 @@ void cutlass_fixed_impl( const at::Tensor& key, // [batch, seq_k, heads, head_size] const at::Tensor& value, // [batch, seq_k, heads, head_size] at::Tensor& out, // [batch, seq_q, heads, head_size] - double softmax_scale, bool is_causal) { + double softmax_scale, int window_size_left = -1, int window_size_right = -1, + bool is_causal = false, bool is_local = false) { int batch_size = query.size(0); int seq_len_q = query.size(1); @@ -286,12 +298,17 @@ void cutlass_fixed_impl( auto v_reshaped = value.transpose(1, 2).contiguous(); auto out_temp = torch::zeros_like(q_reshaped); + if (is_local) { + window_size_left = window_size_left == -1 ? seq_len_q : window_size_left; + window_size_right = window_size_right == -1 ? seq_len_k : window_size_right; + } + // Prepare arguments prefill_args_fixed_t args{ q_reshaped.data_ptr(), k_reshaped.data_ptr(), v_reshaped.data_ptr(), out_temp.data_ptr(), static_cast(softmax_scale), num_heads_q, num_heads_kv, head_size, is_causal, batch_size, - seq_len_q, seq_len_k + seq_len_q, seq_len_k, window_size_left, window_size_right, is_local }; dispatch_by_head_size(aten_to_Cutlass_dtype(query), args); diff --git a/flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp b/flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp index ec1db8e..5d65e84 100644 --- a/flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp +++ b/flash-attn2/flash_attn_xpu/src/kernel/fixed_kernel.hpp @@ -101,6 +101,7 @@ class FMHAPrefill { static constexpr int SharedStorageSize = 0; static constexpr bool CausalMask = CollectiveMainloop::CausalMask; + static constexpr bool LocalMask = CollectiveMainloop::LocalMask; static constexpr int SubgroupSize = CollectiveMainloop::SubgroupSize; // sub_group size static constexpr uint32_t MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 @@ -341,6 +342,30 @@ class FMHAPrefill { } } + if constexpr (LocalMask) { + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + nblock * QK_BLK_N; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { + int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv - seq_len_qo; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { + bool left_mask = col_idx < cute::max(0, + row + row_idx + col_ref - mainloop_params.window_left); + bool right_mask = col_idx > cute::min(seq_len_kv, + row + row_idx + col_ref + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + // we only need one block ahead, there is enough gap to prefetch it while doing softmax. because the gap between the two MMA is big, // prefetching it the same way as cutlass K matrix does not make sense for(int i=0; i < size<1>(pVgV); i++) { @@ -389,6 +414,31 @@ class FMHAPrefill { } } } + + if constexpr (LocalMask) { + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + (nblock_limit - 1) * QK_BLK_N; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { + int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv - seq_len_qo; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { + bool left_mask = col_idx < cute::max(0, + row + row_idx + col_ref - mainloop_params.window_left); + bool right_mask = col_idx > cute::min(seq_len_kv, + row + row_idx + col_ref + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + // Add masking for partial tiles at the end block if (seq_len % QK_BLK_N != 0) { const int remainder = seq_len % QK_BLK_N; diff --git a/flash-attn2/flash_attn_xpu/src/varlen.hpp b/flash-attn2/flash_attn_xpu/src/varlen.hpp index ffe3a73..7fb02f5 100644 --- a/flash-attn2/flash_attn_xpu/src/varlen.hpp +++ b/flash-attn2/flash_attn_xpu/src/varlen.hpp @@ -45,7 +45,7 @@ struct chunk_prefill_args_t { bool is_local; }; -template +template struct KernelLauncher { using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; using StrideK = typename FMHAChunkPrefillKernel::StrideK; @@ -188,7 +188,7 @@ template struct FMHAKernel { - template static void run(const chunk_prefill_args_t& args) { cutlass::KernelHardwareInfo hw_info; @@ -219,7 +219,7 @@ struct FMHAKernel { using ProblemShapeVarlen = cute::tuple; using ProblemShapeType = - std::conditional_t; + std::conditional_t; // Mainloop using CollectiveMainloop = @@ -239,7 +239,7 @@ struct FMHAKernel { ProblemShapeType, CollectiveMainloop, CollectiveSoftmaxEpilogue, CollectiveEpilogue, Scheduler>; - KernelLauncher launcher; + KernelLauncher launcher; launcher.run(args, hw_info); } @@ -247,10 +247,10 @@ struct FMHAKernel { static void dispatch(const chunk_prefill_args_t& args) { #define DISPATCH_LOCAL(PagedKV, Causal) \ if (args.is_local) { \ - run(args); \ } else { \ - run(args); \ } diff --git a/flash-attn2/tests/test_flash_attn.py b/flash-attn2/tests/test_flash_attn.py index d2c75d6..8b4aae5 100644 --- a/flash-attn2/tests/test_flash_attn.py +++ b/flash-attn2/tests/test_flash_attn.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F from einops import rearrange, repeat -from flash_attn import ( +from flash_attn2 import ( flash_attn_func, flash_attn_kvpacked_func, flash_attn_qkvpacked_func, @@ -13,9 +13,9 @@ flash_attn_varlen_qkvpacked_func, flash_attn_with_kvcache, ) -from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import _get_block_size_n -from flash_attn.layers.rotary import apply_rotary_emb +from flash_attn2.bert_padding import pad_input, unpad_input +from flash_attn2.flash_attn_interface import _get_block_size_n +from flash_attn2.layers.rotary import apply_rotary_emb MAX_HEADDIM_SM8x = 192 @@ -594,8 +594,6 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ if device == "xpu": if alibi: pytest.skip("alibi not supported on xpu currently") - if local: - pytest.skip("local attention not supported on xpu currently") if dropout_p != 0.0: pytest.skip("dropout not supported on xpu currently") @@ -941,8 +939,6 @@ def test_flash_attn_output( if device == "xpu": if alibi: pytest.skip("alibi not supported on xpu currently") - if local: - pytest.skip("local attention not supported on xpu currently") if dropout_p != 0.0: pytest.skip("dropout not supported on xpu currently") if softcap != 0.0: @@ -1541,9 +1537,6 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype, devi and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM - if device == "xpu": - if local: - pytest.skip("local attention not supported on xpu currently") if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q @@ -1838,8 +1831,6 @@ def test_flash_attn_splitkv( if device == "xpu": if alibi: pytest.skip("alibi not supported on xpu currently") - if local: - pytest.skip("local attention not supported on xpu currently") if swap_sq_sk: seqlen_q, seqlen_k = seqlen_k, seqlen_q