diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 75b76e593bbc0..0dfd52217e19f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -7920,7 +7920,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf; vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; - uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; +#define SINK_ENABLE_BIT (1<<24) +#define SKIP_NEG_INF_BIT (1<<17) +#define MASK_ENABLE_BIT (1<<16) + + const bool skip_neg_inf = ctx->device->vendor_id != VK_VENDOR_ID_INTEL; + + const uint32_t mask_n_head_log2 = + ((sinks != nullptr) ? SINK_ENABLE_BIT : 0) | + ((mask != nullptr) ? MASK_ENABLE_BIT : 0) | + (skip_neg_inf ? SKIP_NEG_INF_BIT : 0) | + n_head_log2; const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 4bef48b006ce7..5dcbe7de786f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -126,18 +126,20 @@ void main() { } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; + if ((p.mask_n_head_log2 & SKIP_NEG_INF_BIT) != 0) { + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index eb93903c4681e..e0eb92973ed51 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -57,6 +57,7 @@ layout (push_constant) uniform parameter { } p; #define SINK_ENABLE_BIT (1<<24) +#define SKIP_NEG_INF_BIT (1<<17) #define MASK_ENABLE_BIT (1<<16) #define N_LOG2_MASK 0xFFFF diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index cd82e4abfabc4..36b08cdeb4f03 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -165,18 +165,20 @@ void main() { } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; + if ((p.mask_n_head_log2 & SKIP_NEG_INF_BIT) != 0) { + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 617d85108698a..7768e3f72163d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -160,10 +160,12 @@ void main() { coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; + if ((p.mask_n_head_log2 & SKIP_NEG_INF_BIT) != 0) { + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } else { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); @@ -176,10 +178,12 @@ void main() { coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; + if ((p.mask_n_head_log2 & SKIP_NEG_INF_BIT) != 0) { + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } }