@@ -105,6 +105,9 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
105105 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
106106 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , matmul1Worker );
107107 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce" , rmsNormWorker );
108+ if (shouldUseFinalNormalization ()) {
109+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_finalize" , rmsNormWorker );
110+ }
108111 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , fusedFFNW1W3Worker );
109112 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_down_proj" , projectionTwoWorker );
110113 }
@@ -285,6 +288,16 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
285288 qwen3Config .rmsNormEps (), // epsilon
286289 qwen3State .localSize ); // local memory size
287290
291+ // Final normalization (non-NVIDIA only)
292+ if (shouldUseFinalNormalization ()) {
293+ unifiedLayer .task ("ffn_rms_finalize" ,
294+ TransformerComputeKernelsLayered ::reductionFinalNormalization ,
295+ context ,
296+ qwen3State .tempFFN , // scale factor (in/out)
297+ qwen3Config .dim (), // dimension
298+ qwen3Config .rmsNormEps ()); // epsilon
299+ }
300+
288301 // Fused RMS Apply + Gate/Up Projection + SiLU + GLU
289302 unifiedLayer .task ("rms_ffn_gate_up" ,
290303 TransformerComputeKernelsLayered ::fusedRmsNormFFNGateUpQ8_0 ,
0 commit comments