diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index aadbb487ec0e4..735d4fe785620 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -2044,6 +2044,26 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } +#ifdef __ARM_FEATURE_SVE +static inline svuint32_t ggml_decode_q4scales_and_mins_for_mmla(const uint32_t *vx_scales) { + const svbool_t pg_all = svptrue_pat_b32(SV_VL4); + const svbool_t pg_false = svpfalse_b(); // 0x0000 + const svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); // 0x00ff + const svbool_t pg_odd = svzip1_b32(pg_false, pg_lo_8); + + svuint32_t vutmp_hi, vutmp_lo; + svuint32_t vx01 = svld1_u32(pg_lo_8, vx_scales); + vutmp_hi = svzip1_u32(vx01, vx01); + vutmp_hi = svlsr_n_u32_m(pg_odd, vutmp_hi, 2); + vutmp_hi = svreinterpret_u32_u64(svand_n_u64_x(pg_all, svreinterpret_u64_u32(vutmp_hi), UINT64_C(0x303030303f3f3f3f))); + const svuint32_t vx2 = svdup_u32(vx_scales[2]); + vutmp_lo = svlsr_u32_x(pg_all, vx2, svreinterpret_u32_s32(svindex_s32(-2, 2))); + vutmp_lo = svand_n_u32_z(pg_odd, vutmp_lo, UINT32_C(0x0f0f0f0f)); + svuint32_t vutmp = svorr_u32_z(pg_all, vutmp_hi, vutmp_lo); + return vutmp; +} +#endif + void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); #ifdef __ARM_FEATURE_MATMUL_INT8 @@ -2066,8 +2086,215 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi static const uint32_t kmask3 = 0x03030303; uint32_t utmp[4]; +#ifdef __ARM_FEATURE_SVE + const int vector_length = ggml_cpu_get_sve_cnt()*8; +#endif #if defined(__ARM_FEATURE_MATMUL_INT8) +#ifdef __ARM_FEATURE_SVE + if (nrc==2) { + svbool_t pg32_2 = svptrue_pat_b32(SV_VL2); + const block_q4_K * GGML_RESTRICT vx0 = vx; + const block_q8_K * GGML_RESTRICT vy0 = vy; + const block_q4_K * GGML_RESTRICT vx1 = (const block_q4_K *) ((const uint8_t*)vx + bx); + const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by); + union { + uint32_t u32[8]; + uint64_t u64[4]; + } new_utmp; + svfloat32_t sumf1 = svdup_n_f32(0); + switch (vector_length) { + case 128: + { + svbool_t pg_false = svpfalse_b(); + svbool_t pg_lo_8 = svwhilelt_b8_s32(0, 8); + svbool_t vmins_mask1= svzip1_b32(pg_lo_8, pg_false); + svbool_t vmins_mask2 = svzip1_b32(pg_false, pg_lo_8); + svbool_t pg128_all = svptrue_pat_b8(SV_VL16); + for (int i = 0; i < nb; ++i) { + svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)); + svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); + svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d); + svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin))); + svfloat32_t vy_dmins = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)); + svfloat32_t svdmins = svmul_n_f32_x(pg128_all, svmul_f32_x(pg128_all, vy_dmins, vx_dmins), -1); + const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs; + const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs; + const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs; + const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs; + svint16_t lo = svld1_s16(pg128_all, vy0[i].bsums + 0); + svint16_t hi = svld1_s16(pg128_all, vy0[i].bsums + 8); + svint16_t sum_tmp1 = svuzp1_s16(lo, hi); + svint16_t sum_tmp2 = svuzp2_s16(lo, hi); + svint16_t svq8sums_0 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2); + lo = svld1_s16(pg128_all, vy1[i].bsums + 0); + hi = svld1_s16(pg128_all, vy1[i].bsums + 8); + sum_tmp1 = svuzp1(lo, hi); + sum_tmp2 = svuzp2(lo, hi); + svint16_t svq8sums_1 = svadd_s16_x(pg128_all, sum_tmp1, sum_tmp2); + svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales); + svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales); + svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1); + svst2_u32(pg128_all, new_utmp.u32, decoded_scales); + svint16_t svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp1_u32(svld1_u32(vmins_mask1, new_utmp.u32+4), svdup_n_u32(0))))); + svint16_t svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u32(svuzp2_u32(svld1_u32(vmins_mask2, new_utmp.u32+4), svdup_n_u32(0))))); + svint32_t svsumfs_tmp1 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_0)); + svint32_t svsumfs_tmp2 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_0, svmins8_1)); + svint32_t svsumfs_tmp3 = svtrn1_s32(svsumfs_tmp1, svsumfs_tmp2); + svint32_t svsumfs_tmp4 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_0)); + svint32_t svsumfs_tmp5 = svreinterpret_s32_s64(svdot_s64(svdup_n_s64(0), svq8sums_1, svmins8_1)); + svint32_t svsumfs_tmp6 = svtrn1_s32(svsumfs_tmp4, svsumfs_tmp5); + svint32_t svsumfs_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6))); + svint32_t svsumfs_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(svsumfs_tmp3), svreinterpret_s64_s32(svsumfs_tmp6))); + svint32_t svsumfs_tmp = svadd_s32_x(pg128_all, svsumfs_tmp7, svsumfs_tmp8); + svint32_t svscales, sumi1, sumi2; + svint32_t acc_sumif1 = svdup_n_s32(0); + svint32_t acc_sumif2 = svdup_n_s32(0); + svint8_t q4bytes_0_l, q4bytes_0_h, q4bytes_1_l, q4bytes_1_h, l0, l1, l2, l3, + q8bytes_0_h, q8bytes_0_l, q8bytes_1_h, q8bytes_1_l, r0, r1, r2, r3; +#pragma GCC unroll 1 + for (int j = 0; j < QK_K/64; ++j) { + q4bytes_0_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 0xf)); + q4bytes_1_l = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 0xf)); + q4bytes_0_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 0xf)); + q4bytes_1_h = svreinterpret_s8_u8(svand_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 0xf)); + l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); + l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); + l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); + l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); + q8bytes_0_h = svld1_s8(pg128_all, q8_0); + q8bytes_1_h = svld1_s8(pg128_all, q8_1); + q8bytes_0_l = svld1_s8(pg128_all, q8_0+16); + q8bytes_1_l = svld1_s8(pg128_all, q8_1+16); + r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); + r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); + r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); + r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); + sumi1 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3); + svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24)); + acc_sumif1 = svmla_s32_x(pg128_all, acc_sumif1, svscales, sumi1); + + q4bytes_0_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0), 4)); + q4bytes_1_l = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1), 4)); + q4bytes_0_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_0+16), 4)); + q4bytes_1_h = svreinterpret_s8_u8(svlsr_n_u8_x(pg128_all, svld1_u8(pg128_all, q4_1+16), 4)); + l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); + l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_l), svreinterpret_s64_s8(q4bytes_1_l))); + l2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); + l3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q4bytes_0_h), svreinterpret_s64_s8(q4bytes_1_h))); + q8bytes_0_h = svld1_s8(pg128_all, q8_0+32); + q8bytes_1_h = svld1_s8(pg128_all, q8_1+32); + q8bytes_0_l = svld1_s8(pg128_all, q8_0+48); + q8bytes_1_l = svld1_s8(pg128_all, q8_1+48); + r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); + r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_h), svreinterpret_s64_s8(q8bytes_1_h))); + r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); + r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0_l), svreinterpret_s64_s8(q8bytes_1_l))); + sumi2 = svmmla_s32(svmmla_s32(svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), r2, l2), r3, l3); + svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg128_all, svlsl_n_u32_x(pg128_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24)); + acc_sumif2 = svmla_s32_x(pg128_all, acc_sumif2, svscales, sumi2); + q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64; + } + sumf1 = svmla_f32_x(pg128_all, + svmla_f32_x(pg128_all, + sumf1, + svcvt_f32_x(pg128_all, + svadd_s32_x(pg128_all, acc_sumif1, acc_sumif2)), + svsuper_block_scales), + svdmins, + svcvt_f32_s32_x(pg128_all, svsumfs_tmp)); + } //end of for nb + } // end of case 128 + break; + case 256: + case 512: + { + const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4); + const svbool_t pg8_16 = svptrue_pat_b8(SV_VL16); + const svbool_t pg256_all = svptrue_pat_b8(SV_ALL); + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q4_0 = vx0[i].qs; + const int8_t * GGML_RESTRICT q8_0 = vy0[i].qs; + const uint8_t * GGML_RESTRICT q4_1 = vx1[i].qs; + const int8_t * GGML_RESTRICT q8_1 = vy1[i].qs; + svint32_t svscales, sumi1, sumi2; + svint32_t acc_sumif1 = svdup_n_s32(0); + svint32_t acc_sumif2 = svdup_n_s32(0); + svint8_t l0, l1, l2, l3, r0, r1, r2, r3; + svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); + svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d))); + svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp)); + svfloat32_t svsuper_block_scales = svmul_f32_z(pg32_4, vy_d, vx_d); + svfloat32_t vx_dmins = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].dmin)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].dmin))); + svfloat64_t vy_dmins_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d))); + svfloat32_t vy_dmins = svreinterpret_f32_f64(svuzp1_f64(vy_dmins_tmp, vy_dmins_tmp)); + svfloat32_t svdmins = svmul_n_f32_x(pg32_4, svmul_f32_x(pg32_4, vx_dmins, vy_dmins), -1); + svint16_t rc1 = svuzp1_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums)); + svint16_t rc2 = svuzp2_s16(svld1_s16(pg256_all, vy0[i].bsums), svld1_s16(pg256_all, vy1[i].bsums)); + svint16_t svq8sums = svadd_s16_x(pg256_all, rc1, rc2); + svuint32_t decoded_scales0 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx0[i].scales); + svuint32_t decoded_scales1 = ggml_decode_q4scales_and_mins_for_mmla((const uint32_t *)vx1[i].scales); + svuint32x2_t decoded_scales = svcreate2_u32(decoded_scales0, decoded_scales1); + svst2_u32(pg8_16, new_utmp.u32, decoded_scales); + svint16_t new_svq8sums_0 = svreinterpret_s16_u64(svtrn1_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums))); + svint16_t new_svq8sums_1 = svreinterpret_s16_u64(svtrn2_u64(svreinterpret_u64_s16(svq8sums), svreinterpret_u64_s16(svq8sums))); + svuint64_t new_mins_0 = svdup_u64(new_utmp.u64[2]); + svuint64_t new_mins_1 = svdup_u64(new_utmp.u64[3]); + svint16_t new_svmins8_0 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_0))); + svint16_t new_svmins8_1 = svreinterpret_s16_u16(svunpklo_u16(svreinterpret_u8_u64(new_mins_1))); + svint64_t dot_prod_0 = svdot_s64(svdup_s64(0), new_svmins8_0, new_svq8sums_0); + svint64_t dot_prod_1 = svdot_s64(dot_prod_0, new_svmins8_1, new_svq8sums_1); + svfloat32_t converted_dot_prod_1 = svcvt_f32_s64_x(pg256_all, dot_prod_1); + svfloat32_t svsumfs_tmp = svuzp1_f32(converted_dot_prod_1, converted_dot_prod_1); + +#pragma GCC unroll 1 + for (int j = 0; j < QK_K/64; ++j) { + svuint8_t q4bytes_0 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 0xf); + svuint8_t q4bytes_1 = svand_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 0xf); + svuint8_t q4bytes_2 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_0), 4); + svuint8_t q4bytes_3 = svlsr_n_u8_x(pg256_all, svld1_u8(pg256_all, q4_1), 4); + l0 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1))); + l1 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_0), svreinterpret_u64_u8(q4bytes_1))); + l2 = svreinterpret_s8_u64(svzip1_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3))); + l3 = svreinterpret_s8_u64(svzip2_u64(svreinterpret_u64_u8(q4bytes_2), svreinterpret_u64_u8(q4bytes_3))); + svint8_t q8bytes_0 = svld1_s8(pg256_all, q8_0); + svint8_t q8bytes_1 = svld1_s8(pg256_all, q8_1); + svint8_t q8bytes_2 = svld1_s8(pg256_all, q8_0+32); + svint8_t q8bytes_3 = svld1_s8(pg256_all, q8_1+32); + r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); + r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); + r2 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3))); + r3 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_2), svreinterpret_s64_s8(q8bytes_3))); + sumi1 = svmmla(svmmla(svdup_n_s32(0), r0, l0), r1, l1); + svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-1)), 24)); + acc_sumif1 = svmla_s32_x(pg256_all, acc_sumif1, svscales, sumi1); + sumi2 = svmmla(svmmla(svdup_n_s32(0), r2, l2), r3, l3); + svscales = svreinterpret_s32_u32(svlsr_n_u32_x(pg256_all, svlsl_n_u32_x(pg256_all, svreinterpret_u32_u64(svdup_n_u64(new_utmp.u64[j/2])), 8*(4-2*(j%2)-2)), 24)); + acc_sumif2 = svmla_s32_x(pg256_all, acc_sumif2, svscales, sumi2); + q4_0 += 32; q4_1 += 32; q8_0 += 64; q8_1 += 64; + } + svint32_t acc_sumif = svadd_s32_x(pg256_all, acc_sumif1, acc_sumif2); + svint32_t swap_acc_sumif = svext_s32(acc_sumif, acc_sumif, 4); + acc_sumif = svadd_s32_x(pg32_4, acc_sumif, swap_acc_sumif); + sumf1 = svmla_f32_x(pg32_4, + svmla_f32_x(pg32_4, + sumf1, + svcvt_f32_x(pg32_4, acc_sumif), + svsuper_block_scales), + svdmins, + svsumfs_tmp); + } // end of for nb + } // end of case 256-512 + break; + default: + assert(false && "Unsupported vector length"); + break; + } + svst1_f32(pg32_2, s, sumf1); + svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sumf1), svdup_n_u8(0), 8))); + return ; + } +#else if (nrc == 2) { const block_q4_K * GGML_RESTRICT x0 = x; const block_q4_K * GGML_RESTRICT x1 = (const block_q4_K *) ((const uint8_t *)vx + bx); @@ -2206,6 +2433,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi return; } #endif +#endif #ifdef __ARM_FEATURE_SVE float sumf = 0; @@ -2235,7 +2463,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * GGML_RESTRICT q4 = x[i].qs; const int8_t * GGML_RESTRICT q8 = y[i].qs; - const int vector_length = ggml_cpu_get_sve_cnt()*8; const svuint8_t m4b = svdup_n_u8(0xf); const svint32_t mzero = svdup_n_s32(0); svint32_t sumi1 = svdup_n_s32(0); @@ -2480,7 +2707,197 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; +#ifdef __ARM_FEATURE_SVE + const int vector_length = ggml_cpu_get_sve_cnt()*8; +#endif #if defined(__ARM_FEATURE_MATMUL_INT8) +#ifdef __ARM_FEATURE_SVE + if (nrc==2) { + const svbool_t pg32_2 = svptrue_pat_b32(SV_VL2); + svfloat32_t sum = svdup_n_f32(0); + const block_q6_K * GGML_RESTRICT vx0 = vx; + const block_q8_K * GGML_RESTRICT vy0 = vy; + const block_q6_K * GGML_RESTRICT vx1 = (const block_q6_K *) ((const uint8_t*)vx + bx); + const block_q8_K * GGML_RESTRICT vy1 = (const block_q8_K *) ((const uint8_t*)vy + by); + switch (vector_length) { + case 128: + { + const svbool_t pg128_all = svptrue_pat_b8(SV_ALL); + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql; + const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh; + const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql; + const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh; + const int8_t * GGML_RESTRICT q80 = vy0[i].qs; + const int8_t * GGML_RESTRICT q81 = vy1[i].qs; + + const int8_t * GGML_RESTRICT scale0 = vx0[i].scales; + const int8_t * GGML_RESTRICT scale1 = vx1[i].scales; + + svfloat32_t vy_d = svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d)); + svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); + svfloat32_t svsuper_block_scales = svmul_f32_x(pg128_all, vy_d, vx_d); + // process q8sum summation 128 bit route + const svint16_t q8sums_01 = svld1_s16(pg128_all, vy0[i].bsums); + const svint16_t q8sums_02 = svld1_s16(pg128_all, vy0[i].bsums + 8); + const svint16_t q8sums_11 = svld1_s16(pg128_all, vy1[i].bsums); + const svint16_t q8sums_12 = svld1_s16(pg128_all, vy1[i].bsums + 8); + const svint64x2_t q6scales_0_tmp = svld2_s64(pg128_all, (const int64_t *)scale0); + const svint16_t q6scales_01 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 0))); + const svint16_t q6scales_02 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_0_tmp, 1))); + const svint64x2_t q6scales_1_tmp = svld2_s64(pg128_all, (const int64_t *)scale1); + const svint16_t q6scales_11 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 0))); + const svint16_t q6scales_12 = svunpklo_s16(svreinterpret_s8_s64(svget2_s64(q6scales_1_tmp, 1))); + const svint64_t prod = svdup_n_s64(0); + + svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_01), q8sums_02, q6scales_02)); + svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_01, q6scales_11), q8sums_02, q6scales_12)); + svint32_t isum_tmp3 = svtrn1_s32(isum_tmp1, isum_tmp2); + svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_01), q8sums_12, q6scales_02)); + svint32_t isum_tmp5 = svreinterpret_s32_s64(svdot_s64(svdot_s64(prod, q8sums_11, q6scales_11), q8sums_12, q6scales_12)); + svint32_t isum_tmp6 = svtrn1_s32(isum_tmp4, isum_tmp5); + svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6))); + svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp3), svreinterpret_s64_s32(isum_tmp6))); + svint32_t svisum_mins = svadd_s32_x(pg128_all, isum_tmp7, isum_tmp8); + + // process mmla + svint8_t l0, l1, r0, r1; + svint32_t isum_tmp = svdup_n_s32(0); + for (int j = 0; j < QK_K/128; ++j) { + for (int k = 0; k < 8; ++k) { + svuint8_t qhbits_0 = svld1_u8(pg128_all, qh0+16*(k%2)); + svuint8_t qhbits_1 = svld1_u8(pg128_all, qh1+16*(k%2)); + svuint8_t q6bits_0 = svld1_u8(pg128_all, ql0+16*(k%4)); + svuint8_t q6bits_1 = svld1_u8(pg128_all, ql1+16*(k%4)); + const int ql_pos = (k/4)*4; + svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_0, 4); + svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg128_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg128_all, q6bits_1, 4); + const int qh_pos = (k/2)*2; + svuint8_t q6bytes_0_hi = svand_n_u8_x(pg128_all, qhbits_0, 0x3 << qh_pos); + svuint8_t q6bytes_1_hi = svand_n_u8_x(pg128_all, qhbits_1, 0x3 << qh_pos); + svint8_t q6bytes_0, q6bytes_1; + if (qh_pos <= 4) { + q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos))); + q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg128_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos))); + } else { + q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_0_lo, svlsr_n_u8_x(pg128_all, q6bytes_0_hi, (qh_pos - 4)))); + q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg128_all, q6bytes_1_lo, svlsr_n_u8_x(pg128_all, q6bytes_1_hi, (qh_pos - 4)))); + } + svint8_t q8bytes_0 = svld1_s8(pg128_all, q80+16*(k%8)); + svint8_t q8bytes_1 = svld1_s8(pg128_all, q81+16*(k%8)); + l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); + l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); + r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); + r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); + svint32_t svscale = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k])); + isum_tmp = svmla_s32_x(pg128_all, isum_tmp, svmmla_s32(svmmla_s32(svdup_n_s32(0), r0, l0), r1, l1), svscale); + } + qh0 += 32; qh1 += 32; + ql0 += 64; ql1 += 64; + q80 += 128; q81 += 128; + scale0 += 8; scale1 += 8; + } + sum = svmla_f32_x(pg128_all, sum, + svcvt_f32_x(pg128_all, svmla_s32_x(pg128_all, isum_tmp, + svisum_mins, svdup_n_s32(-32))), + svsuper_block_scales); + } + } // end of case 128 + break; + case 256: + case 512: + { + const svbool_t pg256_all = svptrue_pat_b8(SV_ALL); + const svbool_t pg32_4 = svptrue_pat_b32(SV_VL4); + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT ql0 = vx0[i].ql; + const uint8_t * GGML_RESTRICT qh0 = vx0[i].qh; + const uint8_t * GGML_RESTRICT ql1 = vx1[i].ql; + const uint8_t * GGML_RESTRICT qh1 = vx1[i].qh; + const int8_t * GGML_RESTRICT q80 = vy0[i].qs; + const int8_t * GGML_RESTRICT q81 = vy1[i].qs; + + const int8_t * GGML_RESTRICT scale0 = vx0[i].scales; + const int8_t * GGML_RESTRICT scale1 = vx1[i].scales; + svfloat32_t vx_d = svzip1_f32(svdup_n_f32(GGML_FP16_TO_FP32(vx0[i].d)), svdup_n_f32(GGML_FP16_TO_FP32(vx1[i].d))); + svfloat64_t vy_d_tmp = svreinterpret_f64_f32(svuzp1_f32(svdup_n_f32(vy0[i].d), svdup_n_f32(vy1[i].d))); + svfloat32_t vy_d = svreinterpret_f32_f64(svuzp1_f64(vy_d_tmp, vy_d_tmp)); + svfloat32_t svsuper_block_scales = svmul_f32_x(pg32_4, vy_d, vx_d); + // process q8sum summation 256 bit route + const svint16_t q8sums_0 = svld1_s16(pg256_all, vy0[i].bsums); + const svint16_t q8sums_1 = svld1_s16(pg256_all, vy1[i].bsums); + const svint16_t q6scales_0 = svunpklo_s16(svld1_s8(pg256_all, scale0)); + const svint16_t q6scales_1 = svunpklo_s16(svld1_s8(pg256_all, scale1)); + const svint64_t prod = svdup_n_s64(0); + svint32_t isum_tmp1 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_0)); + svint32_t isum_tmp2 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_0, q6scales_1)); + svint32_t isum_tmp3 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_0)); + svint32_t isum_tmp4 = svreinterpret_s32_s64(svdot_s64(prod, q8sums_1, q6scales_1)); + svint32_t isum_tmp5 = svtrn1_s32(isum_tmp1, isum_tmp2); + svint32_t isum_tmp6 = svtrn1_s32(isum_tmp3, isum_tmp4); + svint32_t isum_tmp7 = svreinterpret_s32_s64(svtrn2_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6))); + svint32_t isum_tmp8 = svreinterpret_s32_s64(svtrn1_s64(svreinterpret_s64_s32(isum_tmp5), svreinterpret_s64_s32(isum_tmp6))); + svint32_t isum_tmp9 = svadd_s32_x(pg256_all, isum_tmp7, isum_tmp8); + svint32_t isum_tmp10 = svreinterpret_s32_u8(svext_u8(svreinterpret_u8_s32(isum_tmp9), svreinterpret_u8_s32(isum_tmp9), 16)); + svint32_t svisum_mins = svadd_s32_z(pg32_4, isum_tmp9, isum_tmp10); + + // process mmla + svint8_t l0, l1, r0, r1; + svint32_t isum_tmp = svdup_n_s32(0); + for (int j = 0; j < QK_K/128; ++j) { + for (int k = 0; k < 8; k+=2) { // process 2 block + svuint8_t qhbits_0 = svld1_u8(pg256_all, qh0); + svuint8_t qhbits_1 = svld1_u8(pg256_all, qh1); + svuint8_t q6bits_0 = svld1_u8(pg256_all, ql0+32*((k%4)/2)); + svuint8_t q6bits_1 = svld1_u8(pg256_all, ql1+32*((k%4)/2)); + const int ql_pos = (k/4)*4; + svuint8_t q6bytes_0_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_0, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_0, 4); + svuint8_t q6bytes_1_lo = (ql_pos < 4) ? svand_n_u8_x(pg256_all, q6bits_1, 0xf) : svlsr_n_u8_x(pg256_all, q6bits_1, 4); + const int qh_pos = (k/2)*2; + svuint8_t q6bytes_0_hi = svand_n_u8_x(pg256_all, qhbits_0, 0x3 << qh_pos); + svuint8_t q6bytes_1_hi = svand_n_u8_x(pg256_all, qhbits_1, 0x3 << qh_pos); + svint8_t q6bytes_0, q6bytes_1; + if (qh_pos <= 4) { + q6bytes_0 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_0_lo, q6bytes_0_hi, 1 << (4 - qh_pos))); + q6bytes_1 = svreinterpret_s8_u8(svmla_n_u8_x(pg256_all, q6bytes_1_lo, q6bytes_1_hi, 1 << (4 - qh_pos))); + } else { + q6bytes_0 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_0_lo, svlsr_n_u8_x(pg256_all, q6bytes_0_hi, (qh_pos - 4)))); + q6bytes_1 = svreinterpret_s8_u8(svorr_u8_x(pg256_all, q6bytes_1_lo, svlsr_n_u8_x(pg256_all, q6bytes_1_hi, (qh_pos - 4)))); + } + svint8_t q8bytes_0 = svld1_s8(pg256_all, q80+32*(k/2)); + svint8_t q8bytes_1 = svld1_s8(pg256_all, q81+32*(k/2)); + l0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); + l1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q6bytes_0), svreinterpret_s64_s8(q6bytes_1))); + r0 = svreinterpret_s8_s64(svzip1_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); + r1 = svreinterpret_s8_s64(svzip2_s64(svreinterpret_s64_s8(q8bytes_0), svreinterpret_s64_s8(q8bytes_1))); + svint32_t svscale0 = svzip1_s32(svdup_n_s32(scale0[k]), svdup_n_s32(scale1[k])); + svint32_t svscale1 = svzip1_s32(svdup_n_s32(scale0[k+1]), svdup_n_s32(scale1[k+1])); + isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r0, l0), svscale0); + isum_tmp = svmla_s32_x(pg256_all, isum_tmp, svmmla_s32(svdup_n_s32(0), r1, l1), svscale1); + } + qh0 += 32; qh1 += 32; + ql0 += 64; ql1 += 64; + q80 += 128; q81 += 128; + scale0 += 8; scale1 += 8; + } // end of for + svint32_t swap_isum_tmp = svext_s32(isum_tmp, isum_tmp, 4); + isum_tmp = svadd_s32_x(pg32_4, isum_tmp, swap_isum_tmp); + sum = svmla_f32_x(pg32_4, sum, + svcvt_f32_x(pg32_4, svmla_s32_x(pg32_4, isum_tmp, + svisum_mins, svdup_n_s32(-32))), + svsuper_block_scales); + } + } // end of case 256 + break; + default: + assert(false && "Unsupported vector length"); + break; + } // end of switch + svst1_f32(pg32_2, s, sum); + svst1_f32(pg32_2, s + bs, svreinterpret_f32_u8(svext_u8(svreinterpret_u8_f32(sum), svdup_n_u8(0), 8))); + return; + } +#else if (nrc == 2) { const block_q6_K * GGML_RESTRICT x0 = x; const block_q6_K * GGML_RESTRICT x1 = (const block_q6_K *) ((const uint8_t *)vx + bx); @@ -2594,27 +3011,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi // adjust bias, apply superblock scale { int32_t bias[4]; -#ifdef __ARM_FEATURE_SVE - const svbool_t pg16_8 = svptrue_pat_b16(SV_VL8); - const svbool_t pg8_8 = svptrue_pat_b8(SV_VL8); - const svint16_t y0_q8sums_0 = svld1_s16(pg16_8, y0->bsums); - const svint16_t y0_q8sums_1 = svld1_s16(pg16_8, y0->bsums + 8); - const svint16_t y1_q8sums_0 = svld1_s16(pg16_8, y1->bsums); - const svint16_t y1_q8sums_1 = svld1_s16(pg16_8, y1->bsums + 8); - const svint16_t x0_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x0->scales)); - const svint16_t x0_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x0->scales + 8)); - const svint16_t x1_q6scales_0 = svunpklo_s16(svld1_s8(pg8_8, x1->scales)); - const svint16_t x1_q6scales_1 = svunpklo_s16(svld1_s8(pg8_8, x1->scales + 8)); - const svint64_t zero = svdup_n_s64(0); - bias[0] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x0_q6scales_0), - svdot_s64(zero, y0_q8sums_1, x0_q6scales_1))); - bias[1] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x0_q6scales_0), - svdot_s64(zero, y1_q8sums_1, x0_q6scales_1))); - bias[2] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y0_q8sums_0, x1_q6scales_0), - svdot_s64(zero, y0_q8sums_1, x1_q6scales_1))); - bias[3] = svaddv_s64(svptrue_b64(), svadd_s64_x(svptrue_b64(), svdot_s64(zero, y1_q8sums_0, x1_q6scales_0), - svdot_s64(zero, y1_q8sums_1, x1_q6scales_1))); -#else // NEON doesn't support int16 dot product, fallback to separated mul and add const int16x8x2_t q8sums0 = vld1q_s16_x2(y0->bsums); const int16x8x2_t q8sums1 = vld1q_s16_x2(y1->bsums); @@ -2646,7 +3042,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi vmull_s16(vget_high_s16(q8sums1.val[1]), vget_high_s16(q6scales1.val[1])))); bias[3] = vaddvq_s32(prod); -#endif const int32x4_t vibias = vmulq_n_s32(vld1q_s32(bias), 32); const float32x4_t superblock_scale = { @@ -2670,9 +3065,9 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi return; } #endif +#endif #ifdef __ARM_FEATURE_SVE - const int vector_length = ggml_cpu_get_sve_cnt()*8; float sum = 0; svuint8_t m4b = svdup_n_u8(0xf); svint32_t vzero = svdup_n_s32(0);