-
Notifications
You must be signed in to change notification settings - Fork 78
Description
I have thoroughly reviewed the T-MAC code, and I find it to be exceptionally well-written. However, both the TVM part of the code and the C++ code generated by TVM are quite challenging for me to understand due to my limited expertise. Additionally, I am struggling to correlate some details with the explanations provided in the paper. I would greatly appreciate some clarification.
My main point of confusion lies in the purpose of the "Bit-serial linear transformation." From what I understand, the following code loads 16 int8 weights at once, then unpacks them into two sets of int4 data. The LUT is used for multiplication, and the adder_bot and adder_top are responsible for accumulating the data (with some operations to avoid overflow). At this point, multiplying the accumulated results by the scale of the GPTQ model group should yield the result of the matrix multiplication. However, does this lead to significant precision loss? If so, is this loss introduced during the quantization of the LUT? Would it be possible to avoid this precision loss by not quantizing the LUT?
for (int k = 0; k < ActK; k++) {
uint8x16_t vec_as = vld1q_u8(a + i * K + (kk + k) * 16);
uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4);
uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask);
int8x16_t vec_v_bot_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_bot);
int8x16_t vec_v_top_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_top);
adder_bot.push(vec_v_bot_tmp, k);
adder_top.push(vec_v_top_tmp, k);
}
I also find it difficult to comprehend some of the formulas in the "Bit-serial linear transformation" section of the Arxiv version:
C_global[cse_var_1] = ((CBits[cse_var_2] * 5.000000e-01f) + CBits[(cse_var_2 + 8)]);
C_global[(cse_var_1 + 1)] = ((CBits[(cse_var_2 + 1)] * 5.000000e-01f) + CBits[(cse_var_2 + 9)]);
C_global[(cse_var_1 + 2)] = ((CBits[(cse_var_2 + 2)] * 5.000000e-01f) + CBits[(cse_var_2 + 10)]);
But where exactly is
Additionally, in the zero-point section, why is there a calculation involving the weights? The comment states w = (w - default_zero - (zeros - default_zero)) * scales. Does w refer to the weights? Is the purpose of add_zero to fit the formula above?
float16x8_t vec_s0 = vld1q_f16(scales + ((i / 4 ) / Bits) * 16);
float16x8_t vec_s1 = vld1q_f16(scales + ((i / 4 + 1) / Bits) * 16);
float16x8_t vec_s2 = vld1q_f16(scales + ((i / 4 + 2) / Bits) * 16);
float16x8_t vec_s3 = vld1q_f16(scales + ((i / 4 + 3) / Bits) * 16);
//scale * accumulated result
vec_c0 = vld1q_f16(c + i * 2) + vec_c0 * vec_s0;
vec_c1 = vld1q_f16(c + i * 2 + 8) + vec_c1 * vec_s1;
vec_c2 = vld1q_f16(c + i * 2 + 16) + vec_c2 * vec_s2;
vec_c3 = vld1q_f16(c + i * 2 + 24) + vec_c3 * vec_s3;
// load zero_point
float16x8_t vec_z0 = vld1q_f16(scales + ((i / 4 ) / Bits) * 16 + 8);
float16x8_t vec_z1 = vld1q_f16(scales + ((i / 4 + 1) / Bits) * 16 + 8);
float16x8_t vec_z2 = vld1q_f16(scales + ((i / 4 + 2) / Bits) * 16 + 8);
float16x8_t vec_z3 = vld1q_f16(scales + ((i / 4 + 3) / Bits) * 16 + 8);
// lut_bias
partial_sum *= 2;
#define add_zero(cs, zs, ib) \
((ib) % Bits) ? ((cs)) \
: ((cs) + zs * partial_sum)
vst1q_f16(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4 )));
vst1q_f16(c + i * 2 + 8, add_zero(vec_c1, vec_z1, (i / 4 + 1)));
vst1q_f16(c + i * 2 + 16, add_zero(vec_c2, vec_z2, (i / 4 + 2)));
vst1q_f16(c + i * 2 + 24, add_zero(vec_c3, vec_z3, (i / 4 + 3)));
#undef add_zero
My questions might be somewhat unclear, mainly because my understanding is limited, and I haven't been able to systematically formulate them. If possible, I would greatly appreciate your assistance in clarifying these points!
我深入阅读了T-MAC的代码,代码十分优秀,但是无论是TVM部分的代码,还是TVM生成的C++代码,由于我的水平有限,阅读起来很吃力,并且和文章中一些细节无法建立起联系。希望能够得到一些解答。
我主要不太明白“Bit-serial linear transformation”的主要目的是什么,按照我的理解,下面的代码完成了根据输入的权重序列(一次性加载16个int8,然后拆包成两个int4的数据),利用LUT完成乘法运算,然后adder_bot和adder_top则完成了数据的累加(其中有一些避免数据溢出的操作)那么这个时候,累加出来的结果乘以GPTQ模型本身group的scale就已经可以作为矩阵乘法的结果了。还是这样会带来严重的精度损失?(但是我不知道这个精度损失是什么引发的,是LUT进行量化时候引入的吗?可以如果不量化LUT,可以避免吗?)
for (int k = 0; k < ActK; k++) {
uint8x16_t vec_as = vld1q_u8(a + i * K + (kk + k) * 16);
uint8x16_t vec_a_top = vshrq_n_u8(vec_as, 4);
uint8x16_t vec_a_bot = vandq_u8(vec_as, vec_mask);
int8x16_t vec_v_bot_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_bot);
int8x16_t vec_v_top_tmp = vqtbl1q_s8(vec_lut[kk + k], vec_a_top);
adder_bot.push(vec_v_bot_tmp, k);
adder_top.push(vec_v_top_tmp, k);
}
Arxiv版本的“Bit-serial linear transformation”部分的公式部分我有些难以理解:
C_global[cse_var_1] = ((CBits[cse_var_2] * 5.000000e-01f) + CBits[(cse_var_2 + 8)]);
C_global[(cse_var_1 + 1)] = ((CBits[(cse_var_2 + 1)] * 5.000000e-01f) + CBits[(cse_var_2 + 9)]);
C_global[(cse_var_1 + 2)] = ((CBits[(cse_var_2 + 2)] * 5.000000e-01f) + CBits[(cse_var_2 + 10)]);
但是
以及在zeropoint部分,为什么要对权重进行计算?注释中有w = (w - default_zero - (zeros - default_zero)) * scales, w表示的是权重吗?以及add_zero的目的是否是为了构成上面的公式。
float16x8_t vec_s0 = vld1q_f16(scales + ((i / 4 ) / Bits) * 16);
float16x8_t vec_s1 = vld1q_f16(scales + ((i / 4 + 1) / Bits) * 16);
float16x8_t vec_s2 = vld1q_f16(scales + ((i / 4 + 2) / Bits) * 16);
float16x8_t vec_s3 = vld1q_f16(scales + ((i / 4 + 3) / Bits) * 16);
//scale * accumulated result
vec_c0 = vld1q_f16(c + i * 2) + vec_c0 * vec_s0;
vec_c1 = vld1q_f16(c + i * 2 + 8) + vec_c1 * vec_s1;
vec_c2 = vld1q_f16(c + i * 2 + 16) + vec_c2 * vec_s2;
vec_c3 = vld1q_f16(c + i * 2 + 24) + vec_c3 * vec_s3;
// load zero_point
float16x8_t vec_z0 = vld1q_f16(scales + ((i / 4 ) / Bits) * 16 + 8);
float16x8_t vec_z1 = vld1q_f16(scales + ((i / 4 + 1) / Bits) * 16 + 8);
float16x8_t vec_z2 = vld1q_f16(scales + ((i / 4 + 2) / Bits) * 16 + 8);
float16x8_t vec_z3 = vld1q_f16(scales + ((i / 4 + 3) / Bits) * 16 + 8);
// lut_bias
partial_sum *= 2;
#define add_zero(cs, zs, ib) \
((ib) % Bits) ? ((cs)) \
: ((cs) + zs * partial_sum)
vst1q_f16(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4 )));
vst1q_f16(c + i * 2 + 8, add_zero(vec_c1, vec_z1, (i / 4 + 1)));
vst1q_f16(c + i * 2 + 16, add_zero(vec_c2, vec_z2, (i / 4 + 2)));
vst1q_f16(c + i * 2 + 24, add_zero(vec_c3, vec_z3, (i / 4 + 3)));
#undef add_zero
我的提问可能有些不清晰,主要是因为我理解能力有限,没能很系统的提出问题,如果可以的话,劳烦您解答!