diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index bc1c278bf49..5cd0785d20f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -75,7 +75,7 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit } void main() { - const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; + const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; if (row >= n_rows) { return; } @@ -83,17 +83,18 @@ void main() { const uint logits_offset = n_experts * row; const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; + const uint lane = gl_SubgroupInvocationID; float wt[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + gl_LocalInvocationID.x; + const uint expert = i + lane; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); + softmax_warp_inplace(wt, n_experts, lane, false); } // at this point, each thread holds a portion of softmax, @@ -111,11 +112,11 @@ void main() { for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; - uint max_expert = gl_LocalInvocationID.x; + uint max_expert = lane; [[unroll]] for (int i = 1; i < experts_per_thread; i++) { - const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE; + const uint expert = lane + i * WARP_SIZE; if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { max_val = wt[i]; max_expert = expert; @@ -132,11 +133,11 @@ void main() { } } - if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((k & (WARP_SIZE - 1)) == lane) { output_weights[k / WARP_SIZE] = max_val; } - if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((max_expert & (WARP_SIZE - 1)) == lane) { wt[max_expert / WARP_SIZE] = -INFINITY; ids[ids_offset + k] = max_expert; @@ -158,12 +159,12 @@ void main() { } if (late_softmax) { - softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); + softmax_warp_inplace(output_weights, n_expert_used, lane, true); } [[unroll]] for (uint i = 0; i < experts_per_thread; ++i) { - uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; + uint idx = i * WARP_SIZE + lane; if (idx < n_expert_used) { weights[weights_offset + idx] = output_weights[i]; }