@@ -75,25 +75,26 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit
7575}
7676
7777void main() {
78- const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y ;
78+ const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID ;
7979 if (row >= n_rows) {
8080 return;
8181 }
8282
8383 const uint logits_offset = n_experts * row;
8484 const uint weights_offset = n_expert_used * row;
8585 const uint ids_offset = n_experts * row;
86+ const uint lane = gl_SubgroupInvocationID;
8687
8788 float wt[experts_per_thread];
8889
8990 [[unroll]]
9091 for (uint i = 0; i < n_experts; i += WARP_SIZE) {
91- const uint expert = i + gl_LocalInvocationID.x ;
92+ const uint expert = i + lane ;
9293 wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
9394 }
9495
9596 if (!late_softmax) {
96- softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x , false);
97+ softmax_warp_inplace(wt, n_experts, lane , false);
9798 }
9899
99100 // at this point, each thread holds a portion of softmax,
@@ -111,11 +112,11 @@ void main() {
111112
112113 for (int k = 0; k < n_expert_used; k++) {
113114 float max_val = wt[0];
114- uint max_expert = gl_LocalInvocationID.x ;
115+ uint max_expert = lane ;
115116
116117 [[unroll]]
117118 for (int i = 1; i < experts_per_thread; i++) {
118- const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
119+ const uint expert = lane + i * WARP_SIZE;
119120 if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
120121 max_val = wt[i];
121122 max_expert = expert;
@@ -132,11 +133,11 @@ void main() {
132133 }
133134 }
134135
135- if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x ) {
136+ if ((k & (WARP_SIZE - 1)) == lane ) {
136137 output_weights[k / WARP_SIZE] = max_val;
137138 }
138139
139- if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x ) {
140+ if ((max_expert & (WARP_SIZE - 1)) == lane ) {
140141 wt[max_expert / WARP_SIZE] = -INFINITY;
141142
142143 ids[ids_offset + k] = max_expert;
@@ -158,12 +159,12 @@ void main() {
158159 }
159160
160161 if (late_softmax) {
161- softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x , true);
162+ softmax_warp_inplace(output_weights, n_expert_used, lane , true);
162163 }
163164
164165 [[unroll]]
165166 for (uint i = 0; i < experts_per_thread; ++i) {
166- uint idx = i * WARP_SIZE + gl_LocalInvocationID.x ;
167+ uint idx = i * WARP_SIZE + lane ;
167168 if (idx < n_expert_used) {
168169 weights[weights_offset + idx] = output_weights[i];
169170 }
0 commit comments