Skip to content

Commit ecb2828

Browse files
Add comments for attention and FFN blocks in Qwen2 Q8_0 FFN layers.
1 parent 8f637cd commit ecb2828

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen2Q8_0FFNLayers.java

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,22 +159,28 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
159159
TaskGraph unifiedLayer = new TaskGraph("layer_" + layerIndex);
160160
unifiedLayer.consumeFromDevice(state.wrapX);
161161
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
162-
//Copy-in weights per layer for batched-layered layout
162+
// Attention weights
163163
weights.rms_att_weightLayered[layerIndex].asFloatArray(),
164164
weights.wqLayered[layerIndex].asByteArray(),
165165
weights.wkLayered[layerIndex].asByteArray(),
166166
weights.wvLayered[layerIndex].asByteArray(),
167167
weights.woLayered[layerIndex].asByteArray(),
168+
// Qwen2-specific bias terms
168169
weights.q_biasLayered[layerIndex].asFloatArray(),
169170
weights.k_biasLayered[layerIndex].asFloatArray(),
170171
weights.v_biasLayered[layerIndex].asFloatArray(),
172+
// FFN weights
171173
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
172174
weights.w1Layered[layerIndex].asByteArray(),
173175
weights.w2Layered[layerIndex].asByteArray(),
174176
weights.w3Layered[layerIndex].asByteArray()
175177
);
176178
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
177179

180+
// ═══════════════════════════════════════════════════════════════════════
181+
// ATTENTION BLOCK
182+
// ═══════════════════════════════════════════════════════════════════════
183+
178184
// RMS Normalization - compute scale factor
179185
unifiedLayer.task("attn_rms_reduce",
180186
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
@@ -258,6 +264,10 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
258264
config.dim(), // output dim
259265
LOCAL_WORK_GROUP_SIZE_ALLOC);
260266

267+
// ═══════════════════════════════════════════════════════════════════════
268+
// FFN BLOCK
269+
// ═══════════════════════════════════════════════════════════════════════
270+
261271
// RMS Normalization - compute scale factor
262272
unifiedLayer.task("ffn_rms_reduce",
263273
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
@@ -302,10 +312,10 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
302312
weights.w2Layered[layerIndex].asByteArray(), // W2 (down)
303313
config.hiddenDim(), // input dim
304314
config.dim(), // output dim
305-
LOCAL_WORK_GROUP_SIZE_ALLOC)
306-
.persistOnDevice(
307-
state.wrapX
308-
);
315+
LOCAL_WORK_GROUP_SIZE_ALLOC);
316+
317+
unifiedLayer.persistOnDevice(state.wrapX);
318+
309319
return unifiedLayer;
310320

311321
}

0 commit comments

Comments
 (0)