Skip to content

Commit 664f160

Browse files
Remove redundant workers in Qwen3 Q8_0 FFN layers.
1 parent 1c2be61 commit 664f160

File tree

1 file changed

+2
-11
lines changed

1 file changed

+2
-11
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,21 +67,12 @@ public Qwen3Q8_0FFNLayers(String taskGraphName, Qwen3State state, Qwen3TornadoWe
6767
public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
6868
WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(config.dim(), state.localSize);
6969

70-
int matmulQGlobal = nEmbdHeadK * config.numberOfHeads() * LOCAL_WORK_GROUP_SIZE_ALLOC;
71-
WorkerGrid matmulQRowMajorWorker = WorkerGridFactory.genericWorker(matmulQGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
72-
73-
int matmulKVGlobal = nEmbdGqa * LOCAL_WORK_GROUP_SIZE_ALLOC;
74-
WorkerGrid matmulKVRowMajorWorker = WorkerGridFactory.genericWorker(matmulKVGlobal, LOCAL_WORK_GROUP_SIZE_ALLOC);
75-
7670
int qkRmsNormGroups = config.numberOfHeads() + config.numberOfKeyValueHeads();
7771
WorkerGrid qkRmsNormWorker = WorkerGridFactory.genericWorker(qkRmsNormGroups * nEmbdHead, nEmbdHead);
7872

79-
int h = config.numberOfHeads();
80-
int ic = nEmbdHead / 2;
81-
WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(h, nEmbdHead);
82-
WorkerGrid copyToCachesWorker = WorkerGridFactory.genericWorker(nEmbdGqa, 128);
73+
WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(config.numberOfHeads(), nEmbdHead);
8374
WorkerGrid parallelAttentionWorker = WorkerGridFactory.createAttentionWorker(config.numberOfHeads(), nEmbdHead);
84-
75+
// attn_output_proj worker (output projection)
8576
int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
8677
WorkerGrid matmul1Worker = WorkerGridFactory.genericWorker(matmul1Global, LOCAL_WORK_GROUP_SIZE_ALLOC);
8778

0 commit comments

Comments
 (0)