Skip to content

Commit 1c2be61

Browse files
[refactor] Rename grid scheduler parameter in Qwen3 Q8_0 FFN layers.
1 parent 06ebbc8 commit 1c2be61

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe
6464
}
6565

6666
@Override
67-
public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
67+
public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
6868
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize);
6969

7070
int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC;
@@ -98,20 +98,20 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
9898
WorkerGrid fusedQKVWorker = WorkerGridFactory.genericWorker(fusedQKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
9999

100100
for (int i = 0; i < config.numberOfLayers(); i++) {
101-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
102-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker);
103-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".qk_rmsnorm", qkRmsNormWorker);
104-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
105-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
106-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
107-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
101+
gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_reduce", rmsNormWorker);
102+
gridScheduler.addWorkerGrid("layer_" + i + ".attn_rms_qkv_projection", fusedQKVWorker);
103+
gridScheduler.addWorkerGrid("layer_" + i + ".qk_rmsnorm", qkRmsNormWorker);
104+
gridScheduler.addWorkerGrid("layer_" + i + ".rope_and_kv_cache", ropeWorker);
105+
gridScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
106+
gridScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", matmul1Worker);
107+
gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
108108
if (shouldUseFinalNormalization()) {
109-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_finalize", rmsNormWorker);
109+
gridScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_finalize", rmsNormWorker);
110110
}
111-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker);
112-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker);
111+
gridScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", fusedFFNW1W3Worker);
112+
gridScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", projectionTwoWorker);
113113
}
114-
return tornadoForwardScheduler;
114+
return gridScheduler;
115115
}
116116

117117
@Override

0 commit comments

Comments
 (0)