@@ -106,7 +106,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
106106 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , matmul1Worker );
107107 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce" , rmsNormWorker );
108108 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , fusedFFNW1W3Worker );
109- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo " , projectionTwoWorker );
109+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_down_proj " , projectionTwoWorker );
110110 }
111111 return tornadoForwardScheduler ;
112112 }
@@ -299,11 +299,16 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
299299 qwen3Config .hiddenDim (), // hidden dimension
300300 LOCAL_WORK_GROUP_SIZE_ALLOC );
301301
302- unifiedLayer .task ("projectionTwo" ,
302+ // Down Projection with Residual
303+ unifiedLayer .task ("ffn_down_proj" ,
303304 TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte ,
304- context , qwen3State .wrapHb , qwen3State .wrapX ,
305- weights .w2Layered [layerIndex ].asByteArray (),
306- config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
305+ context ,
306+ qwen3State .wrapHb , // input: FFN intermediate
307+ qwen3State .wrapX , // output: wrapX += W2 · wrapHb
308+ weights .w2Layered [layerIndex ].asByteArray (), // W2 (down)
309+ config .hiddenDim (), // input dim
310+ config .dim (), // output dim
311+ LOCAL_WORK_GROUP_SIZE_ALLOC )
307312 .persistOnDevice (state .wrapX );
308313
309314 return unifiedLayer ;
0 commit comments