Skip to content

Commit c480157

Browse files
[refactor] Replace reductionsOneBlockFFN task with ffn_rms_reduce in Qwen3 Q8_0 FFN layers and update worker grids.
1 parent eade8f8 commit c480157

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
104104
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
105105
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
106106
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
107-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
107+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
108108
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker);
109109
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker);
110110
}
@@ -275,13 +275,15 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
275275

276276
// ========== FEED-FORWARD BLOCK ==========
277277

278-
// RMS norm for FFN input
279-
unifiedLayer.task("reductionsOneBlockFFN",
278+
// RMS Normalization - compute scale factor
279+
unifiedLayer.task("ffn_rms_reduce",
280280
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
281-
context, qwen3State.tempFFN, qwen3State.wrapX, config.dim(), config.rmsNormEps(), qwen3State.localSize)
282-
.task("mapContextFFN",
283-
TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
284-
context, qwen3State.wrapXb, qwen3State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), qwen3State.tempFFN);
281+
context,
282+
qwen3State.tempFFN, // output: scale factor
283+
qwen3State.wrapX, // input: hidden state
284+
qwen3Config.dim(), // dimension
285+
qwen3Config.rmsNormEps(), // epsilon
286+
qwen3State.localSize); // local memory size
285287

286288
// Fused RMS Apply + Gate/Up Projection + SiLU + GLU
287289
unifiedLayer.task("rms_ffn_gate_up",

0 commit comments

Comments
 (0)