@@ -103,7 +103,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
103103 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".qk_rmsnorm" , qkRmsNormWorker );
104104 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
105105 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
106- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".matmul1 " , matmul1Worker );
106+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj " , matmul1Worker );
107107 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlockFFN" , rmsNormWorker );
108108 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , fusedFFNW1W3Worker );
109109 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , projectionTwoWorker );
@@ -262,12 +262,16 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
262262 layerIndex , // layer index
263263 qwen3Config .contextLength ()); // context length
264264
265- // Output projection (Q8_0 weights)
266- unifiedLayer .task ("matmul1 " ,
265+ // Output Projection with Residual
266+ unifiedLayer .task ("attn_output_proj " ,
267267 TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte ,
268- context , qwen3State .wrapXb , qwen3State .wrapX ,
269- weights .woLayered [layerIndex ].asByteArray (),
270- qDim0 , config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC );
268+ context ,
269+ qwen3State .wrapXb , // input: attention output
270+ qwen3State .wrapX , // output: wrapX += Wo · wrapXb
271+ weights .woLayered [layerIndex ].asByteArray (), // Wo [dim x qDim]
272+ nEmbdHeadK * qwen3Config .numberOfHeads (), // input dim (qDim)
273+ config .dim (), // output dim
274+ LOCAL_WORK_GROUP_SIZE_ALLOC );
271275
272276 // ========== FEED-FORWARD BLOCK ==========
273277
0 commit comments