Skip to content

Commit 4404be4

Browse files
committed
Fix shader to support 2D workgroup mapping to a single subgroup
1 parent 142df17 commit 4404be4

File tree

1 file changed

+10
-9
lines changed

1 file changed

+10
-9
lines changed

ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,26 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit
7575
}
7676

7777
void main() {
78-
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
78+
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
7979
if (row >= n_rows) {
8080
return;
8181
}
8282

8383
const uint logits_offset = n_experts * row;
8484
const uint weights_offset = n_expert_used * row;
8585
const uint ids_offset = n_experts * row;
86+
const uint lane = gl_SubgroupInvocationID;
8687

8788
float wt[experts_per_thread];
8889

8990
[[unroll]]
9091
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
91-
const uint expert = i + gl_LocalInvocationID.x;
92+
const uint expert = i + lane;
9293
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
9394
}
9495

9596
if (!late_softmax) {
96-
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
97+
softmax_warp_inplace(wt, n_experts, lane, false);
9798
}
9899

99100
// at this point, each thread holds a portion of softmax,
@@ -111,11 +112,11 @@ void main() {
111112

112113
for (int k = 0; k < n_expert_used; k++) {
113114
float max_val = wt[0];
114-
uint max_expert = gl_LocalInvocationID.x;
115+
uint max_expert = lane;
115116

116117
[[unroll]]
117118
for (int i = 1; i < experts_per_thread; i++) {
118-
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
119+
const uint expert = lane + i * WARP_SIZE;
119120
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
120121
max_val = wt[i];
121122
max_expert = expert;
@@ -132,11 +133,11 @@ void main() {
132133
}
133134
}
134135

135-
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
136+
if ((k & (WARP_SIZE - 1)) == lane) {
136137
output_weights[k / WARP_SIZE] = max_val;
137138
}
138139

139-
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
140+
if ((max_expert & (WARP_SIZE - 1)) == lane) {
140141
wt[max_expert / WARP_SIZE] = -INFINITY;
141142

142143
ids[ids_offset + k] = max_expert;
@@ -158,12 +159,12 @@ void main() {
158159
}
159160

160161
if (late_softmax) {
161-
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
162+
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
162163
}
163164

164165
[[unroll]]
165166
for (uint i = 0; i < experts_per_thread; ++i) {
166-
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
167+
uint idx = i * WARP_SIZE + lane;
167168
if (idx < n_expert_used) {
168169
weights[weights_offset + idx] = output_weights[i];
169170
}

0 commit comments

Comments
 (0)