Skip to content

Commit 5d4edb7

Browse files
[refactor] Update attention task
1 parent 5d9ae41 commit 5d4edb7

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
106106
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker);
107107
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_qkv_bias", fusedQKVBiasWorker);
108108
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
109+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
109110
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".matmul1", configDimRowMajorGlobalWorker);
110111
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
111112
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
112113
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
113-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
114114
}
115115
return tornadoForwardScheduler;
116116
}
@@ -230,12 +230,24 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
230230
config.kvDim(), // kvDim
231231
layerIndex, // layer offset
232232
config.contextLength()); // max sequence length
233+
234+
// Flash Attention
235+
unifiedLayer.task("attention",
236+
Qwen2Kernels::processHeadsFlashAttention,
237+
context,
238+
qwen2State.wrapQ, // query vectors
239+
qwen2State.wrapKeyCache, // key cache
240+
qwen2State.wrapValueCache, // value cache
241+
qwen2State.wrapXb, // output: attention result
242+
config.numberOfHeads(), // nHeads
243+
config.headSize(), // headSize
244+
config.kvDim(), // kvDim
245+
config.kvMul(), // kvMul (nHeads / nHeadKv)
246+
qwen2State.positionHolder, // position
247+
layerIndex, // layer index
248+
config.contextLength()); // context length
233249

234-
unifiedLayer.task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context,
235-
state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
236-
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
237-
state.positionHolder, layerIndex, config.contextLength())
238-
.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context,
250+
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context,
239251
state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asByteArray(), config.dim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
240252
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN,
241253
state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);

0 commit comments

Comments
 (0)