Skip to content

Commit 83bb419

Browse files
[refactor] Replace reductionsOneBlockFFN task with ffn_rms_reduce in Qwen2 Q8_0 FFN layers and update worker grids.
1 parent 2e6aaf5 commit 83bb419

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
108108
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
109109
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
110110
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
111+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
111112
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
112-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
113113
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
114114
}
115115
return tornadoForwardScheduler;
@@ -258,8 +258,15 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
258258
config.dim(), // output dim
259259
LOCAL_WORK_GROUP_SIZE_ALLOC);
260260

261-
unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
262-
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
261+
// RMS Normalization - compute scale factor
262+
unifiedLayer.task("ffn_rms_reduce",
263+
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
264+
context,
265+
qwen2State.tempFFN, // output: scale factor
266+
qwen2State.wrapX, // input: hidden state
267+
config.dim(), // dimension
268+
config.rmsNormEps(), // epsilon
269+
qwen2State.localSize); // local memory size
263270

264271
// Fused RMS Apply + Gate/Up Projection + SiLU + GLU
265272
// (Replaces mapContextFFN + fusedFeedForwardWithSiLUAndGLUActivation)

0 commit comments

Comments
 (0)