@@ -108,8 +108,8 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
108108 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
109109 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
110110 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , configDimRowMajorGlobalWorker );
111+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce" , rmsNormWorker );
111112 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , configDimRowMajorGlobalWorker );
112- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlockFFN" , rmsNormWorker );
113113 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , configHiddenDimRowMajorWorker );
114114 }
115115 return tornadoForwardScheduler ;
@@ -258,8 +258,15 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
258258 config .dim (), // output dim
259259 LOCAL_WORK_GROUP_SIZE_ALLOC );
260260
261- unifiedLayer .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN ,
262- state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
261+ // RMS Normalization - compute scale factor
262+ unifiedLayer .task ("ffn_rms_reduce" ,
263+ TransformerComputeKernelsLayered ::reductionOneBlockWithLayer ,
264+ context ,
265+ qwen2State .tempFFN , // output: scale factor
266+ qwen2State .wrapX , // input: hidden state
267+ config .dim (), // dimension
268+ config .rmsNormEps (), // epsilon
269+ qwen2State .localSize ); // local memory size
263270
264271 // Fused RMS Apply + Gate/Up Projection + SiLU + GLU
265272 // (Replaces mapContextFFN + fusedFeedForwardWithSiLUAndGLUActivation)
0 commit comments