From 664b48ccfb6a3fd4c1a52f6c97c8921f27268e41 Mon Sep 17 00:00:00 2001 From: Vivek Trivedi Date: Mon, 1 Dec 2025 20:23:28 -0800 Subject: [PATCH] Some minor performance improvements to buffer 4b mat mul. (#15989) Summary: The code change in this diff aims to improve the performance of buffer 4b matrix multiplication by reducing unnecessary computations and by spreading operations to allow better latency hiding. Reviewed By: yipjustin Differential Revision: D87910988 --- .../graph/ops/glsl/linear_qcsnw_tiled.glsl | 40 ++++++++++++------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index 1e5de21cffc..d966de7282e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -78,6 +78,12 @@ void main() { const int in_row_txstride = div4(in_sizes.x); + $if WEIGHT_STORAGE == "buffer": + $if QUANT_NBITS == 4: + uint qmat2_bufi = weight_txcol; + $else: + uint qmat2_bufi = out_txcol; + for (int pos = 0, txpos = 0; txpos < in_row_txstride; pos += 4, txpos += 1) { @@ -99,7 +105,6 @@ void main() { } $if WEIGHT_STORAGE == "buffer": - uint qmat2_bufi; uint weight_row_txstride = div4(weight_sizes.x); uint encoded_weight; @@ -114,26 +119,31 @@ void main() { $if QUANT_NBITS == 4: $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; - packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28)); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF); $else: packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - - qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); - - qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); $else: @@ -146,6 +156,8 @@ void main() { $for j in range(4): sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r]; } + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi += weight_row_txstride; } }