Skip to content

Commit bf168a6

Browse files
committed
Include previous implementation of parallel attention
1 parent d023c74 commit bf168a6

File tree

1 file changed

+115
-0
lines changed

1 file changed

+115
-0
lines changed

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,121 @@ public static void ropeRotation(KernelContext context, IntArray positionNlayer,
172172

173173
}
174174

175+
/**
176+
* Orchestrates parallel multi-head attention computation across all heads.
177+
* Each head processes attention independently in parallel.
178+
*
179+
* Attention computation:
180+
* 1. Compute attention scores (Q·K)
181+
* 2. Apply softmax for attention weights
182+
* 3. Compute weighted sum of values (attention·V)
183+
*
184+
* @param q Query vectors for all heads
185+
* @param key_cache Cached key vectors
186+
* @param value_cache Cached value vectors
187+
* @param xb Output buffer for attention results
188+
* @param nHeads Number of attention heads
189+
* @param headSize Dimension of each head
190+
* @param kvDim Total key/value dimension
191+
* @param kvMul Key/value head multiplier for grouped-query attention
192+
* @param seqLen Current sequence length
193+
* @param positionHolder Array containing position and layer info
194+
* @param wrapAtt Buffer for attention weights
195+
* @param layer Current transformer layer
196+
* @param contextLength Maximum context length
197+
*/
198+
public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
199+
IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {
200+
201+
int pos = positionHolder.get(0);
202+
int loff = layer * contextLength * kvDim;
203+
204+
// Parallelize computation across attention heads
205+
for (@Parallel int h = 0; h < nHeads; h++) {
206+
// Process each head in parallel
207+
processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt);
208+
}
209+
}
210+
211+
/**
212+
* Computes attention for a single head.
213+
* Implements scaled dot-product attention with softmax normalization.
214+
*
215+
* Steps:
216+
* 1. Compute attention scores: Q·K / sqrt(head_size)
217+
* 2. Apply softmax (with max subtraction for numerical stability)
218+
* 3. Compute weighted sum of values
219+
*
220+
* @param allQ All query vectors
221+
* @param key_cache Cached keys
222+
* @param value_cache Cached values
223+
* @param allXb Output buffer
224+
* @param h Head index to process
225+
* @param headSize Dimension per head
226+
* @param kvDim Key/value dimension
227+
* @param kvMul Key multiplier for grouped attention
228+
* @param loff Layer offset in cache
229+
* @param pos Current position
230+
* @param wrapAtt Attention weights buffer
231+
*/
232+
private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos,
233+
FloatArray wrapAtt) {
234+
235+
// Base index for this head's attention weights
236+
int headOffset = h * (pos + 1);
237+
238+
// STEP 1: Calculate attention scores for all timesteps
239+
for (int t = 0; t <= pos; t++) {
240+
int kvHeadIdx = h / kvMul;
241+
int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
242+
243+
float score = 0.0f;
244+
for (int i = 0; i < headSize; i++) {
245+
score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i);
246+
}
247+
score = score / TornadoMath.sqrt(headSize);
248+
249+
// Store in attention buffer
250+
wrapAtt.set(headOffset + t, score);
251+
}
252+
253+
// STEP 2: Find max score for softmax stability
254+
float maxScore = wrapAtt.get(headOffset);
255+
for (int t = 1; t <= pos; t++) {
256+
float val = wrapAtt.get(headOffset + t);
257+
if (val > maxScore) {
258+
maxScore = val;
259+
}
260+
}
261+
262+
// STEP 3: Compute exponentials and sum
263+
float sum = 0.0f;
264+
for (int t = 0; t <= pos; t++) {
265+
int idx = headOffset + t;
266+
float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore);
267+
wrapAtt.set(idx, expScore);
268+
sum += expScore;
269+
}
270+
271+
// STEP 4: Normalize
272+
float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos + 1));
273+
for (int t = 0; t <= pos; t++) {
274+
int idx = headOffset + t;
275+
wrapAtt.set(idx, wrapAtt.get(idx) * normFactor);
276+
}
277+
278+
// STEP 5: Compute weighted sum of values for each dimension
279+
for (int i = 0; i < headSize; i++) {
280+
float weightedSum = 0.0f;
281+
for (int t = 0; t <= pos; t++) {
282+
int kvHeadIdx = h / kvMul;
283+
int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
284+
weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i);
285+
}
286+
allXb.set(h * headSize + i, weightedSum);
287+
}
288+
}
289+
175290
public static void processHeadsFlashAttention(
176291
KernelContext context,
177292
FloatArray q,

0 commit comments

Comments
 (0)