Skip to content

Commit a1081e6

Browse files
authored
Some minor performance improvements to buffer 4b mat mul.
Differential Revision: D87910988 Pull Request resolved: #15989
1 parent e2c8c60 commit a1081e6

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ void main() {
7878

7979
const int in_row_txstride = div4(in_sizes.x);
8080

81+
$if WEIGHT_STORAGE == "buffer":
82+
$if QUANT_NBITS == 4:
83+
uint qmat2_bufi = weight_txcol;
84+
$else:
85+
uint qmat2_bufi = out_txcol;
86+
8187
for (int pos = 0, txpos = 0;
8288
txpos < in_row_txstride;
8389
pos += 4, txpos += 1) {
@@ -99,7 +105,6 @@ void main() {
99105
}
100106

101107
$if WEIGHT_STORAGE == "buffer":
102-
uint qmat2_bufi;
103108
uint weight_row_txstride = div4(weight_sizes.x);
104109
uint encoded_weight;
105110

@@ -114,26 +119,31 @@ void main() {
114119
$if QUANT_NBITS == 4:
115120
$for c in range(0, TILE_TXCOLS, 2):
116121
$if WEIGHT_STORAGE == "buffer":
117-
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
118122
encoded_weight = t_weight[qmat2_bufi + ${c}];
119-
packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
123+
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T((encoded_weight >> 4) & 0xF);
124+
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T((encoded_weight >> 12) & 0xF);
125+
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T((encoded_weight >> 20) & 0xF);
126+
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T((encoded_weight >> 28));
127+
128+
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T((encoded_weight) & 0xF);
129+
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T((encoded_weight >> 8) & 0xF);
130+
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T((encoded_weight >> 16) & 0xF);
131+
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T((encoded_weight >> 24) & 0xF);
120132
$else:
121133
packed_weight_tex = texelFetch(
122134
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
123-
124-
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4);
125-
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4);
126-
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4);
127-
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4);
128-
129-
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF);
130-
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF);
131-
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF);
132-
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF);
135+
qmat2[${c} * 4 * TILE_TXCOLS + 0] = T(packed_weight_tex.x >> 4);
136+
qmat2[${c} * 4 * TILE_TXCOLS + 1] = T(packed_weight_tex.y >> 4);
137+
qmat2[${c} * 4 * TILE_TXCOLS + 2] = T(packed_weight_tex.z >> 4);
138+
qmat2[${c} * 4 * TILE_TXCOLS + 3] = T(packed_weight_tex.w >> 4);
139+
140+
qmat2[${c} * 4 * TILE_TXCOLS + 4] = T(packed_weight_tex.x & 0xF);
141+
qmat2[${c} * 4 * TILE_TXCOLS + 5] = T(packed_weight_tex.y & 0xF);
142+
qmat2[${c} * 4 * TILE_TXCOLS + 6] = T(packed_weight_tex.z & 0xF);
143+
qmat2[${c} * 4 * TILE_TXCOLS + 7] = T(packed_weight_tex.w & 0xF);
133144
$else:
134145
$for c in range(TILE_TXCOLS):
135146
$if WEIGHT_STORAGE == "buffer":
136-
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
137147
encoded_weight = t_weight[qmat2_bufi + ${c}];
138148
packed_weight_tex = ivec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
139149
$else:
@@ -146,6 +156,8 @@ void main() {
146156
$for j in range(4):
147157
sums[tr * TILE_TXCOLS * 4 + ${c} * 4 + ${j}] += qmat2[${c} * 4 + ${j}] * mat1[tr * 4 + r];
148158
}
159+
$if WEIGHT_STORAGE == "buffer":
160+
qmat2_bufi += weight_row_txstride;
149161
}
150162
}
151163

0 commit comments

Comments
 (0)