@@ -104,7 +104,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
104104 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
105105 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
106106 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , matmul1Worker );
107- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlockFFN " , rmsNormWorker );
107+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce " , rmsNormWorker );
108108 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , fusedFFNW1W3Worker );
109109 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , projectionTwoWorker );
110110 }
@@ -275,13 +275,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
275275
276276 // ========== FEED-FORWARD BLOCK ==========
277277
278- // RMS norm for FFN input
279- unifiedLayer .task ("reductionsOneBlockFFN " ,
278+ // RMS Normalization - compute scale factor
279+ unifiedLayer .task ("ffn_rms_reduce " ,
280280 TransformerComputeKernelsLayered ::reductionOneBlockWithLayer ,
281- context , qwen3State .tempFFN , qwen3State .wrapX , config .dim (), config .rmsNormEps (), qwen3State .localSize )
282- .task ("mapContextFFN" ,
283- TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer ,
284- context , qwen3State .wrapXb , qwen3State .wrapX , weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), qwen3State .tempFFN );
281+ context ,
282+ qwen3State .tempFFN , // output: scale factor
283+ qwen3State .wrapX , // input: hidden state
284+ qwen3Config .dim (), // dimension
285+ qwen3Config .rmsNormEps (), // epsilon
286+ qwen3State .localSize ); // local memory size
285287
286288 // Fused RMS Apply + Gate/Up Projection + SiLU + GLU
287289 unifiedLayer .task ("rms_ffn_gate_up" ,
0 commit comments