@@ -106,11 +106,11 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
106106 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_qkv_projection" , fusedQKVWorker );
107107 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".fused_qkv_bias" , fusedQKVBiasWorker );
108108 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
109+ tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
109110 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".matmul1" , configDimRowMajorGlobalWorker );
110111 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , configDimRowMajorGlobalWorker );
111112 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlockFFN" , rmsNormWorker );
112113 tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , configHiddenDimRowMajorWorker );
113- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".parallel-attention" , parallelAttentionWorker );
114114 }
115115 return tornadoForwardScheduler ;
116116 }
@@ -230,12 +230,24 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
230230 config .kvDim (), // kvDim
231231 layerIndex , // layer offset
232232 config .contextLength ()); // max sequence length
233+
234+ // Flash Attention
235+ unifiedLayer .task ("attention" ,
236+ Qwen2Kernels ::processHeadsFlashAttention ,
237+ context ,
238+ qwen2State .wrapQ , // query vectors
239+ qwen2State .wrapKeyCache , // key cache
240+ qwen2State .wrapValueCache , // value cache
241+ qwen2State .wrapXb , // output: attention result
242+ config .numberOfHeads (), // nHeads
243+ config .headSize (), // headSize
244+ config .kvDim (), // kvDim
245+ config .kvMul (), // kvMul (nHeads / nHeadKv)
246+ qwen2State .positionHolder , // position
247+ layerIndex , // layer index
248+ config .contextLength ()); // context length
233249
234- unifiedLayer .task ("parallel-attention" , Qwen2Kernels ::processHeadsFlashAttention , context ,
235- state .wrapQ , state .wrapKeyCache , state .wrapValueCache , state .wrapXb ,
236- config .numberOfHeads (), config .headSize (), config .kvDim (), config .kvMul (),
237- state .positionHolder , layerIndex , config .contextLength ())
238- .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context ,
250+ unifiedLayer .task ("matmul1" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidualQ8_0Byte , context ,
239251 state .wrapXb , state .wrapX , weights .woLayered [layerIndex ].asByteArray (), config .dim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
240252 .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer , context , state .tempFFN ,
241253 state .wrapX , config .dim (), config .rmsNormEps (), state .localSize );
0 commit comments