@@ -172,6 +172,121 @@ public static void ropeRotation(KernelContext context, IntArray positionNlayer,
172172
173173 }
174174
175+ /**
176+ * Orchestrates parallel multi-head attention computation across all heads.
177+ * Each head processes attention independently in parallel.
178+ *
179+ * Attention computation:
180+ * 1. Compute attention scores (Q·K)
181+ * 2. Apply softmax for attention weights
182+ * 3. Compute weighted sum of values (attention·V)
183+ *
184+ * @param q Query vectors for all heads
185+ * @param key_cache Cached key vectors
186+ * @param value_cache Cached value vectors
187+ * @param xb Output buffer for attention results
188+ * @param nHeads Number of attention heads
189+ * @param headSize Dimension of each head
190+ * @param kvDim Total key/value dimension
191+ * @param kvMul Key/value head multiplier for grouped-query attention
192+ * @param seqLen Current sequence length
193+ * @param positionHolder Array containing position and layer info
194+ * @param wrapAtt Buffer for attention weights
195+ * @param layer Current transformer layer
196+ * @param contextLength Maximum context length
197+ */
198+ public static void processHeadsParallel (FloatArray q , FloatArray key_cache , FloatArray value_cache , FloatArray xb , int nHeads , int headSize , int kvDim , int kvMul , int seqLen ,
199+ IntArray positionHolder , FloatArray wrapAtt , int layer , int contextLength ) {
200+
201+ int pos = positionHolder .get (0 );
202+ int loff = layer * contextLength * kvDim ;
203+
204+ // Parallelize computation across attention heads
205+ for (@ Parallel int h = 0 ; h < nHeads ; h ++) {
206+ // Process each head in parallel
207+ processHeadTornado (q , key_cache , value_cache , xb , h , headSize , kvDim , kvMul , loff , pos , wrapAtt );
208+ }
209+ }
210+
211+ /**
212+ * Computes attention for a single head.
213+ * Implements scaled dot-product attention with softmax normalization.
214+ *
215+ * Steps:
216+ * 1. Compute attention scores: Q·K / sqrt(head_size)
217+ * 2. Apply softmax (with max subtraction for numerical stability)
218+ * 3. Compute weighted sum of values
219+ *
220+ * @param allQ All query vectors
221+ * @param key_cache Cached keys
222+ * @param value_cache Cached values
223+ * @param allXb Output buffer
224+ * @param h Head index to process
225+ * @param headSize Dimension per head
226+ * @param kvDim Key/value dimension
227+ * @param kvMul Key multiplier for grouped attention
228+ * @param loff Layer offset in cache
229+ * @param pos Current position
230+ * @param wrapAtt Attention weights buffer
231+ */
232+ private static void processHeadTornado (FloatArray allQ , FloatArray key_cache , FloatArray value_cache , FloatArray allXb , int h , int headSize , int kvDim , int kvMul , long loff , int pos ,
233+ FloatArray wrapAtt ) {
234+
235+ // Base index for this head's attention weights
236+ int headOffset = h * (pos + 1 );
237+
238+ // STEP 1: Calculate attention scores for all timesteps
239+ for (int t = 0 ; t <= pos ; t ++) {
240+ int kvHeadIdx = h / kvMul ;
241+ int keyOffset = (int ) (loff + t * kvDim + kvHeadIdx * headSize );
242+
243+ float score = 0.0f ;
244+ for (int i = 0 ; i < headSize ; i ++) {
245+ score += allQ .get (h * headSize + i ) * key_cache .get (keyOffset + i );
246+ }
247+ score = score / TornadoMath .sqrt (headSize );
248+
249+ // Store in attention buffer
250+ wrapAtt .set (headOffset + t , score );
251+ }
252+
253+ // STEP 2: Find max score for softmax stability
254+ float maxScore = wrapAtt .get (headOffset );
255+ for (int t = 1 ; t <= pos ; t ++) {
256+ float val = wrapAtt .get (headOffset + t );
257+ if (val > maxScore ) {
258+ maxScore = val ;
259+ }
260+ }
261+
262+ // STEP 3: Compute exponentials and sum
263+ float sum = 0.0f ;
264+ for (int t = 0 ; t <= pos ; t ++) {
265+ int idx = headOffset + t ;
266+ float expScore = TornadoMath .exp (wrapAtt .get (idx ) - maxScore );
267+ wrapAtt .set (idx , expScore );
268+ sum += expScore ;
269+ }
270+
271+ // STEP 4: Normalize
272+ float normFactor = (sum > 0.0f ) ? (1.0f / sum ) : (1.0f / (pos + 1 ));
273+ for (int t = 0 ; t <= pos ; t ++) {
274+ int idx = headOffset + t ;
275+ wrapAtt .set (idx , wrapAtt .get (idx ) * normFactor );
276+ }
277+
278+ // STEP 5: Compute weighted sum of values for each dimension
279+ for (int i = 0 ; i < headSize ; i ++) {
280+ float weightedSum = 0.0f ;
281+ for (int t = 0 ; t <= pos ; t ++) {
282+ int kvHeadIdx = h / kvMul ;
283+ int valueOffset = (int ) (loff + t * kvDim + kvHeadIdx * headSize );
284+ weightedSum += wrapAtt .get (headOffset + t ) * value_cache .get (valueOffset + i );
285+ }
286+ allXb .set (h * headSize + i , weightedSum );
287+ }
288+ }
289+
175290 public static void processHeadsFlashAttention (
176291 KernelContext context ,
177292 FloatArray q ,
0 commit comments