Skip to content

Commit 05c0048

Browse files
[refactor] Add detailed comments for attention and FFN blocks in Qwen3 Q8_0 layers and improve code readability.
1 parent 4bccf37 commit 05c0048

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

src/main/java/org/beehive/gpullama3/tornadovm/layers/type/q8_0/Qwen3Q8_0FFNLayers.java

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -165,40 +165,38 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
165165

166166
// === Data Setup ===
167167
unifiedLayer.consumeFromDevice(qwen3State.wrapX);
168-
// Transfer Q8_0 weights for this layer (quants and scales)
169168
unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
170-
weights.rms_att_weightLayered[layerIndex].asFloatArray(), //
171-
weights.wqLayered[layerIndex].asByteArray(),
172-
weights.wkLayered[layerIndex].asByteArray(),
173-
weights.wvLayered[layerIndex].asByteArray(),
174-
weights.woLayered[layerIndex].asByteArray(),
175-
weights.rms_att_KNormLayered[layerIndex].asFloatArray(), //
176-
weights.rms_att_QNormLayered[layerIndex].asFloatArray(),//
177-
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), //
178-
weights.w1Layered[layerIndex].asByteArray(),
179-
weights.w2Layered[layerIndex].asByteArray(),
180-
weights.w3Layered[layerIndex].asByteArray());
181-
182-
// Configure layer data transfers (EVERY_EXECUTION and device persistence)
169+
// Attention weights
170+
weights.rms_att_weightLayered[layerIndex].asFloatArray(), // RMS norm weights
171+
weights.wqLayered[layerIndex].asByteArray(), // Q projection
172+
weights.wkLayered[layerIndex].asByteArray(), // K projection
173+
weights.wvLayered[layerIndex].asByteArray(), // V projection
174+
weights.woLayered[layerIndex].asByteArray(), // Output projection
175+
// Qwen3-specific Q/K norm weights
176+
weights.rms_att_KNormLayered[layerIndex].asFloatArray(), // K RMSNorm weights
177+
weights.rms_att_QNormLayered[layerIndex].asFloatArray(), // Q RMSNorm weights
178+
// FFN weights
179+
weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), // FFN RMSNorm weights
180+
weights.w1Layered[layerIndex].asByteArray(), // FFN gate projection
181+
weights.w2Layered[layerIndex].asByteArray(), // FFN down projection
182+
weights.w3Layered[layerIndex].asByteArray()); // FFN up projection
183183
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
184184

185+
// ═══════════════════════════════════════════════════════════════════════
186+
// ATTENTION BLOCK
187+
// ═══════════════════════════════════════════════════════════════════════
185188

186189
// RMS Normalization - compute scale factor
187190
unifiedLayer.task("attn_rms_reduce",
188191
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
189192
context,
190193
qwen3State.temp, // output: scale factor
191194
qwen3State.wrapX, // input: hidden state
192-
config.dim(), // dimension
193-
config.rmsNormEps(), // epsilon
195+
config.dim(), // dimension
196+
config.rmsNormEps(), // epsilon
194197
qwen3State.localSize); // local memory size
195198

196-
// QKV projections with Qwen3 GQA dimensions
197-
// Q8_0 weights pass both quants and scales
198-
int qDim0 = nEmbdHeadK * config.numberOfHeads(); // Query dimension
199-
int kvDim0 = nEmbdGqa; // KV dimension (smaller due to GQA)
200-
int qkvDim1 = config.dim(); // Input dimension
201-
199+
// Fused RMS Apply + QKV Projection
202200
unifiedLayer.task("attn_rms_qkv_projection",
203201
Qwen3Kernels::fusedRmsNormQKVMatmulQ8_0,
204202
context,
@@ -273,7 +271,9 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
273271
config.dim(), // output dim
274272
LOCAL_WORK_GROUP_SIZE_ALLOC);
275273

276-
// ========== FEED-FORWARD BLOCK ==========
274+
// ═══════════════════════════════════════════════════════════════════════
275+
// FFN BLOCK
276+
// ═══════════════════════════════════════════════════════════════════════
277277

278278
// RMS Normalization - compute scale factor
279279
unifiedLayer.task("ffn_rms_reduce",
@@ -308,11 +308,13 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
308308
weights.w2Layered[layerIndex].asByteArray(), // W2 (down)
309309
config.hiddenDim(), // input dim
310310
config.dim(), // output dim
311-
LOCAL_WORK_GROUP_SIZE_ALLOC)
312-
.persistOnDevice(state.wrapX);
311+
LOCAL_WORK_GROUP_SIZE_ALLOC);
312+
313+
unifiedLayer.persistOnDevice(state.wrapX);
313314

314315
return unifiedLayer;
315316
}
317+
// @formatter:on
316318

317319
/**
318320
* Configure data transfers for first and subsequent layers

0 commit comments

Comments
 (0)