Skip to content

Commit d02f2ec

Browse files
ggml-cpu: add vec_dot_f16_unroll
1 parent 49afba6 commit d02f2ec

File tree

1 file changed

+83
-7
lines changed

1 file changed

+83
-7
lines changed

ggml/src/ggml-cpu/vec.h

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,89 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
224224
}
225225
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
226226
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
227-
#elif defined(__riscv_v_intrinsic)
228-
// todo: RVV impl
229-
for (int i = 0; i < n; ++i) {
230-
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
231-
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
232-
}
233-
}
227+
228+
#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh)
229+
size_t vl = __riscv_vsetvlmax_e32m2();
230+
231+
// initialize accumulators to all zeroes
232+
vfloat32m2_t vsum0_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
233+
vfloat32m2_t vsum0_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
234+
vfloat32m2_t vsum0_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
235+
vfloat32m2_t vsum0_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
236+
vfloat32m2_t vsum1_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
237+
vfloat32m2_t vsum1_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
238+
vfloat32m2_t vsum1_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
239+
vfloat32m2_t vsum1_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl);
240+
241+
// calculate step size
242+
const size_t epr = __riscv_vsetvlmax_e16m1();
243+
const size_t step = epr * 4;
244+
const int np = (n & ~(step - 1));
245+
246+
// unroll by 4
247+
for (int i = 0; i < np; i += step) {
248+
vfloat16m1_t ay0 = __riscv_vle16_v_f16m1((const _Float16 *)(y + i), epr);
249+
vfloat16m1_t ax0_0 = __riscv_vle16_v_f16m1((const _Float16 *)(x[0] + i), epr);
250+
vfloat16m1_t ax1_0 = __riscv_vle16_v_f16m1((const _Float16 *)(x[1] + i), epr);
251+
vsum0_0 = __riscv_vfwmacc_vv_f32m2(vsum0_0, ax0_0, ay0, epr);
252+
vsum1_0 = __riscv_vfwmacc_vv_f32m2(vsum1_0, ax1_0, ay0, epr);
253+
__asm__ __volatile__("" ::: "memory");
254+
255+
vfloat16m1_t ay1 = __riscv_vle16_v_f16m1((const _Float16 *)(y + i + epr), epr);
256+
vfloat16m1_t ax0_1 = __riscv_vle16_v_f16m1((const _Float16 *)(x[0] + i + epr), epr);
257+
vfloat16m1_t ax1_1 = __riscv_vle16_v_f16m1((const _Float16 *)(x[1] + i + epr), epr);
258+
vsum0_1 = __riscv_vfwmacc_vv_f32m2(vsum0_1, ax0_1, ay1, epr);
259+
vsum1_1 = __riscv_vfwmacc_vv_f32m2(vsum1_1, ax1_1, ay1, epr);
260+
__asm__ __volatile__("" ::: "memory");
261+
262+
vfloat16m1_t ay2 = __riscv_vle16_v_f16m1((const _Float16 *)(y + i + 2 * epr), epr);
263+
vfloat16m1_t ax0_2 = __riscv_vle16_v_f16m1((const _Float16 *)(x[0] + i + 2 * epr), epr);
264+
vfloat16m1_t ax1_2 = __riscv_vle16_v_f16m1((const _Float16 *)(x[1] + i + 2 * epr), epr);
265+
vsum0_2 = __riscv_vfwmacc_vv_f32m2(vsum0_2, ax0_2, ay2, epr);
266+
vsum1_2 = __riscv_vfwmacc_vv_f32m2(vsum1_2, ax1_2, ay2, epr);
267+
__asm__ __volatile__("" ::: "memory");
268+
269+
vfloat16m1_t ay3 = __riscv_vle16_v_f16m1((const _Float16 *)(y + i + 3 * epr), epr);
270+
vfloat16m1_t ax0_3 = __riscv_vle16_v_f16m1((const _Float16 *)(x[0] + i + 3 * epr), epr);
271+
vfloat16m1_t ax1_3 = __riscv_vle16_v_f16m1((const _Float16 *)(x[1] + i + 3 * epr), epr);
272+
vsum0_3 = __riscv_vfwmacc_vv_f32m2(vsum0_3, ax0_3, ay3, epr);
273+
vsum1_3 = __riscv_vfwmacc_vv_f32m2(vsum1_3, ax1_3, ay3, epr);
274+
__asm__ __volatile__("" ::: "memory");
275+
}
276+
277+
vfloat32m2_t vsum0_01 = __riscv_vfadd_vv_f32m2(vsum0_0, vsum0_1, vl);
278+
vfloat32m2_t vsum0_23 = __riscv_vfadd_vv_f32m2(vsum0_2, vsum0_3, vl);
279+
vfloat32m2_t vsum0 = __riscv_vfadd_vv_f32m2(vsum0_01, vsum0_23, vl);
280+
281+
vfloat32m2_t vsum1_01 = __riscv_vfadd_vv_f32m2(vsum1_0, vsum1_1, vl);
282+
vfloat32m2_t vsum1_23 = __riscv_vfadd_vv_f32m2(vsum1_2, vsum1_3, vl);
283+
vfloat32m2_t vsum1 = __riscv_vfadd_vv_f32m2(vsum1_01, vsum1_23, vl);
284+
285+
// leftovers
286+
for (int i = np; i < n; i += vl) {
287+
vl = __riscv_vsetvl_e16m1(n - i);
288+
vfloat16m1_t ay = __riscv_vle16_v_f16m1((const _Float16 *)(y + i), vl);
289+
vfloat16m1_t ax0 = __riscv_vle16_v_f16m1((const _Float16 *)(x[0] + i), vl);
290+
vfloat16m1_t ax1 = __riscv_vle16_v_f16m1((const _Float16 *)(x[1] + i), vl);
291+
292+
vsum0 = __riscv_vfwmacc_vv_f32m2(vsum0, ax0, ay, vl);
293+
vsum1 = __riscv_vfwmacc_vv_f32m2(vsum1, ax1, ay, vl);
294+
}
295+
296+
// reduce
297+
vl = __riscv_vsetvlmax_e32m1();
298+
vfloat32m1_t acc0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum0, 0),
299+
__riscv_vget_v_f32m2_f32m1(vsum0, 1), vl);
300+
vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1(
301+
acc0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
302+
303+
vfloat32m1_t acc1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum1, 0),
304+
__riscv_vget_v_f32m2_f32m1(vsum1, 1), vl);
305+
vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1(
306+
acc1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl);
307+
sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0);
308+
sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1);
309+
234310
#else
235311
const int np = (n & ~(GGML_F16_STEP - 1));
236312

0 commit comments

Comments
 (0)