@@ -109,8 +109,8 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
109109 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
110110 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , configDimRowMajorGlobalWorker );
111111 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce" , rmsNormWorker );
112- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , configDimRowMajorGlobalWorker );
113112 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , configHiddenDimRowMajorWorker );
113+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_down_proj" , configDimRowMajorGlobalWorker );
114114 }
115115 return tornadoForwardScheduler ;
116116 }
@@ -283,8 +283,16 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
283283 config .hiddenDim (), // hidden dimension
284284 LOCAL_WORK_GROUP_SIZE_ALLOC );
285285
286- unifiedLayer .task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context ,
287- state .wrapHb , state .wrapX , weights .w2Layered [layerIndex ].asByteArray (), config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
286+ // Down Projection with Residual
287+ unifiedLayer .task ("ffn_down_proj" ,
288+ TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte ,
289+ context ,
290+ qwen2State .wrapHb , // input: FFN intermediate
291+ qwen2State .wrapX , // output: wrapX += W2 · wrapHb
292+ weights .w2Layered [layerIndex ].asByteArray (), // W2 (down)
293+ config .hiddenDim (), // input dim
294+ config .dim (), // output dim
295+ LOCAL_WORK_GROUP_SIZE_ALLOC )
288296 .persistOnDevice (
289297 state .wrapX
290298 );
0 commit comments