Skip to content

Commit 06ebbc8

Browse files
Add final normalization step for non-NVIDIA devices in Qwen3 Q8_0 FFN layers and update worker grids.
1 parent 05c0048 commit 06ebbc8

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
105105
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
106106
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
107107
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
108+
if (shouldUseFinalNormalization()) {
109+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_finalize", rmsNormWorker);
110+
}
108111
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker);
109112
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker);
110113
}
@@ -285,6 +288,16 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
285288
qwen3Config.rmsNormEps(), // epsilon
286289
qwen3State.localSize); // local memory size
287290

291+
// Final normalization (non-NVIDIA only)
292+
if (shouldUseFinalNormalization()) {
293+
unifiedLayer.task("ffn_rms_finalize",
294+
TransformerComputeKernelsLayered::reductionFinalNormalization,
295+
context,
296+
qwen3State.tempFFN, // scale factor (in/out)
297+
qwen3Config.dim(), // dimension
298+
qwen3Config.rmsNormEps()); // epsilon
299+
}
300+
288301
// Fused RMS Apply + Gate/Up Projection + SiLU + GLU
289302
unifiedLayer.task("rms_ffn_gate_up",
290303
TransformerComputeKernelsLayered::fusedRmsNormFFNGateUpQ8_0,

0 commit comments

Comments
 (0)