diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index 1e5de21cffc..d966de7282e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -78,6 +78,12 @@ void main() { const int in_row_txstride = div4(in_sizes.x); + $if WEIGHT_STORAGE == "buffer": + $if QUANT_NBITS == 4: + uint qmat2_bufi = weight_txcol; + $else: + uint qmat2_bufi = out_txcol; + for (int pos = 0, txpos = 0; txpos < in_row_txstride; pos += 4, txpos += 1) { @@ -99,7 +105,6 @@ void main() { } $if WEIGHT_STORAGE == "buffer": - uint qmat2_bufi; uint weight_row_txstride = div4(weight_sizes.x); uint encoded_weight; @@ -114,26 +119,31 @@ void main() { $if QUANT_NBITS == 4: $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; - packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28)); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF); $else: packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); - - qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); - qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); - - qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); - qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4); + qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4); + + qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF); + qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF); $else: $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": - qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; encoded_weight = t_weight[qmat2_bufi + ${c}]; packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); $else: @@ -146,6 +156,8 @@ void main() { $for j in range(4): sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r]; } + $if WEIGHT_STORAGE == "buffer": + qmat2_bufi += weight_row_txstride; } }