Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion external/tornadovm
Submodule tornadovm updated 181 files
2 changes: 1 addition & 1 deletion llama-tornado
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def create_parser() -> argparse.ArgumentParser:
)
debug_group.add_argument(
"--profiler-dump-dir",
default="/home/mikepapadim/repos/gpu-llama3.java/prof.json",
default="/home/ruiqi/GPULlama3.java/prof.json",
help="Directory for profiler output",
)

Expand Down
4 changes: 2 additions & 2 deletions set_paths
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

# Resolve root of this project (LLaMA3) and TornadoVM
export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm"
export TORNADO_ROOT="/home/ruiqi/TornadoVM_OCL/TornadoVM"

# Set the path to TornadoVM SDK binaries
export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk"
export TORNADO_SDK="/home/ruiqi/TornadoVM_OCL/TornadoVM/bin/sdk"

# Add TornadoVM and LLaMA bin directories to PATH
export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public TransformerComputeKernelsLayered() {
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
Expand Down Expand Up @@ -80,20 +81,92 @@ public static void reductionOneBlockWithLayer(KernelContext context, FloatArray
}

/**
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
* Performs RMS (Root Mean Square) normalization using parallel reduction. It first computes the variance and scaling factor across all work groups,
* then it applies the computed normalization factor to input and weight elements.
*
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* Algorithm: 1. Each thread computes square of its input element 2. Work group performs parallel reduction of squares 3. Partial sums stored per work group 4. All thread combines all partial
* sums and computes normalization factor 5. Applies the computed normalization factor to input and weight elements.
*
* @param context
* Kernel execution context
* @param output
* Array for normalized output
* Array to store partial sums and final normalization factor
* @param x
* Input values to normalize
* Input array to normalize
* @param weights
* Weight values for each element
* @param temp
* Temporary array containing normalization factor at index 0
* @param size
* Number of elements to process
* @param ermsNorm
* Epsilon value squared for numerical stability
* @param localMemSize
* Size of local memory allocation (must match work group size)
*/

public static void reductionOneBlockWithLayerFuse(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp, int size, float ermsNorm, int localMemSize) {
int gid = context.globalIdx;
int lid = context.localIdx;
int groupId = context.groupIdx;
int groupSize = context.localGroupSizeX;

// Allocate local memory with the provided size
float[] localX = context.allocateFloatLocalArray(localMemSize);

// Load input value and compute square
if (gid < size) {
float v = x.get(gid);
localX[lid] = v * v;
} else {
localX[lid] = 0.0f;
}

// Perform parallel reduction within the work group
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
context.localBarrier();
if (lid < stride) {
localX[lid] += localX[lid + stride];
}
}

// Each workgroup stores its partial sum in a different location
if (lid == 0) {
// Store the partial sum from each workgroup
temp.set(groupId, localX[0]);
}

context.globalBarrier();

float localss = 0.0f;
int numGroups = (size + groupSize - 1) / groupSize;
for (int i = 0; i < numGroups; i++) { // Assuming 8 workgroups
localss += temp.get(i);
}
localss /= size;
localss += ermsNorm;
localss = 1.0f / TornadoMath.sqrt(localss);

if (gid < size) {
float in = x.get(gid);
float w = weights.get(gid);
output.set(gid, w * (localss * in));
}
}

/**
* Applies the computed normalization factor to input and weight elements. This is the second phase of RMS normalization.
* <p>
* Formula: output[i] = weight[i] * (normalizationFactor * x[i])
*
* @param context Kernel execution context
* @param output Array for normalized output
* @param x Input values to normalize
* @param weights Weight values for each element
* @param temp Temporary array containing normalization factor at index 0
*/
public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray output, FloatArray x, FloatArray weights, FloatArray temp) {
int gid = context.globalIdx;
Expand All @@ -104,25 +177,17 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray

/**
* Copies keys and values into the key-value cache for attention computation. Enables efficient access to past key-value pairs during autoregressive generation.
*
* <p>
* Cache layout: [layer][position][dimension] - Each layer has its own key and value cache - Each position in sequence has a key and value vector
*
* @param destKeyCache
* Destination array for key cache
* @param srcKey
* Source keys to copy
* @param destValueCache
* Destination array for value cache
* @param srcValue
* Source values to copy
* @param positioNlayer
* Array containing current position
* @param kvDim
* Dimension of key/value vectors
* @param layer
* Current transformer layer index
* @param contextLength
* Maximum sequence length
* @param destKeyCache Destination array for key cache
* @param srcKey Source keys to copy
* @param destValueCache Destination array for value cache
* @param srcValue Source values to copy
* @param positioNlayer Array containing current position
* @param kvDim Dimension of key/value vectors
* @param layer Current transformer layer index
* @param contextLength Maximum sequence length
*/
public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {

Expand Down Expand Up @@ -158,21 +223,15 @@ public static void splitQKV(FloatArray qkv, FloatArray q, FloatArray k, FloatArr
/**
* Applies Rotary Position Encoding (RoPE) to query and key vectors. RoPE rotates pairs of dimensions based on their position in the sequence, enabling the model to learn relative positional
* information.
*
* <p>
* For each pair of dimensions (2*i, 2*i+1): - Compute rotation angle based on position and frequency - Apply 2D rotation to the pair
*
* @param context
* Kernel execution context
* @param positionHolder
* Array containing current position
* @param sq
* Query vectors to rotate
* @param sk
* Key vectors to rotate
* @param kv_dim
* Dimension of key/value vectors
* @param head_size
* Dimension of each attention head
* @param context Kernel execution context
* @param positionHolder Array containing current position
* @param sq Query vectors to rotate
* @param sk Key vectors to rotate
* @param kv_dim Dimension of key/value vectors
* @param head_size Dimension of each attention head
*/
public static void ropeRotation(KernelContext context, IntArray positionHolder, FloatArray sq, FloatArray sk, int kv_dim, int head_size) {
int i = context.globalIdx * 2;
Expand Down Expand Up @@ -247,31 +306,20 @@ public static void ropeRotationPhi3(KernelContext context, IntArray positionHold

/**
* Computes attention for a single head. Implements scaled dot-product attention with softmax normalization.
*
* <p>
* Steps: 1. Compute attention scores: Q·K / sqrt(head_size) 2. Apply softmax (with max subtraction for numerical stability) 3. Compute weighted sum of values
*
* @param allQ
* All query vectors
* @param key_cache
* Cached keys
* @param value_cache
* Cached values
* @param allXb
* Output buffer
* @param h
* Head index to process
* @param headSize
* Dimension per head
* @param kvDim
* Key/value dimension
* @param kvMul
* Key multiplier for grouped attention
* @param loff
* Layer offset in cache
* @param pos
* Current position
* @param wrapAtt
* Attention weights buffer
* @param allQ All query vectors
* @param key_cache Cached keys
* @param value_cache Cached values
* @param allXb Output buffer
* @param h Head index to process
* @param headSize Dimension per head
* @param kvDim Key/value dimension
* @param kvMul Key multiplier for grouped attention
* @param loff Layer offset in cache
* @param pos Current position
* @param wrapAtt Attention weights buffer
*/
private static void processHeadTornado(FloatArray allQ, FloatArray key_cache, FloatArray value_cache, FloatArray allXb, int h, int headSize, int kvDim, int kvMul, long loff, int pos,
FloatArray wrapAtt) {
Expand Down Expand Up @@ -627,23 +675,16 @@ public static void processHeadsFlashAttentionOpt(KernelContext context, FloatArr

/**
* Performs optimized matrix-vector multiplication where each work group processes one row of the matrix.
*
* <p>
* Algorithm: 1. Each work group handles one output dimension 2. Threads in work group compute partial dot products 3. Parallel reduction yields final row result
*
* @param context
* Kernel execution context
* @param x
* Input vector
* @param hb
* Output vector
* @param w
* Weight matrix (row-major)
* @param n
* Input dimension
* @param d
* Output dimension
* @param localWorkGroupSize
* Number of threads per work group
* @param context Kernel execution context
* @param x Input vector
* @param hb Output vector
* @param w Weight matrix (row-major)
* @param n Input dimension
* @param d Output dimension
* @param localWorkGroupSize Number of threads per work group
*/
public static void matrixVectorGeneric(KernelContext context, FloatArray x, FloatArray hb, FloatArray w, int n, int d, int localWorkGroupSize) {
// One row per workgroup (not per thread)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public LlamaFP16FFNLayers(String taskGraph, State state, Weights weights, Config
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
WorkerGrid ropeWorker = WorkerGridFactory.genericWorker(config.dim()/2, 128);
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), 256);
//System.out.println("llama config dim: " + config.dim());

int configDimRowMajorGlobal = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid configDimRowMajorGlobalWorker = WorkerGridFactory.genericWorker(configDimRowMajorGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
Expand All @@ -54,9 +55,7 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", configHiddenDimRowMajorWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
}
Expand Down Expand Up @@ -112,13 +111,8 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
weights.w3Layered[layerIndex].asHalfFloatArray());
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
unifiedLayer
.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.temp, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
if (shouldUseFinalNormalization()) {
unifiedLayer.task("reductionFinalNormalization", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.temp,
config.dim(), config.rmsNormEps());
}
unifiedLayer.task("mapContext", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp)
.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
.task("reductionsOneBlock", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, state.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), state.temp, config.dim(), config.rmsNormEps(), state.localSize);
unifiedLayer.task("qmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
LOCAL_WORK_GROUP_SIZE_ALLOC)
.task("kmatmul", TransformerComputeKernelsLayered::matrixVectorGeneric, context, state.wrapXb, state.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), config.dim(), config.kvDim(),
LOCAL_WORK_GROUP_SIZE_ALLOC)
Expand All @@ -130,12 +124,8 @@ TaskGraph setupSingleFFNLayer(LlamaTornadoWeights weights, Configuration config,
configureAttention(unifiedLayer, layerIndex);
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapXb, state.wrapX, weights.woLayered[layerIndex].asHalfFloatArray(), config.dim(), config.dim(),
LOCAL_WORK_GROUP_SIZE_ALLOC)
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer, context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize);
if (shouldUseFinalNormalization()) {
unifiedLayer.task("reductionFinalNormalizationFFN", TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN, config.dim(), config.rmsNormEps());
}
unifiedLayer.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN)
.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(),
.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayerFuse, context, state.wrapXb, state.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), state.tempFFN, config.dim(), config.rmsNormEps(), state.localSize);
unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context, state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(),
weights.w3Layered[layerIndex].asHalfFloatArray(), config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context, state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asHalfFloatArray(), config.hiddenDim(),
config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC).persistOnDevice(state.wrapX);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ public Phi3FP16FFNLayers(String taskGraphName, Phi3State state, Phi3TornadoWeigh
public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
// RMS norm worker
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize);

// Combined QKV matmul worker
int matmulQkvGlobal = opSize * LOCAL_WORK_GROUP_SIZE_ALLOC;
WorkerGrid matmulQkvRowMajorWorker = WorkerGridFactory.genericWorker(matmulQkvGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
Expand Down
Loading