Skip to content

Commit a968bae

Browse files
authored
Merge pull request #78 from beehive-lab/feat/deq-n-compute
[FP16] Improved performance by fusing dequantize with compute in kernels: 20-30% Inference Speedup
2 parents e01c2e3 + 664f160 commit a968bae

22 files changed

+4673
-1040
lines changed

llama-tornado

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def create_parser() -> argparse.ArgumentParser:
422422
)
423423
debug_group.add_argument(
424424
"--profiler-dump-dir",
425-
default="/home/mikepapadim/repos/gpu-llama3.java/prof.json",
425+
default=None,
426426
help="Directory for profiler output",
427427
)
428428

@@ -498,6 +498,11 @@ def main():
498498
parser = create_parser()
499499
args = parser.parse_args()
500500

501+
# Set default profiler log path relative to LLAMA_ROOT
502+
if args.profiler_dump_dir is None:
503+
llama_root = os.environ.get("LLAMA_ROOT")
504+
args.profiler_dump_dir = os.path.join(llama_root, "profiler-log.json")
505+
501506
# Set default seed if not provided
502507
if args.seed is None:
503508
args.seed = int(time.time())

pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,12 @@
5454
<dependency>
5555
<groupId>io.github.beehive-lab</groupId>
5656
<artifactId>tornado-api</artifactId>
57-
<version>2.0.1-dev</version>
57+
<version>2.1.0</version>
5858
</dependency>
5959
<dependency>
6060
<groupId>io.github.beehive-lab</groupId>
6161
<artifactId>tornado-runtime</artifactId>
62-
<version>2.0.1-dev</version>
62+
<version>2.1.0</version>
6363
</dependency>
6464
</dependencies>
6565

src/main/java/org/beehive/gpullama3/inference/state/LlamaState.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ protected StateFields createStateFields(Configuration config) {
6464
fields.wrapK = new FloatArray(config.dim());
6565
fields.wrapV = new FloatArray(config.dim());
6666

67+
fields.wrapXFP16 = new HalfFloatArray(config.dim());
68+
fields.wrapXbFP16 = new HalfFloatArray(config.dim());
6769
// dim vs kvdim
6870
fields.wrapKeyCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());
6971
fields.wrapValueCache = new FloatArray(config.contextLength() * kvDim * config.numberOfLayers());

src/main/java/org/beehive/gpullama3/inference/state/Phi3State.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ protected StateFields createStateFields(Configuration config) {
8787
}
8888
fields.wrapX = new FloatArray(dim);
8989
fields.wrapXb = new FloatArray(dim);
90+
fields.wrapXFP16 = new HalfFloatArray(dim);
91+
fields.wrapXbFP16 = new HalfFloatArray(dim);
9092
fields.wrapXb2 = new FloatArray(dim);
9193
fields.wrapHb = new FloatArray(2 * hiddenDim);
9294
fields.wrapHb2 = new FloatArray(hiddenDim);

src/main/java/org/beehive/gpullama3/inference/state/Qwen2State.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ protected StateFields createStateFields(Configuration configuration) {
4848
}
4949
fields.wrapX = new FloatArray(config.dim());
5050
fields.wrapXb = new FloatArray(config.dim());
51+
fields.wrapXbFP16 = new HalfFloatArray(config.dim());
5152
fields.wrapXb2 = new FloatArray(config.dim());
5253
fields.wrapHb = new FloatArray(config.hiddenDim());
5354
fields.wrapHb2 = new FloatArray(config.hiddenDim());

src/main/java/org/beehive/gpullama3/inference/state/Qwen3State.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ protected StateFields createStateFields(Configuration configuration) {
7575

7676
fields.wrapX = new FloatArray(config.dim());
7777
fields.wrapXb = new FloatArray(nEmbdHeadK * config.numberOfHeads());
78+
fields.wrapXbFP16 = new HalfFloatArray(nEmbdHeadK * config.numberOfHeads());
79+
7880
fields.wrapXb2 = new FloatArray(config.dim());
7981
fields.wrapHb = new FloatArray(config.hiddenDim());
8082
fields.wrapHb2 = new FloatArray(config.hiddenDim());

src/main/java/org/beehive/gpullama3/inference/state/State.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import org.beehive.gpullama3.model.Configuration;
55
import uk.ac.manchester.tornado.api.types.HalfFloat;
66
import uk.ac.manchester.tornado.api.types.arrays.*;
7+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
9+
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
710

811
/**
912
* Represents the base state structure used during LLM inference.
@@ -58,13 +61,17 @@ public abstract class State {
5861
public final IntArray positionHolder;
5962

6063
public TornadoNativeArray embeddingX;
64+
65+
public final HalfFloatArray wrapXbFP16; // FloatArray wrapper for xb (residual branch activation), optimized for TornadoVM usage.
66+
6167
// store inter
6268
public int localSize;
6369
public FloatArray temp; // Temporary buffer for intermediate calculations, size adjusted for local workgroup size.
6470
public FloatArray tempFFN; // Temporary buffer for feed-forward network calculations, size adjusted for local workgroup size.
6571
public FloatArray tempLogits; // Temporary buffer for logits calculations, size adjusted for local workgroup size.
6672
public int latestToken; // Keeps track of the most recent token processed by the model. Useful for stateful or autoregressive models.
6773

74+
public HalfFloatArray wrapXFP16;
6875
/** last index in previous block */
6976

7077
protected State(Configuration config, int batchsize) {
@@ -100,6 +107,9 @@ protected State(Configuration config, int batchsize) {
100107
this.wrapK = fields.wrapK;
101108
this.wrapV = fields.wrapV;
102109

110+
this.wrapXFP16 = fields.wrapXFP16;
111+
this.wrapXbFP16 = fields.wrapXbFP16;
112+
103113
// dim vs kvdim
104114
this.wrapKeyCache = fields.wrapKeyCache;
105115
this.wrapValueCache = fields.wrapValueCache;
@@ -136,6 +146,7 @@ public void createActivationQ8_0(int size) {
136146
int q8BytesNeeded = blocksNeeded * Q8_0_BLOCK_BYTES;
137147
this.embeddingX = new ByteArray(q8BytesNeeded);
138148
}
149+
public HalfFloatArray wrapXFP16, wrapXbFP16;
139150
}
140151

141152
@Override

src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package org.beehive.gpullama3.tornadovm;
22

3-
import org.beehive.gpullama3.tensor.GGMLType;
43
import org.beehive.gpullama3.inference.state.State;
54
import org.beehive.gpullama3.model.Configuration;
65
import org.beehive.gpullama3.model.Model;
6+
import org.beehive.gpullama3.tensor.GGMLType;
77
import org.beehive.gpullama3.tornadovm.layerplanner.base.QuantizationPlannerFactory;
88
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
99
import uk.ac.manchester.tornado.api.TornadoExecutionPlan;
@@ -133,6 +133,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
133133

134134
// Set the position in the state object (used by attention layers)
135135
state.positionHolder.set(0, position);
136+
state.temp.clear();
137+
state.tempFFN.clear();
136138

137139
// 2. Execute each transformer layer graph sequentially
138140
// Each graph computes attention and feed-forward transformations for one layer
@@ -141,7 +143,8 @@ public FloatArray tornadoVMForwardExecuteLayered(int position) {
141143
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
142144
.execute();
143145
}
144-
146+
state.tempLogits.clear(); // Clear the intermediate logits tensor -> set to 0f
147+
state.wrapLogits.clear(); // Clear the output logits tensor -> set to 0f
145148
// 3. Execute the final graph that projects the last hidden state to output logits
146149
executionPlan.withGraph(getFinalLogitsGraphIndex())
147150
.withGridScheduler(tornadoVMLayerPlanner.getGridScheduler())
@@ -179,7 +182,7 @@ private int getFinalLogitsGraphIndex() {
179182
/// Execute the forward pass of the LLaMA transformer model using TornadoVM acceleration just once to copy the data into the read-only data layer.
180183
public void forceCopyInReadOnlyDataLayered() {
181184
// Execute all TornadoVM graphs
182-
state.wrapX.init(0.0f);
185+
state.wrapX.clear();
183186
state.positionHolder.init(0);
184187

185188
// Execute activation update graph

0 commit comments

Comments
 (0)