@@ -159,22 +159,28 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
159159 TaskGraph unifiedLayer = new TaskGraph ("layer_" + layerIndex );
160160 unifiedLayer .consumeFromDevice (state .wrapX );
161161 unifiedLayer .transferToDevice (DataTransferMode .FIRST_EXECUTION ,
162- //Copy-in weights per layer for batched-layered layout
162+ // Attention weights
163163 weights .rms_att_weightLayered [layerIndex ].asFloatArray (),
164164 weights .wqLayered [layerIndex ].asByteArray (),
165165 weights .wkLayered [layerIndex ].asByteArray (),
166166 weights .wvLayered [layerIndex ].asByteArray (),
167167 weights .woLayered [layerIndex ].asByteArray (),
168+ // Qwen2-specific bias terms
168169 weights .q_biasLayered [layerIndex ].asFloatArray (),
169170 weights .k_biasLayered [layerIndex ].asFloatArray (),
170171 weights .v_biasLayered [layerIndex ].asFloatArray (),
172+ // FFN weights
171173 weights .rms_ffn_weightLayered [layerIndex ].asFloatArray (),
172174 weights .w1Layered [layerIndex ].asByteArray (),
173175 weights .w2Layered [layerIndex ].asByteArray (),
174176 weights .w3Layered [layerIndex ].asByteArray ()
175177 );
176178 unifiedLayer = configureLayerDataTransfers (unifiedLayer , layerIndex );
177179
180+ // ═══════════════════════════════════════════════════════════════════════
181+ // ATTENTION BLOCK
182+ // ═══════════════════════════════════════════════════════════════════════
183+
178184 // RMS Normalization - compute scale factor
179185 unifiedLayer .task ("attn_rms_reduce" ,
180186 TransformerComputeKernelsLayered ::reductionOneBlockWithLayer ,
@@ -258,6 +264,10 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
258264 config .dim (), // output dim
259265 LOCAL_WORK_GROUP_SIZE_ALLOC );
260266
267+ // ═══════════════════════════════════════════════════════════════════════
268+ // FFN BLOCK
269+ // ═══════════════════════════════════════════════════════════════════════
270+
261271 // RMS Normalization - compute scale factor
262272 unifiedLayer .task ("ffn_rms_reduce" ,
263273 TransformerComputeKernelsLayered ::reductionOneBlockWithLayer ,
@@ -302,10 +312,10 @@ TaskGraph setupSingleQwen2Q8_0FFNLayer(Qwen2TornadoWeights weights, int layerInd
302312 weights .w2Layered [layerIndex ].asByteArray (), // W2 (down)
303313 config .hiddenDim (), // input dim
304314 config .dim (), // output dim
305- LOCAL_WORK_GROUP_SIZE_ALLOC )
306- . persistOnDevice (
307- state .wrapX
308- );
315+ LOCAL_WORK_GROUP_SIZE_ALLOC );
316+
317+ unifiedLayer . persistOnDevice ( state .wrapX );
318+
309319 return unifiedLayer ;
310320
311321 }
0 commit comments