Skip to content

Commit 2b27344

Browse files
[refactor] Replace projectionTwo task with ffn_down_proj in Qwen2 Q8_0 FFN layers and update worker grids.
1 parent 83bb419 commit 2b27344

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler)
109109
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attention", parallelAttentionWorker);
110110
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".attn_output_proj", configDimRowMajorGlobalWorker);
111111
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_rms_reduce", rmsNormWorker);
112-
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", configDimRowMajorGlobalWorker);
113112
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".rms_ffn_gate_up", configHiddenDimRowMajorWorker);
113+
tornadoForwardScheduler.addWorkerGrid("layer_" + i + ".ffn_down_proj", configDimRowMajorGlobalWorker);
114114
}
115115
return tornadoForwardScheduler;
116116
}
@@ -283,8 +283,16 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
283283
config.hiddenDim(), // hidden dimension
284284
LOCAL_WORK_GROUP_SIZE_ALLOC);
285285

286-
unifiedLayer.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte, context,
287-
state.wrapHb, state.wrapX, weights.w2Layered[layerIndex].asByteArray(), config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
286+
// Down Projection with Residual
287+
unifiedLayer.task("ffn_down_proj",
288+
TransformerComputeKernelsLayered::matrixVectorGenericWithResidualQ8_0Byte,
289+
context,
290+
qwen2State.wrapHb, // input: FFN intermediate
291+
qwen2State.wrapX, // output: wrapX += W2 · wrapHb
292+
weights.w2Layered[layerIndex].asByteArray(), // W2 (down)
293+
config.hiddenDim(), // input dim
294+
config.dim(), // output dim
295+
LOCAL_WORK_GROUP_SIZE_ALLOC)
288296
.persistOnDevice(
289297
state.wrapX
290298
);

0 commit comments

Comments
 (0)