Skip to content

Commit 2e6aaf5

Browse files
[refactor] Update attention output projection task in Qwen2 Q8_0 FFN layers
1 parent 5d4edb7 commit 2e6aaf5

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
107107
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorker);
108108
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
109109
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
110-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
110+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
111111
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
112112
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
113113
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
@@ -230,7 +230,7 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
230230
config.kvDim(), // kvDim
231231
layerIndex, // layer offset
232232
config.contextLength()); // max sequence length
233-
233+
234234
// Flash Attention
235235
unifiedLayer.task("attention",
236236
Qwen2Kernels::processHeadsFlashAttention,
@@ -247,9 +247,18 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
247247
layerIndex, // layer index
248248
config.contextLength()); // context length
249249

250-
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context,
251-
state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
252-
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
250+
// Output Projection with Residual
251+
unifiedLayer.task("attn_output_proj",
252+
TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
253+
context,
254+
qwen2State.wrapXb, // input: attention output
255+
qwen2State.wrapX, // output: wrapX += Wo · wrapXb
256+
weights.woLayered[layerIndex].asByteArray(), // Wo
257+
config.dim(), // input dim
258+
config.dim(), // output dim
259+
LOCAL_WORK_GROUP_SIZE_ALLOC);
260+
261+
unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
253262
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
254263

255264
// Fused RMS Apply + Gate/Up Projection + SiLU + GLU

0 commit comments

Comments
 (0)