@@ -165,40 +165,38 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
165165
166166 // === Data Setup ===
167167 unifiedLayer .consumeFromDevice (qwen3State .wrapX );
168- // Transfer Q8_0 weights for this layer (quants and scales)
169168 unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
170- weights .rms_att_weightLayered [layerIndex ].asFloatArray (), //
171- weights .wqLayered [layerIndex ].asByteArray (),
172- weights .wkLayered [layerIndex ].asByteArray (),
173- weights .wvLayered [layerIndex ].asByteArray (),
174- weights .woLayered [layerIndex ].asByteArray (),
175- weights .rms_att_KNormLayered [layerIndex ].asFloatArray (), //
176- weights .rms_att_QNormLayered [layerIndex ].asFloatArray (),//
177- weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), //
178- weights .w1Layered [layerIndex ].asByteArray (),
179- weights .w2Layered [layerIndex ].asByteArray (),
180- weights .w3Layered [layerIndex ].asByteArray ());
181-
182- // Configure layer data transfers (EVERY_EXECUTION and device persistence)
169+ // Attention weights
170+ weights .rms_att_weightLayered [layerIndex ].asFloatArray (), // RMS norm weights
171+ weights .wqLayered [layerIndex ].asByteArray (), // Q projection
172+ weights .wkLayered [layerIndex ].asByteArray (), // K projection
173+ weights .wvLayered [layerIndex ].asByteArray (), // V projection
174+ weights .woLayered [layerIndex ].asByteArray (), // Output projection
175+ // Qwen3-specific Q/K norm weights
176+ weights .rms_att_KNormLayered [layerIndex ].asFloatArray (), // K RMSNorm weights
177+ weights .rms_att_QNormLayered [layerIndex ].asFloatArray (), // Q RMSNorm weights
178+ // FFN weights
179+ weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (), // FFN RMSNorm weights
180+ weights .w1Layered [layerIndex ].asByteArray (), // FFN gate projection
181+ weights .w2Layered [layerIndex ].asByteArray (), // FFN down projection
182+ weights .w3Layered [layerIndex ].asByteArray ()); // FFN up projection
183183 unifiedLayer = configureLayerDataTransfers (unifiedLayer , layerIndex );
184184
185+ // ═══════════════════════════════════════════════════════════════════════
186+ // ATTENTION BLOCK
187+ // ═══════════════════════════════════════════════════════════════════════
185188
186189 // RMS Normalization - compute scale factor
187190 unifiedLayer .task ("attn_rms_reduce" ,
188191 TransformerComputeKernelsLayered ::reductionOneBlockWithLayer ,
189192 context ,
190193 qwen3State .temp , // output: scale factor
191194 qwen3State .wrapX , // input: hidden state
192- config .dim (), // dimension
193- config .rmsNormEps (), // epsilon
195+ config .dim (), // dimension
196+ config .rmsNormEps (), // epsilon
194197 qwen3State .localSize ); // local memory size
195198
196- // QKV projections with Qwen3 GQA dimensions
197- // Q8_0 weights pass both quants and scales
198- int qDim0 = nEmbdHeadK * config .numberOfHeads (); // Query dimension
199- int kvDim0 = nEmbdGqa ; // KV dimension (smaller due to GQA)
200- int qkvDim1 = config .dim (); // Input dimension
201-
199+ // Fused RMS Apply + QKV Projection
202200 unifiedLayer .task ("attn_rms_qkv_projection" ,
203201 Qwen3Kernels ::fusedRmsNormQKVMatmulQ8_0 ,
204202 context ,
@@ -273,7 +271,9 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
273271 config .dim (), // output dim
274272 LOCAL_WORK_GROUP_SIZE_ALLOC );
275273
276- // ========== FEED-FORWARD BLOCK ==========
274+ // ═══════════════════════════════════════════════════════════════════════
275+ // FFN BLOCK
276+ // ═══════════════════════════════════════════════════════════════════════
277277
278278 // RMS Normalization - compute scale factor
279279 unifiedLayer .task ("ffn_rms_reduce" ,
@@ -308,11 +308,13 @@ TaskGraph setupSingleQwen3FFNLayer(Qwen3TornadoWeights weights, int layerIndex)
308308 weights .w2Layered [layerIndex ].asByteArray (), // W2 (down)
309309 config .hiddenDim (), // input dim
310310 config .dim (), // output dim
311- LOCAL_WORK_GROUP_SIZE_ALLOC )
312- .persistOnDevice (state .wrapX );
311+ LOCAL_WORK_GROUP_SIZE_ALLOC );
312+
313+ unifiedLayer .persistOnDevice (state .wrapX );
313314
314315 return unifiedLayer ;
315316 }
317+ // @formatter:on
316318
317319 /**
318320 * Configure data transfers for first and subsequent layers
0 commit comments