@@ -64,7 +64,7 @@ public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe
6464 }
6565
6666 @ Override
67- public GridScheduler updateGridScheduler (GridScheduler tornadoForwardScheduler ) {
67+ public GridScheduler updateGridScheduler (GridScheduler gridScheduler ) {
6868 WorkerGrid rmsNormWorker = WorkerGridFactory .createRmsNormWorker (config .dim (), state .localSize );
6969
7070 int matmulQGlobal = nEmbdHeadK * config .numberOfHeads () * LOCAL_WORK_GROUP_SIZE_ALLOC ;
@@ -98,20 +98,20 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
9898 WorkerGrid fusedQKVWorker = WorkerGridFactory .genericWorker (fusedQKVGlobal , LOCAL_WORK_GROUP_SIZE_ALLOC );
9999
100100 for (int i = 0 ; i < config .numberOfLayers (); i ++) {
101- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_reduce" , rmsNormWorker );
102- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_qkv_projection" , fusedQKVWorker );
103- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".qk_rmsnorm" , qkRmsNormWorker );
104- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
105- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
106- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , matmul1Worker );
107- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce" , rmsNormWorker );
101+ gridScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_reduce" , rmsNormWorker );
102+ gridScheduler .addWorkerGrid ("layer_" + i + ".attn_rms_qkv_projection" , fusedQKVWorker );
103+ gridScheduler .addWorkerGrid ("layer_" + i + ".qk_rmsnorm" , qkRmsNormWorker );
104+ gridScheduler .addWorkerGrid ("layer_" + i + ".rope_and_kv_cache" , ropeWorker );
105+ gridScheduler .addWorkerGrid ("layer_" + i + ".attention" , parallelAttentionWorker );
106+ gridScheduler .addWorkerGrid ("layer_" + i + ".attn_output_proj" , matmul1Worker );
107+ gridScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_reduce" , rmsNormWorker );
108108 if (shouldUseFinalNormalization ()) {
109- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_finalize" , rmsNormWorker );
109+ gridScheduler .addWorkerGrid ("layer_" + i + ".ffn_rms_finalize" , rmsNormWorker );
110110 }
111- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , fusedFFNW1W3Worker );
112- tornadoForwardScheduler .addWorkerGrid ("layer_" + i + ".ffn_down_proj" , projectionTwoWorker );
111+ gridScheduler .addWorkerGrid ("layer_" + i + ".rms_ffn_gate_up" , fusedFFNW1W3Worker );
112+ gridScheduler .addWorkerGrid ("layer_" + i + ".ffn_down_proj" , projectionTwoWorker );
113113 }
114- return tornadoForwardScheduler ;
114+ return gridScheduler ;
115115 }
116116
117117 @ Override
0 commit comments