@@ -107,7 +107,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
107107 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".fused_qkv_bias" , fusedQKVBiasWorker );
108108 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
109109 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
110- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".matmul1 " , configDimRowMajorGlobalWorker );
110+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj " , configDimRowMajorGlobalWorker );
111111 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , configDimRowMajorGlobalWorker );
112112 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlockFFN" , rmsNormWorker );
113113 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , configHiddenDimRowMajorWorker );
@@ -230,7 +230,7 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
230230 config .kvDim (), // kvDim
231231 layerIndex , // layer offset
232232 config .contextLength ()); // max sequence length
233-
233+
234234 // Flash Attention
235235 unifiedLayer .task ("attention" ,
236236 Qwen2Kernels ::processHeadsFlashAttention ,
@@ -247,9 +247,18 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
247247 layerIndex , // layer index
248248 config .contextLength ()); // context length
249249
250- unifiedLayer .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context ,
251- state .wrapXb , state .wrapX , weights .woLayered [layerIndex ].asByteArray (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
252- .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN ,
250+ // Output Projection with Residual
251+ unifiedLayer .task ("attn_output_proj" ,
252+ TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte ,
253+ context ,
254+ qwen2State .wrapXb , // input: attention output
255+ qwen2State .wrapX , // output: wrapX += Wo · wrapXb
256+ weights .woLayered [layerIndex ].asByteArray (), // Wo
257+ config .dim (), // input dim
258+ config .dim (), // output dim
259+ LOCAL_WORK_GROUP_SIZE_ALLOC );
260+
261+ unifiedLayer .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN ,
253262 state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
254263
255264 // Fused RMS Apply + Gate/Up Projection + SiLU + GLU
0 commit comments