From d0966eb62ec6e07304c53c28d7303e4f1adffeab Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 26 Nov 2025 16:05:01 +0200 Subject: [PATCH 1/3] Refactor tensor loading and introduce support for Half-Float precision in TornadoVM acceleration. --- set_paths | 4 ++-- .../gpullama3/inference/InferenceCore.java | 2 +- .../gpullama3/inference/state/LlamaState.java | 3 +++ .../gpullama3/inference/state/Phi3State.java | 2 ++ .../gpullama3/inference/state/Qwen2State.java | 2 ++ .../gpullama3/inference/state/Qwen3State.java | 5 ++++- .../gpullama3/inference/state/State.java | 5 +++++ .../model/loader/LlamaModelLoader.java | 2 +- .../model/loader/MistralModelLoader.java | 2 +- .../kernels/TransformerComputeKernels.java | 14 ++++++++++++++ .../gpullama3/tornadovm/layers/Activation.java | 18 ++++++++++++------ 11 files changed, 47 insertions(+), 12 deletions(-) diff --git a/set_paths b/set_paths index fd807c5e..fe79810e 100644 --- a/set_paths +++ b/set_paths @@ -6,10 +6,10 @@ # Resolve root of this project (LLaMA3) and TornadoVM export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" +#export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" # Set the path to TornadoVM SDK binaries -export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" +#export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" # Add TornadoVM and LLaMA bin directories to PATH export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}" diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 8104e561..33f8c0e8 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -583,7 +583,7 @@ public static FloatArray forwardTornadoVM(Model model, State state, int token, i final Configuration configuration = model.configuration(); final TornadoWeights weights = (TornadoWeights) model.weights(); - MemorySegment.copy(weights.getTokenEmbeddingTable().asFloatArray().getSegment(), (long) token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES); + MemorySegment.copy(weights.getTokenEmbeddingTable().asHalfFloatArray().getSegment(), (long) token * configuration.dim() * Short.BYTES, state.embeddingX.getSegment(), 0, configuration.dim() * Short.BYTES); return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position); } diff --git a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java index 9f9fdcdb..38af3877 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java @@ -4,6 +4,7 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -52,6 +53,8 @@ protected StateFields createStateFields(Configuration config) { fields.wrapHb = new FloatArray(config.hiddenDim()); fields.wrapHb2 = new FloatArray(config.hiddenDim()); + fields.embeddingX = new HalfFloatArray(config.dim()); + fields.wrapLogits = new FloatArray(config.vocabularySize()); fields.wrapQ = new FloatArray(config.dim()); fields.wrapK = new FloatArray(config.dim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java index d29ba130..dad2fb58 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -79,6 +80,7 @@ protected StateFields createStateFields(Configuration config) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(contextLength, kvDim)).limit(nLayers).toArray(FloatTensor[]::new); // TornadoVM wrapper arrays for GPU acceleration + fields.embeddingX = new HalfFloatArray(config.dim()); fields.wrapX = new FloatArray(dim); fields.wrapXb = new FloatArray(dim); fields.wrapXb2 = new FloatArray(dim); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java index da6d7046..577c7cdc 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -40,6 +41,7 @@ protected StateFields createStateFields(Configuration configuration) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); // TornadoVM wrappers with Qwen2 dimensions + fields.embeddingX = new HalfFloatArray(config.dim()); fields.wrapX = new FloatArray(config.dim()); fields.wrapXb = new FloatArray(config.dim()); fields.wrapXb2 = new FloatArray(config.dim()); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java index d6a6d087..722bc75d 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java @@ -5,6 +5,7 @@ import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; import java.util.stream.Stream; @@ -65,6 +66,8 @@ protected StateFields createStateFields(Configuration configuration) { fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new); // TornadoVM wrappers with Qwen3-specific sizes + + fields.embeddingX = new HalfFloatArray(config.dim()); fields.wrapX = new FloatArray(config.dim()); fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads()); fields.wrapXb2 = new FloatArray(config.dim()); @@ -74,7 +77,7 @@ protected StateFields createStateFields(Configuration configuration) { fields.wrapQ = new FloatArray(nEmbdHeadK * config.numberOfHeads()); fields.wrapK = new FloatArray(nEmbdKGqa); fields.wrapV = new FloatArray(nEmbdKGqa); - + fields.embeddingX = new HalfFloatArray(config.dim()); fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers()); fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers()); fields.wrapValueCache.init(0.f); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/State.java b/src/main/java/org/beehive/gpullama3/inference/state/State.java index 01d94936..9de2b314 100644 --- a/src/main/java/org/beehive/gpullama3/inference/state/State.java +++ b/src/main/java/org/beehive/gpullama3/inference/state/State.java @@ -2,7 +2,9 @@ import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.model.Configuration; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; import uk.ac.manchester.tornado.api.types.arrays.IntArray; /** @@ -57,6 +59,7 @@ public abstract class State { public final FloatArray wrapValueCache; // FloatArray wrapper for the value cache, optimized for TornadoVM. public final IntArray positionHolder; + public HalfFloatArray embeddingX; // store inter public int localSize; public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size. @@ -88,6 +91,7 @@ protected State(Configuration config, int batchsize) { this.keyCache = fields.keyCache; this.valueCache = fields.valueCache; + this.embeddingX = fields.embeddingX; this.wrapX = fields.wrapX; this.wrapXb = fields.wrapXb; this.wrapXb2 = fields.wrapXb2; @@ -121,6 +125,7 @@ protected static class StateFields { public FloatArray wrapQ, wrapK, wrapV, wrapAtt, wrapKeyCache, wrapValueCache; public IntArray positionHolder; public FloatArray temp, tempFFN, tempLogits; + public HalfFloatArray embeddingX; } @Override diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index 069704a7..d0f6f758 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -120,7 +120,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new LlamaTornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), diff --git a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java index 25c493db..79a4edfe 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java @@ -130,7 +130,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new LlamaTornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java index 7f69e496..97d7ab6a 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernels.java @@ -2,7 +2,9 @@ import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.HalfFloat; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; public class TransformerComputeKernels { @@ -19,6 +21,18 @@ public static void emptyTaskToForceCopyIn(FloatArray buffer) { } } + public static void convertFP16toFP32(KernelContext context, HalfFloatArray x, FloatArray wrapX) { + int i = context.globalIdx; + wrapX.set(i, x.get(i).getFloat32()); + } + + public static void convertFP32toFP16(KernelContext context, FloatArray wrapX, HalfFloatArray x) { + int i = context.globalIdx; + float valInput = wrapX.get(i); + HalfFloat val = new HalfFloat(valInput); + x.set(i,val); + } + /** * Performs RMS (Root Mean Square) normalization using parallel reduction. * This is a two-phase reduction: first within work groups, then across work groups. diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java index 16783829..5f156512 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/Activation.java @@ -7,8 +7,10 @@ import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; import uk.ac.manchester.tornado.api.GridScheduler; import uk.ac.manchester.tornado.api.ImmutableTaskGraph; +import uk.ac.manchester.tornado.api.KernelContext; import uk.ac.manchester.tornado.api.TaskGraph; import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.WorkerGrid1D; import uk.ac.manchester.tornado.api.enums.DataTransferMode; public class Activation extends AbstractLayer { @@ -17,16 +19,20 @@ public class Activation extends AbstractLayer { public Activation(String taskGraphHandle, State state, Weights weights, Configuration config) { super(taskGraphHandle, state, weights, config); - // formatter:off - this.activationUpdate = new TaskGraph(taskGraphHandle).transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX) - .task("updateX", TransformerComputeKernels::emptyTaskToForceCopyIn, state.wrapX).persistOnDevice(state.wrapX); - // formatter:on + KernelContext kernelContext = new KernelContext(); + // @formatter:off + this.activationUpdate = new TaskGraph(taskGraphHandle) + .transferToDevice(DataTransferMode.EVERY_EXECUTION, state.embeddingX) + .task("updateX", TransformerComputeKernels::convertFP16toFP32, kernelContext, state.embeddingX, state.wrapX) + .persistOnDevice(state.wrapX); + // @formatter:on } @Override public GridScheduler updateGridScheduler(GridScheduler scheduler) { - WorkerGrid singleWorker = WorkerGridFactory.createSingleWorker(); - scheduler.addWorkerGrid("activationUpdate.updateX", singleWorker); + WorkerGrid worker = new WorkerGrid1D(config.dim()); + worker.setLocalWork(128, 1, 1); + scheduler.addWorkerGrid("activationUpdate.updateX", worker); return scheduler; } From 61993fcbce88729c1f7e4dc224dc6c32c4621dd6 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 26 Nov 2025 16:08:36 +0200 Subject: [PATCH 2/3] Replace `loadTornadoTensorAsFP32` with `loadTornadoTensor` across model loaders for consistent tensor loading. --- .../org/beehive/gpullama3/model/loader/Phi3ModelLoader.java | 2 +- .../org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java | 2 +- .../org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java index f32249ed..2bdf0c32 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java @@ -140,7 +140,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new Phi3TornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java index c957c029..b7e9b691 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -137,7 +137,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr // Load all tensors uniformly as TornadoTensor hierarchy return new Qwen2TornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java index 008af2b3..59a8b7ae 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java @@ -137,7 +137,7 @@ protected Weights createTornadoVMWeights(Map tensorEntr final int nl = config.numberOfLayers(); return new Qwen3TornadoWeights( - loadTornadoTensorAsFP32(tokenEmbeddings), + loadTornadoTensor(tokenEmbeddings), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // fp32 loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), From 937408fd1e43ddf23206c773efd296201941d8c9 Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Wed, 26 Nov 2025 16:15:38 +0200 Subject: [PATCH 3/3] [CI] point to master tornadovm --- .github/workflows/build-and-run.yml | 2 +- set_paths | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build-and-run.yml b/.github/workflows/build-and-run.yml index 4eacedf1..0a857a0d 100644 --- a/.github/workflows/build-and-run.yml +++ b/.github/workflows/build-and-run.yml @@ -32,7 +32,7 @@ jobs: - name: Clone TornadoVM explicitly run: | - git clone --depth 1 --branch develop \ + git clone --depth 1 --branch master \ https://github.com/beehive-lab/TornadoVM.git \ GPULlama3.java/external/tornadovm - name: Set up Python venv for TornadoVM diff --git a/set_paths b/set_paths index fe79810e..fd807c5e 100644 --- a/set_paths +++ b/set_paths @@ -6,10 +6,10 @@ # Resolve root of this project (LLaMA3) and TornadoVM export LLAMA_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -#export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" +export TORNADO_ROOT="${LLAMA_ROOT}/external/tornadovm" # Set the path to TornadoVM SDK binaries -#export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" +export TORNADO_SDK="${TORNADO_ROOT}/bin/sdk" # Add TornadoVM and LLaMA bin directories to PATH export PATH="${PATH}:${TORNADO_SDK}:${LLAMA_ROOT}"