diff --git a/set_paths b/set_paths index 0f356cc8..d5909616 100644 --- a/set_paths +++ b/set_paths @@ -22,4 +22,4 @@ echo "[INFO] Environment configured for LLaMA3 with TornadoVM at: $TORNADO_SDK" # 3. You can run LLaMA3 with GPU acceleration using TornadoVM # # To use this script: source ./setup_environment.sh -# or: . ./setup_environment.sh +# or: . ./setup_environment.sh \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java index d1803e41..803f38f4 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java @@ -284,6 +284,7 @@ public static void fusedRmsNormFFNGateUpQ8_0( * @param localMemSize * Size of local memory allocation (must match work group size) */ + public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) { int gid = context.globalIdx; int lid = context.localIdx; @@ -331,20 +332,170 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray } /** - * Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization. + * Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups, + * then it applies the computed normalization factor to input and weight elements. * + *
* Formula: output[i] = weight[i] * (normalizationFactor * x[i]) * + * Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial + * sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements. + * * @param context * Kernel execution context * @param output - * Array for normalized output + * Array to store partial sums and final normalization factor + * @param x + * Input array to normalize + * @param weights + * Weight values for each element + * @param temp + * Temporary array containing normalization factor at index 0 + * @param size + * Number of elements to process + * @param ermsNorm + * Epsilon value squared for numerical stability + * @param localMemSize + * Size of local memory allocation (must match work group size) + */ + + public static void reductionOneBlockWithLayerFuse(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) { + int gid = context.globalIdx; + int lid = context.localIdx; + int groupId = context.groupIdx; + int groupSize = context.localGroupSizeX; + + // Allocate local memory with the provided size + float[] localX = context.allocateFloatLocalArray(localMemSize); + + // Load input value and compute square + if (gid < size) { + float v = x.get(gid); + localX[lid] = v * v; + } else { + localX[lid] = 0.0f; + } + + // Perform parallel reduction within the work group + for (int stride = (groupSize / 2); stride > 0; stride /= 2) { + context.localBarrier(); + if (lid < stride) { + localX[lid] += localX[lid + stride]; + } + } + + // Each workgroup stores its partial sum in a different location + if (lid == 0) { + // Store the partial sum from each workgroup + temp.set(groupId, localX[0]); + } + + context.globalBarrier(); + + float localss = 0.0f; + int numGroups = (size + groupSize - 1) / groupSize; + for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups + localss += temp.get(i); + } + localss /= size; + localss += ermsNorm; + localss = 1.0f / TornadoMath.sqrt(localss); + + if (gid < size) { + float in = x.get(gid); + float w = weights.get(gid); + output.set(gid, w * (localss * in)); + } + } + + /** + * Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups, + * then it applies the computed normalization factor to input and weight elements. + * + *
+ * Formula: output[i] = weight[i] * (normalizationFactor * x[i]) + * + * Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial + * sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements. + * + * @param context + * Kernel execution context + * @param outputFP16 + * Half float array to store partial sums and final normalization factor * @param x - * Input values to normalize + * Input array to normalize * @param weights * Weight values for each element * @param temp * Temporary array containing normalization factor at index 0 + * @param size + * Number of elements to process + * @param ermsNorm + * Epsilon value squared for numerical stability + * @param localMemSize + * Size of local memory allocation (must match work group size) + */ + + public static void reductionOneBlockWithLayerFuseFP16(KernelContext context, HalfFloatArray outputFP16, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) { + int gid = context.globalIdx; + int lid = context.localIdx; + int groupId = context.groupIdx; + int groupSize = context.localGroupSizeX; + + // Allocate local memory with the provided size + float[] localX = context.allocateFloatLocalArray(localMemSize); + + // Load input value and compute square + if (gid < size) { + float v = x.get(gid); + localX[lid] = v * v; + } else { + localX[lid] = 0.0f; + } + + // Perform parallel reduction within the work group + for (int stride = (groupSize / 2); stride > 0; stride /= 2) { + context.localBarrier(); + if (lid < stride) { + localX[lid] += localX[lid + stride]; + } + } + + // Each workgroup stores its partial sum in a different location + if (lid == 0) { + // Store the partial sum from each workgroup + temp.set(groupId, localX[0]); + } + + context.globalBarrier(); + + float localss = 0.0f; + int numGroups = (size + groupSize - 1) / groupSize; + for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups + localss += temp.get(i); + } + localss /= size; + localss += ermsNorm; + localss = 1.0f / TornadoMath.sqrt(localss); + + if (gid < size) { + float in = x.get(gid); + float w = weights.get(gid); + outputFP16.set(gid, new HalfFloat(w * (localss * in))); + } + } + + + /** + * Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization. + *
+ * Formula: output[i] = weight[i] * (normalizationFactor * x[i]) + * + * @param context Kernel execution context + * @param output Array for normalized output + * @param x Input values to normalize + * @param weights Weight values for each element + * @param temp Temporary array containing normalization factor at index 0 */ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) { int gid = context.globalIdx; @@ -355,25 +506,17 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray /** * Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation. - * + *
* Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector * - * @param destKeyCache - * Destination array for key cache - * @param srcKey - * Source keys to copy - * @param destValueCache - * Destination array for value cache - * @param srcValue - * Source values to copy - * @param positioNlayer - * Array containing current position - * @param kvDim - * Dimension of key/value vectors - * @param layer - * Current transformer layer index - * @param contextLength - * Maximum sequence length + * @param destKeyCache Destination array for key cache + * @param srcKey Source keys to copy + * @param destValueCache Destination array for value cache + * @param srcValue Source values to copy + * @param positioNlayer Array containing current position + * @param kvDim Dimension of key/value vectors + * @param layer Current transformer layer index + * @param contextLength Maximum sequence length */ public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) { @@ -463,21 +606,15 @@ public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArr /** * Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional * information. - * + *
* For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair * - * @param context - * Kernel execution context - * @param positionHolder - * Array containing current position - * @param sq - * Query vectors to rotate - * @param sk - * Key vectors to rotate - * @param kv_dim - * Dimension of key/value vectors - * @param head_size - * Dimension of each attention head + * @param context Kernel execution context + * @param positionHolder Array containing current position + * @param sq Query vectors to rotate + * @param sk Key vectors to rotate + * @param kv_dim Dimension of key/value vectors + * @param head_size Dimension of each attention head */ public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) { int i = context.globalIdx * 2; @@ -552,31 +689,20 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold /** * Computes attention for a single head. Implements scaled dot-product attention with softmax normalization. - * + *
* Steps: 1. Compute attention scores: Q·K / sqrt(head_size) 2. Apply softmax (with max subtraction for numerical stability) 3. Compute weighted sum of values * - * @param allQ - * All query vectors - * @param key_cache - * Cached keys - * @param value_cache - * Cached values - * @param allXb - * Output buffer - * @param h - * Head index to process - * @param headSize - * Dimension per head - * @param kvDim - * Key/value dimension - * @param kvMul - * Key multiplier for grouped attention - * @param loff - * Layer offset in cache - * @param pos - * Current position - * @param wrapAtt - * Attention weights buffer + * @param allQ All query vectors + * @param key_cache Cached keys + * @param value_cache Cached values + * @param allXb Output buffer + * @param h Head index to process + * @param headSize Dimension per head + * @param kvDim Key/value dimension + * @param kvMul Key multiplier for grouped attention + * @param loff Layer offset in cache + * @param pos Current position + * @param wrapAtt Attention weights buffer */ private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos, FloatArray wrapAtt) { @@ -1117,23 +1243,16 @@ public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArr /** * Performs optimized matrix-vector multiplication where each work group processes one row of the matrix. - * + *
* Algorithm: 1. Each work group handles one output dimension 2. Threads in work group compute partial dot products 3. Parallel reduction yields final row result * - * @param context - * Kernel execution context - * @param x - * Input vector - * @param hb - * Output vector - * @param w - * Weight matrix (row-major) - * @param n - * Input dimension - * @param d - * Output dimension - * @param localWorkGroupSize - * Number of threads per work group + * @param context Kernel execution context + * @param x Input vector + * @param hb Output vector + * @param w Weight matrix (row-major) + * @param n Input dimension + * @param d Output dimension + * @param localWorkGroupSize Number of threads per work group */ public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) { // One row per workgroup (not per thread) diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java index 8d105e89..b22ff923 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/LlamaFP16FFNLayers.java @@ -50,7 +50,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) for (int i = 0; i < config.numberOfLayers(); i++) { // === Attention Block === tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply_fp16", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQKVWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker); @@ -199,21 +198,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, // === Attention Block === // RMS Normalization unifiedLayer.task("attn_rms_reduce", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, state.temp, state.wrapX, + TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuseFP16, + context, state.wrapXbFP16, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("attn_rms_finalize", - TransformerComputeKernelsLayered::reductionFinalNormalization, - context, state.temp, config.dim(), config.rmsNormEps()); - } - - unifiedLayer.task("attn_rms_apply_fp16", - TransformerComputeKernels::mapContextWithQuantize, - context, state.wrapXbFP16, state.wrapX, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); - // QKV Projection (fused) unifiedLayer.task("qkv_projection", TransformerComputeKernelsLayered::fusedQKVMatmulX, diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java index ba1b6a79..c170b039 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/LlamaQ8_0FFNLayers.java @@ -161,21 +161,10 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config, // === Attention Block === // RMS Normalization unifiedLayer.task("attn_rms_reduce", - TransformerComputeKernelsLayered::reductionOneBlockWithLayer, - context, state.temp, state.wrapX, + TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, + context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp, config.dim(), config.rmsNormEps(), state.localSize); - if (shouldUseFinalNormalization()) { - unifiedLayer.task("attn_rms_finalize", - TransformerComputeKernelsLayered::reductionFinalNormalization, - context, state.temp, config.dim(), config.rmsNormEps()); - } - - unifiedLayer.task("attn_rms_apply", - TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, - context, state.wrapXb, state.wrapX, - weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp); - // QKV Projection (fused with Q8 dequantization) unifiedLayer.task("qkv_projection", TransformerComputeKernelsLayered::fusedQKVMatmulQ8, @@ -306,7 +295,6 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) // --- Attention Block --- // RMS Normalization tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker); - tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_apply", rmsNormWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qkv_projection", fusedQkvWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWithCacheWorker); tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);