Skip to content

Commit 0060fc6

Browse files
authored
Merge pull request beehive-lab#66 from beehive-lab/cleanup/nv
Add `SchedulerType` support to all TornadoVM layer planners and layer…
2 parents 9a790e6 + 0636662 commit 0060fc6

23 files changed

+181
-66
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/kernels/TransformerComputeKernelsLayered.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,4 +1055,50 @@ public static void fusedFeedForwardWithSiLUAndGLUActivation(KernelContext contex
10551055
hb.set(rowId, result);
10561056
}
10571057
}
1058+
1059+
/**
1060+
* Orchestrates parallel multi-head attention computation across all heads. Each head processes attention independently in parallel.
1061+
*
1062+
* Attention computation: 1. Compute attention scores (Q·K) 2. Apply softmax for attention weights 3. Compute weighted sum of values (attention·V)
1063+
*
1064+
* @param q
1065+
* Query vectors for all heads
1066+
* @param key_cache
1067+
* Cached key vectors
1068+
* @param value_cache
1069+
* Cached value vectors
1070+
* @param xb
1071+
* Output buffer for attention results
1072+
* @param nHeads
1073+
* Number of attention heads
1074+
* @param headSize
1075+
* Dimension of each head
1076+
* @param kvDim
1077+
* Total key/value dimension
1078+
* @param kvMul
1079+
* Key/value head multiplier for grouped-query attention
1080+
* @param seqLen
1081+
* Current sequence length
1082+
* @param positionHolder
1083+
* Array containing position and layer info
1084+
* @param wrapAtt
1085+
* Buffer for attention weights
1086+
* @param layer
1087+
* Current transformer layer
1088+
* @param contextLength
1089+
* Maximum context length
1090+
*/
1091+
public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
1092+
IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {
1093+
1094+
int pos = positionHolder.get(0);
1095+
int loff = layer * contextLength * kvDim;
1096+
1097+
// Parallelize computation across attention heads
1098+
for (@Parallel int h = 0; h < nHeads; h++) {
1099+
// Process each head in parallel
1100+
processHeadTornado(q, key_cache, value_cache, xb, h, headSize, kvDim, kvMul, loff, pos, wrapAtt);
1101+
}
1102+
}
1103+
10581104
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/base/QuantizedLayerPlanner.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import org.beehive.gpullama3.model.Configuration;
66
import org.beehive.gpullama3.model.Model;
77
import org.beehive.gpullama3.tornadovm.GenericLayerPlanner;
8+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerDetectionService;
9+
import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
810
import uk.ac.manchester.tornado.api.KernelContext;
911

1012
/**
@@ -22,16 +24,19 @@ public abstract class QuantizedLayerPlanner<S extends State, C extends Configura
2224
protected final C config;
2325
protected final W weights;
2426
protected final KernelContext context;
27+
protected final Model model;
28+
protected final SchedulerType schedulerType;
2529

2630
/**
2731
* Constructor: validate quantization type, extract model components
2832
*/
2933
protected QuantizedLayerPlanner(S state, Model model) {
3034
this.state = state;
35+
this.model = model;
3136
this.config = (C) model.configuration();
3237
this.weights = (W) model.weights();
3338
this.context = new KernelContext();
34-
39+
this.schedulerType = SchedulerDetectionService.determineSchedulerType(model);
3540
validateQuantizationType();
3641
}
3742

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/LlamaFP16LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ public LlamaFP16LayerPlanner(LlamaState state, Model model) {
2020
@Override
2121
protected void initializeLayerComponents() {
2222
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
23-
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config);
24-
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
23+
this.ffnLayers = new LlamaFP16FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
24+
this.logitsLayer = new LogitsFP16Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
2525
}
2626

2727
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Phi3FP16LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ public Phi3FP16LayerPlanner(Phi3State state, Model model) {
2727
@Override
2828
protected void initializeLayerComponents() {
2929
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
30-
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config);
31-
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
30+
this.ffnLayers = new Phi3FP16FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
31+
this.logitsLayer = new LogitsFP16Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType);
3232
}
3333

3434
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen2FP16LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public Qwen2FP16LayerPlanner(Qwen2State state, Model model) {
2727
@Override
2828
protected void initializeLayerComponents() {
2929
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
30-
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config);
31-
this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
30+
this.ffnLayers = new Qwen2FP16FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
31+
this.logitsLayer = new LogitsFP16Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
3232
}
3333
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Qwen3FP16LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ public Qwen3FP16LayerPlanner(Qwen3State state, Model model) {
2727
@Override
2828
protected void initializeLayerComponents() {
2929
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
30-
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config);
31-
this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
30+
this.ffnLayers = new Qwen3FP16FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
31+
this.logitsLayer = new LogitsFP16Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
3232
}
3333

3434
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/LlamaQ8_0LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ public LlamaQ8_0LayerPlanner(LlamaState state, Model model) {
2020
@Override
2121
protected void initializeLayerComponents() {
2222
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
23-
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config);
24-
this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
23+
this.ffnLayers = new LlamaQ8_0FFNLayers("llamaFFN", this.state, this.weights, this.config, this.schedulerType);
24+
this.logitsLayer = new LogitsQ8_0Layer("llamaLogits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
2525
}
2626

2727
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Phi3Q8_0LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ public Phi3Q8_0LayerPlanner(Phi3State state, Model model) {
2828
@Override
2929
protected void initializeLayerComponents() {
3030
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
31-
this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config);
32-
this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
31+
this.ffnLayers = new Phi3Q8_0FFNLayers("phi3FFN", this.state, this.weights, this.config, this.schedulerType);
32+
this.logitsLayer = new LogitsQ8_0Layer("phi3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
3333
}
3434

3535
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen2Q8_0LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ public Qwen2Q8_0LayerPlanner(Qwen2State state, Model model) {
2828
@Override
2929
protected void initializeLayerComponents() {
3030
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
31-
this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config);
32-
this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
31+
this.ffnLayers = new Qwen2Q8_0FFNLayers("qwen2FFN", this.state, this.weights, this.config, this.schedulerType);
32+
this.logitsLayer = new LogitsQ8_0Layer("qwen2Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(), this.schedulerType);
3333
}
3434

3535
}

src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/q8_0/Qwen3Q8_0LayerPlanner.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public Qwen3Q8_0LayerPlanner(Qwen3State state, Model model) {
2828
@Override
2929
protected void initializeLayerComponents() {
3030
this.activationLayer = new Activation("activationUpdate", this.state, this.weights, this.config);
31-
this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config);
32-
this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID());
31+
this.ffnLayers = new Qwen3Q8_0FFNLayers("qwen3FFN", this.state, this.weights, this.config, this.schedulerType);
32+
this.logitsLayer = new LogitsQ8_0Layer("qwen3Logits", this.state, this.weights, this.config, ffnLayers.getLastTaskGraphID(),this.schedulerType);
3333
}
3434
}

0 commit comments

Comments
 (0)