Skip to content

Commit 8f637cd

Browse files
Add final normalization step for non-NVIDIA devices in Qwen2 Q8_0 FFN layers.
1 parent 2b27344 commit 8f637cd

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,16 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
268268
config.rmsNormEps(), // epsilon
269269
qwen2State.localSize); // local memory size
270270

271+
// Final normalization (non-NVIDIA only)
272+
if (shouldUseFinalNormalization()) {
273+
unifiedLayer.task("ffn_rms_finalize",
274+
TransformerComputeKernelsLayered::reductionFinalNormalization,
275+
context,
276+
qwen2State.tempFFN, // scale factor (in/out)
277+
config.dim(), // dimension
278+
config.rmsNormEps()); // epsilon
279+
}
280+
271281
// Fused RMS Apply + Gate/Up Projection + SiLU + GLU
272282
// (Replaces mapContextFFN + fusedFeedForwardWithSiLUAndGLUActivation)
273283
unifiedLayer.task("rms_ffn_gate_up",

0 commit comments

Comments
 (0)