Skip to content

Commit eade8f8

Browse files
[refactor] Rename matmul1 task to attn_output_proj and update corresponding worker grid in Qwen3 Q8_0 FFN layers.
1 parent ecb2828 commit eade8f8

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
103103
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qk_rmsnorm", qkRmsNormWorker);
104104
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
105105
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
106-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker);
106+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
107107
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
108108
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker);
109109
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker);
@@ -262,12 +262,16 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
262262
layerIndex, // layer index
263263
qwen3Config.contextLength()); // context length
264264

265-
// Output projection (Q8_0 weights)
266-
unifiedLayer.task("matmul1",
265+
// Output Projection with Residual
266+
unifiedLayer.task("attn_output_proj",
267267
TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
268-
context, qwen3State.wrapXb, qwen3State.wrapX,
269-
weights.woLayered[layerIndex].asByteArray(),
270-
qDim0, config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC);
268+
context,
269+
qwen3State.wrapXb, // input: attention output
270+
qwen3State.wrapX, // output: wrapX += Wo · wrapXb
271+
weights.woLayered[layerIndex].asByteArray(), // Wo [dim x qDim]
272+
nEmbdHeadK * qwen3Config.numberOfHeads(), // input dim (qDim)
273+
config.dim(), // output dim
274+
LOCAL_WORK_GROUP_SIZE_ALLOC);
271275

272276
// ========== FEED-FORWARD BLOCK ==========
273277

0 commit comments

Comments
 (0)