From 6f8d9794eda71741320eada29df8136378ddd684 Mon Sep 17 00:00:00 2001 From: cailinxi Date: Thu, 7 Aug 2025 09:57:06 +0800 Subject: [PATCH 1/3] ggml: add spacemit backend Change-Id: I249bdc043485d815a9c351867137bc1e27cc2e23 --- cmake/riscv64-spacemit-linux-gnu-gcc.cmake | 31 + docs/build-riscv64-spacemit.md | 87 + ggml/src/ggml-cpu/CMakeLists.txt | 11 +- ggml/src/ggml-cpu/ggml-cpu.c | 40 + ggml/src/ggml-cpu/ggml-cpu.cpp | 10 + .../ggml-cpu/spacemit/ggml_spacemit_ime.cpp | 1056 ++++++ .../src/ggml-cpu/spacemit/ggml_spacemit_ime.h | 9 + .../spacemit/ggml_spacemit_ime_kernels.cpp | 3219 +++++++++++++++++ .../spacemit/ggml_spacemit_ime_kernels.h | 30 + ggml/src/ggml-cpu/vec.cpp | 248 ++ ggml/src/ggml-cpu/vec.h | 215 +- 11 files changed, 4952 insertions(+), 4 deletions(-) create mode 100644 cmake/riscv64-spacemit-linux-gnu-gcc.cmake create mode 100644 docs/build-riscv64-spacemit.md create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp create mode 100644 ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h diff --git a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake new file mode 100644 index 0000000000000..e1df484c7c5b8 --- /dev/null +++ b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake @@ -0,0 +1,31 @@ +# Copyright (c) 2023 SpacemiT. All rights reserved. +set(CMAKE_SYSTEM_NAME Linux) +SET(CMAKE_SYSTEM_PROCESSOR riscv64) +set(CMAKE_SYSTEM_VERSION 1) + +if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)") +message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}") +else() +set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple") +if(DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) +else() + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") +endif() + +set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") +set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc) +set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++) +set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip) +set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu") +set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot") +endif() + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) +set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}") +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic") +add_definitions(-D__fp16=_Float16) diff --git a/docs/build-riscv64-spacemit.md b/docs/build-riscv64-spacemit.md new file mode 100644 index 0000000000000..87cd58d5053f6 --- /dev/null +++ b/docs/build-riscv64-spacemit.md @@ -0,0 +1,87 @@ +> [!IMPORTANT] +> This build documentation is specific only to RISC-V SpacemiT SOCs. + +## Build llama.cpp locally (for riscv64) + +1. Prepare Toolchain For RISCV +~~~ +wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz +~~~ + +2. Build +Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`. +```bash + +cmake -B build-riscv64-spacemit \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DLLAMA_CURL=OFF \ + -DGGML_RV_ZFH=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ + -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \ + -DCMAKE_INSTALL_PREFIX=build-riscv64-spacemit/installed + +cmake --build build-riscv64-spacemit --parallel $(nproc) --config Release + +pushd build-riscv64-spacemit +make install +popd +``` + +## Simulation +You can use QEMU to perform emulation on non-RISC-V architectures. + +1. Download QEMU +~~~ +wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz +~~~ + +2. Run Simulation +After build your llama.cpp, you can run the executable file via QEMU for simulation, for example: +~~~ +export QEMU_ROOT_PATH={your QEMU file path} +export RISCV_ROOT_PATH_IME1={your RISC-V compiler path} + +${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build-riscv64-spacemit/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1 +~~~ +## Performance +#### Quantization Support For Matrix +~~~ +model name : Spacemit(R) X60 +isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt +mmu : sv39 +uarch : spacemit,x60 +mvendorid : 0x710 +marchid : 0x8000000058000001 +~~~ + +Q4_0 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02| + +Q4_1 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04| + + +Q4_K +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| \ No newline at end of file diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index f188d1638dc5d..bc43a3a6003d0 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -433,7 +433,16 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/repack.cpp ) - if (GGML_RVV) + if (GGML_CPU_RISCV64_SPACEMIT) + list(APPEND ARCH_FLAGS -march=rv64gcv_zfh_zba_zicbop -mabi=lp64d -DGGML_RV_ZFH -D${RISCV64_SPACEMIT_IME_SPEC}) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT) + list(APPEND GGML_CPU_SOURCES + ggml-cpu/spacemit/ggml_spacemit_ime.cpp + ggml-cpu/spacemit/ggml_spacemit_ime.h + ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp + ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h + ) + elseif (GGML_RVV) if (GGML_XTHEADVECTOR) list(APPEND ARCH_FLAGS -march=rv64gc_xtheadvector -mabi=lp64d) elseif (GGML_RV_ZFH) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index d89cd8f4ef652..31802d125945d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -3209,6 +3209,26 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) { uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); vec_xst(v_y, 0, (ggml_fp16_t *)(y + i)); } +#elif defined(__riscv) && defined(__riscv_v) && defined(__riscv_zfh) + int64_t n_loop = n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "slli t1, t0, 1 \n\t" + "slli t2, t0, 2 \n\t" + "vle32.v v0, (%[IN]) \n\t" + "add %[IN], %[IN], t2 \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "vfncvt.f.f.w v4, v0 \n\t" + "vse16.v v4, (%[DST]) \n\t" + "add %[DST], %[DST], t1 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + + : [ IN ] "+r"(x), [ DST ] "+r"(y), [ n ] "+r"(n_loop) + : + : "cc", "t0", "t1", "t2"); + i += n; #endif for (; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(x[i]); @@ -3250,6 +3270,26 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) { float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0); vec_xst(v_yh, 0, (float *)(y + i)); } +#elif defined(__riscv) && defined(__riscv_v) && defined(__riscv_zfh) + int64_t n_loop = n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "slli t1, t0, 2 \n\t" + "slli t2, t0, 1 \n\t" + "vle16.v v0, (%[IN]) \n\t" + "add %[IN], %[IN], t2 \n\t" + "vfwcvt.f.f.v v4, v0 \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "vse32.v v4, (%[DST]) \n\t" + "add %[DST], %[DST], t1 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + + : [ IN ] "+r"(x), [ DST ] "+r"(y), [ n ] "+r"(n_loop) + : + : "cc", "t0", "t1", "t2"); + i += n; #endif for (; i < n; ++i) { diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 8dacd36714b4c..17da22cfe1533 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -18,6 +18,10 @@ # include "kleidiai/kleidiai.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ggml_spacemit_ime.h" +#endif + #if defined(_WIN32) # define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX @@ -45,6 +49,12 @@ std::vector & ggml_backend_cpu_get_extra_buffer_type } #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_KLEIDIAI if (ggml_backend_cpu_kleidiai_buffer_type()) { bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp new file mode 100644 index 0000000000000..24f6d328be3e3 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp @@ -0,0 +1,1056 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "ggml-common.h" +#include "ggml-backend-impl.h" + +#include "ggml-quants.h" +#include "ggml-impl.h" +#include "ggml-cpu.h" +#include "ggml-cpu-impl.h" +#include "traits.h" + +#include +#include +#include +#include +#include // for qsort +#include // for GGML_ASSERT +#include + +#include "ggml_spacemit_ime.h" +#include "ggml_spacemit_ime_kernels.h" +#include "vec.h" + + +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#endif + +#if !defined(__riscv_zfh) || !defined(GGML_RV_ZFH) +#error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +#error "RISCV64_SPACEMIT_IME1 not defined" +#endif + +#else + +#error "riscv not enabled in this build" + +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + + +#if defined(RISCV64_SPACEMIT_IME1) +#define QGEMM_STRIDEN_THREAD_ALIGN 16 +#else +#define QGEMM_STRIDEN_THREAD_ALIGN 32 +#endif + +typedef enum { + ScaleFp32 = 0, + ScaleFp16, +} QNBIT_GEMM_SCALE_TYPE; + +template +struct QNBIT_GEMM_DATA_PARAMS { + const T* A = nullptr; ///< address of A (float32/16 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C + + QNBIT_GEMM_SCALE_TYPE ScaleType = QNBIT_GEMM_SCALE_TYPE::ScaleFp32; ///< datatype of B scale(FP32 or FP16). +}; + +constexpr +size_t +DivRoundup(size_t up, size_t down) +{ + return (up + down - 1) / down; +} +constexpr size_t +Q8BlkSize(size_t BlkLen) +{ + const size_t BlkSize = sizeof(float) + BlkLen * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(BlkSize % alignof(float) == 0); + return BlkSize; +} + +namespace ggml::cpu::riscv64_spacemit { + +const int num_ai_cores = std::thread::hardware_concurrency() / 2; + +} // namespace ggml::cpu::riscv64_spacemit + +static void SQ4BitGemm_CompInt8( + const size_t BlkLen, const size_t K, + const QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, const size_t RangeStartM, + const size_t RangeCountM, const size_t RangeStartN, + const size_t RangeCountN) { + const size_t scale_stride = + DataParams->ScaleType == QNBIT_GEMM_SCALE_TYPE::ScaleFp16 + ? sizeof(uint16_t) + : sizeof(float); + + constexpr size_t BlkBitWidth = 4; + + const size_t k_blks = DivRoundup(K, BlkLen); + + const size_t lda = k_blks * Q8BlkSize(BlkLen); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * (BlkLen * BlkBitWidth / 8); + const std::byte* QuantA = + static_cast(PerGemmWorkspace) + RangeStartM * lda; + + const size_t zero_point_stride = DataParams->QuantBZeroPoint != nullptr ? sizeof(uint8_t) : 0; + const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); + const std::byte* QuantBData = + static_cast(DataParams->PackedQuantBData) + + RangeStartN * packed_b_stride; + + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + size_t CountN; + const size_t ComputeBlockCountN = RangeCountM == 1 ? RangeCountN : 16; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, ComputeBlockCountN); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * packed_b_stride; + const std::byte* b_col_zp = (zero_point_stride != 0) + ? b_col + : nullptr; + float* c_blk = C + n; + + size_t RowsRemaining = RangeCountM; + + while (RowsRemaining > 0) { + const auto RowsHandled = sqnbitgemm_spacemit_ime::SQ4BitGemmKernel_CompInt8( + BlkLen, a_row, b_col, nullptr, b_col_zp, c_blk, + RowsRemaining, CountN, K, k_blks, ldc, nullptr, scale_stride); + + c_blk += RowsHandled * ldc; + a_row += RowsHandled * lda; + + RowsRemaining -= RowsHandled; + } + } +} + + + +template +constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template +struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template +struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), "wrong block_with_zp<4,16> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0* in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1* in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16* dst = (block_q4_0x16*)t->data; + const block_q4_0* src = (const block_q4_0*)data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16* dst = (block_q4_1x16*)t->data; + const block_q4_1* src = (const block_q4_1*)data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, const uint8_t* GGML_RESTRICT q, uint8_t* GGML_RESTRICT d, uint8_t* GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor* t, int interleave_block, const void* GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16* dst = (block_q4_1x16*)t->data; + const block_q4_K* src = (const block_q4_K*)data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t* q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(struct ggml_tensor*, const void*, size_t); + +template <> +int repack(struct ggml_tensor* t, const void* data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> +int repack(struct ggml_tensor* t, const void* data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> +int repack(struct ggml_tensor* t, const void* data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(struct ggml_tensor* t, const void* data, size_t data_size) = 0; +}; + +template +class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor* op, size_t& size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; + size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params* params, struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_Q4_0 || // + op->src[0]->type == GGML_TYPE_Q4_1 || // + op->src[0]->type == GGML_TYPE_Q4_K) { + forward_mul_mat_q4(params, op); + return true; + } + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_mul_mat_q4(ggml_compute_params* params, ggml_tensor* op) { + const ggml_tensor* src0 = op->src[0]; + const ggml_tensor* src1 = op->src[1]; + ggml_tensor* dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + int ith = params->ith; + int nth = params->nth; + + [[maybe_unused]] const enum ggml_type type = src0->type; + + void* w_data = (void*)src0->data; + const float* feature = (const float*)src1->data; + float* output = (float*)dst->data; + + const auto BatchN = ne12 * ne13; + [[maybe_unused]] const auto BatchWeight = ne02 * ne03; + const auto M = ne11; + const auto K = ne10; + const auto N = ne01; + + assert(BatchWeight == 1); + // constexpr size_t BlkBitWidth = 4; + constexpr size_t BlkLen = QK4_0; + + size_t BlockCountK = DivRoundup(K, BlkLen); + const size_t Size = M * BlockCountK * Q8BlkSize(BlkLen); + auto Alignment = alignof(double); + size_t PerGemmWorkspaceStride = DivRoundup(Size, Alignment) * Alignment; + const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; + const size_t desired_wsize = WorkspaceSize + Alignment - 1; + + if (params->wsize < desired_wsize && ith == 0) { + throw std::runtime_error( + "wsize less than MlasSQNBitGemmBatchWorkspaceSize"); + } + + std::vector> DataParams(BatchN); + + for (int i = 0; i < BatchN; i++) { + DataParams[i].A = feature + M * K * i; + DataParams[i].lda = K; + DataParams[i].QuantBDataWorkspace = w_data; + DataParams[i].PackedQuantBData = (const std::byte*)w_data; + DataParams[i].QuantBScale = nullptr; + + if constexpr (std::is_same_v) { + DataParams[i].QuantBZeroPoint = nullptr; + } else { + DataParams[i].QuantBZeroPoint = (const uint8_t*)w_data; + } + + DataParams[i].Bias = nullptr; + DataParams[i].C = output + M * N * i; + DataParams[i].ldc = N; + DataParams[i].ScaleType = QNBIT_GEMM_SCALE_TYPE::ScaleFp16; + } + Alignment = alignof(double);; + const uintptr_t WorkspaceAddress = reinterpret_cast(params->wdata); + void* Workspace = reinterpret_cast( + (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))); + + BlockCountK = DivRoundup(K, BlkLen); + const auto PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + PerGemmWorkspaceStride = + DivRoundup(PerGemmWorkspaceSize, Alignment) * Alignment; + + const auto QuantizeARow = sqnbitgemm_spacemit_ime::QuantizeARow_CompInt8; + const auto QuantizeAM4Row = sqnbitgemm_spacemit_ime::QuantizeAM4Row_CompInt8; + + BlockCountK = DivRoundup(K, BlkLen); + const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); + + { + const size_t BlockSizeM = 4; + size_t BlockCountM = DivRoundup(M, BlockSizeM); + int task_count = BatchN * BlockCountM; + int task_per_thread = (task_count + nth - 1) / nth; + int start = ith * task_per_thread; + int end = std::min((ith + 1) * task_per_thread, task_count); + for (int compute_idx = start; compute_idx < end; compute_idx++) { + auto gemm_idx = compute_idx / BlockCountM; + auto m_idx = compute_idx % BlockCountM * BlockSizeM; + const auto& data = DataParams[gemm_idx]; + auto RowsTobeHandled = (M - m_idx) > 4 ? 4 : (M - m_idx); + if (RowsTobeHandled == 4) { + const float* ARowPtr = data.A + m_idx * data.lda; + std::byte* QuantARowPtr = static_cast(Workspace) + + gemm_idx * PerGemmWorkspaceStride + + m_idx * QuantAStride; + QuantizeAM4Row(BlkLen, ARowPtr, K, QuantARowPtr); + + } else { + while (RowsTobeHandled) { + const float* ARowPtr = data.A + m_idx * data.lda; + std::byte* QuantARowPtr = static_cast(Workspace) + + gemm_idx * PerGemmWorkspaceStride + + m_idx * QuantAStride; + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + RowsTobeHandled -= 1; + m_idx += 1; + } + } + } + } + + ggml_barrier(params->threadpool); + + if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) + return; + + nth = std::min(nth, int{ggml::cpu::riscv64_spacemit::num_ai_cores}); + + int ThreadsPerGemm = nth / BatchN; + constexpr size_t StrideM = 128; + + size_t nc = N; + const size_t BlockedM = DivRoundup(M, StrideM); + const size_t max_nc = DivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min(nc, DivRoundup(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * + QGEMM_STRIDEN_THREAD_ALIGN); + } + + const size_t StrideN = nc; + const size_t ThreadCountM = DivRoundup(M, StrideM); + const size_t ThreadCountN = DivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + { + int task_count = BatchN * ThreadsPerGemm; + int task_per_thread = (task_count + nth - 1) / nth; + int start = ith * task_per_thread; + int end = std::min((ith + 1) * task_per_thread, task_count); + for (int compute_idx = start; compute_idx < end; compute_idx++) { + const auto gemm_i = compute_idx / ThreadsPerGemm; + const auto blk_i = compute_idx % ThreadsPerGemm; + const auto* Data = &DataParams[gemm_i]; + + const auto ThreadIdN = blk_i / ThreadCountM; + const auto ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + void* PerGemmWorkspace = reinterpret_cast(Workspace) + + gemm_i * PerGemmWorkspaceStride; + + SQ4BitGemm_CompInt8(BlkLen, K, Data, PerGemmWorkspace, RangeStartM, + RangeCountM, RangeStartN, RangeCountN); + } + } + } + + int repack(struct ggml_tensor* t, const void* data, size_t data_size) override { + GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), + (int)NB_COLS, (int)INTER_SIZE); + return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); + } +}; + +class tensor_traits_common : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor* op, size_t& size) override { + switch (op->op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + size = 0; + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params* params, struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_NORM: + forward_norm_f32(params, op); + return true; + case GGML_OP_RMS_NORM: + forward_rms_norm_f32(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_norm_f32(ggml_compute_params* params, ggml_tensor* op) { + const ggml_tensor* src0 = op->src[0]; + ggml_tensor* dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto* input = (float*)src0->data; + auto* output = (float*)dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto* p_input = const_cast(input + offset); + + auto* p_output = output + offset; + auto* p_temp_output = p_output; + auto* p_gamma_data = (const float*) nullptr; + auto* p_beta_data = (const float*) nullptr; + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + void forward_rms_norm_f32(ggml_compute_params* params, ggml_tensor* op) { + const ggml_tensor* src0 = op->src[0]; + ggml_tensor* dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto* input = (float*)src0->data; + auto* output = (float*)dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto* p_input = const_cast(input + offset); + auto* p_output = output + offset; + auto* p_temp_output = p_output; + auto* p_gamma_data = (const float*)nullptr; + auto* p_beta_data = (const float*)nullptr; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + // float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + int repack(struct ggml_tensor* t, const void* data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; + +static const tensor_traits q4_0_16x8_q8_0; +static const tensor_traits q4_1_16x8_q8_0; +static const tensor_traits q4_k_16x8_q8_0; +static const tensor_traits_common rvv_impl; + +} // namespace ggml::cpu::riscv64_spacemit + +static const ggml::cpu::tensor_traits* ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor* cur) { + if (cur->type == GGML_TYPE_Q4_0) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_1) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_K) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_F32) { + return &ggml::cpu::riscv64_spacemit::rvv_impl; + } + + return nullptr; +} + +static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor* tensor) { + tensor->extra = (void*)const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); + + GGML_UNUSED(buffer); + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor* tensor, + const void* data, size_t offset, size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base*)tensor->extra; + if (tensor_traits) { + auto OK = tensor_traits->repack(tensor, data, size); + GGML_ASSERT(OK == 0); + } + + GGML_UNUSED(buffer); +} + +static const char* ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_RISCV64_SPACEMIT"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 64; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const struct ggml_tensor* tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + + size_t nbytes; + const size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } else { + nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + if (tensor->type == GGML_TYPE_Q4_K) { + GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); + nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + } + } else { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } + } + + GGML_UNUSED(buft); + return nbytes; +} + +namespace ggml::cpu::riscv64_spacemit { + +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && + (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + if (op->src[0]->type == GGML_TYPE_F32) { + return true; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + ggml::cpu::tensor_traits* get_tensor_traits(const struct ggml_tensor* op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + return (ggml::cpu::tensor_traits*)op->src[0]->extra; + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return (ggml::cpu::tensor_traits*)(&ggml::cpu::riscv64_spacemit::rvv_impl); + default: + // GGML_ABORT("fatal error"); + break; + } + + return nullptr; + } +}; + +} // namespace ggml::cpu::riscv64_spacemit + +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + /* .iface = */ { + /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, // defaults to SIZE_MAX + /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes, + /* .is_host = */ nullptr, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ new ggml::cpu::riscv64_spacemit::extra_buffer_type(), + }; + + return &ggml_backend_cpu_buffer_type_riscv64_spacemit; +} \ No newline at end of file diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h new file mode 100644 index 0000000000000..8020508ac2eef --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h @@ -0,0 +1,9 @@ +#pragma once + +#include "traits.h" +#include "ggml.h" +#include + +// #include +// GGML internal header +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); \ No newline at end of file diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp new file mode 100644 index 0000000000000..947490c8d9ab0 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp @@ -0,0 +1,3219 @@ +#include "ggml_spacemit_ime_kernels.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// HPMADOT_OFF = 0 is not supported when ref_C used +#define HPMADOT_OFF 0 +namespace sqnbitgemm_spacemit_ime +{ +template +struct Q4_0X16 { + T scale[16]; + // uint8_t zp[8]; + // blklen = 16 + // uint8_t[16][8] + // b0 b8, b1 b9, b2 b10, b3 b11, b4 b12, b5 b13, b6 b14, b7 b15 + + // blklen = 32 + // uint8_t[16][2][8] + // b0 b8, b1 b9, b2 b10, b3 b11, b4 b12, b5 b13, b6 b14, b7 b15 + // b16 b24, b17 b25, b18 b26, b19 b27, b20 b28, b21 b29, b22 b30, b23 b31 +}; + +#define QUANTIZEM4ROW_KERNEL \ + "vmv.s.x v16, zero \n\t" \ + "vfabs.v v8, v0 \n\t" \ + "vfredmax.vs v16, v8, v16 \n\t" \ + "vfmv.f.s f10, v16 \n\t" \ + "fmul.s f10, f10, %[RMAXREC] \n\t" \ + "fsw f10, (a1) \n\t" \ + "fdiv.s f11, %[FONE], f10 \n\t" \ + "vfmul.vf v16, v0, f11 \n\t" \ + "vfcvt.x.f.v v16, v16 \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vnclip.wx v16, v16, zero \n\t" \ + "vnclip.wx v17, v17, zero \n\t" \ + "vnclip.wx v18, v18, zero \n\t" \ + "vnclip.wx v19, v19, zero \n\t" \ + "vnclip.wx v20, v20, zero \n\t" \ + "vnclip.wx v21, v21, zero \n\t" \ + "vnclip.wx v22, v22, zero \n\t" \ + "vnclip.wx v23, v23, zero \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vnclip.wx v24, v16, zero \n\t" \ + "vnclip.wx v25, v17, zero \n\t" \ + "vnclip.wx v26, v18, zero \n\t" \ + "vnclip.wx v27, v19, zero \n\t" \ + "vnclip.wx v28, v20, zero \n\t" \ + "vnclip.wx v29, v21, zero \n\t" \ + "vnclip.wx v30, v22, zero \n\t" \ + "vnclip.wx v31, v23, zero \n\t" + +#define QUANTIZEM4ROW_STORE \ + "addi t1, %[BlkLen], 0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v24, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v25, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v26, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v27, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v28, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v29, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v30, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v31, (s1) \n\t" + +void +QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA) +{ + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float* SRC = A + row_index * CountK; + std::byte* DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [ SRC ] "+r"(SRC) + : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), + [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 128) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float* SRC = A + row_index * CountK; + std::byte* DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "addi t2, t2, -128 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v16, v24, v16 \n\t" + "vfredmax.vs v24, v16, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e64, m4 \n\t" + "vsse64.v v16, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "sub t2, t2, t2 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "jal x0, QUANTIZE%= \n\t" + + "QUIT%=: \n\t" + : [ SRC ] "+r"(SRC) + : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), + [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 256) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float* SRC = A + row_index * CountK; + std::byte* DST = QuantA + row_index * sizeof(float); + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "addi t2, t2, -256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v24, v16 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t1, t2, 0 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + + "TAIL_LOOP%=: \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t2, e32, m1 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 32 \n\t" + "vfmul.vf v1, v0, f11 \n\t" + "vfcvt.x.f.v v2, v1 \n\t" + "vsetvli t0, zero, e16, mf2 \n\t" + "vnclip.wx v3, v2, zero \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vnclip.wx v3, v3, zero \n\t" + "vse8.v v3, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bnez t2, TAIL_LOOP%= \n\t" + + "QUIT%=: \n\t" + : [ SRC ] "+r"(SRC) + : [ DST ] "r"(DST), [ BlkLen ] "r"(BlkLen), [ OFFSET ] "r"(offset), [ STRIDE ] "r"(stride), + [ CountK ] "r"(CountK), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } +} + +void +QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA) +{ + const float* SRC = A; + std::byte* DST = QuantA; + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + std::byte* QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; + + if (CountK <= BlkLen) { + float max_abs_A = 0.0f; + for (size_t k = 0; k < CountK; k++) { + max_abs_A = std::max(max_abs_A, fabsf(A[k])); + } + float scale_A = max_abs_A * range_max_reciprocal; + + ((float*)QuantA)[0] = scale_A; + + auto* QuantAData_offset = (int8_t*)(QuantA + sizeof(float)); + + for (size_t k = 0; k < CountK; k++) { + QuantAData_offset[k] = + (int8_t)std::clamp(roundf(A[k] / scale_A), (float)std::numeric_limits::lowest(), + (float)std::numeric_limits::max()); + } + for (size_t k = CountK; k < BlkLen; k++) { + QuantAData_offset[k] = 0; + } + + return; + } + + if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { + __asm__ volatile( + "vsetvli t0, zero, e8, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[CNT], e8, m8 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "sub %[CNT], %[CNT], t0 \n\t" + "bnez %[CNT], LOOP%= \n\t" + : [ DST ] "+r"(QuantA_offset), [ CNT ] "+r"(offset) + : + : "cc", "t0"); + } + if (BlkLen == 16) { + float buffer[64] = {0.0f}; + __asm__ volatile( + "addi t3, zero, 16*8 \n\t" + "addi t2, zero, 16 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m2 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v2, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v4, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v6, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v10, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v12, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v14, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v18, v2 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v22, v6 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v26, v10 \n\t" + "vfabs.v v28, v12 \n\t" + "vfabs.v v30, v14 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v18, v18, v19 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v22, v22, v23 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v26, v26, v27 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vse32.v v16, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v18, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v20, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v22, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v24, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v26, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v28, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v30, (a1) \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f11, f3, f7 \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f11, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f12, f3, f7 \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fsw f12, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f13, f3, f7 \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f13, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f14, f3, f7 \n\t" + "fmul.s f14, f14, %[RMAXREC] \n\t" + "fsw f14, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f14, %[FONE], f14 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f15, f3, f7 \n\t" + "fmul.s f15, f15, %[RMAXREC] \n\t" + "fsw f15, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f15, %[FONE], f15 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f16, f3, f7 \n\t" + "fmul.s f16, f16, %[RMAXREC] \n\t" + "fsw f16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f16, %[FONE], f16 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f17, f3, f7 \n\t" + "fmul.s f17, f17, %[RMAXREC] \n\t" + "fsw f17, (%[DST]) \n\t" + "addi %[DST], %[DST], -136 \n\t" + "fdiv.s f17, %[FONE], f17 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v18, v2, f11 \n\t" + "vfmul.vf v20, v4, f12 \n\t" + "vfmul.vf v22, v6, f13 \n\t" + "vfmul.vf v24, v8, f14 \n\t" + "vfmul.vf v26, v10, f15 \n\t" + "vfmul.vf v28, v12, f16 \n\t" + "vfmul.vf v30, v14, f17 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v18, v18 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v22, v22 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v26, v26 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vfcvt.x.f.v v30, v30 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v18, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v20, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v22, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v26, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v28, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v30, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vse32.v v16, (%[BUFFER]) \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m2 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [ SRC ] "+r"(SRC), [ DST ] "+r"(DST), [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ BUFFER ] "r"(buffer) + : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", + "f13", "f14", "f15", "f16", "f17"); + } else if (BlkLen == 32) { + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ SRC ] "r"(SRC), [ DST ] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); + } else if (BlkLen == 64) { + __asm__ volatile( + "addi t3, zero, 64*2 \n\t" + "addi t2, zero, 64 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 68 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vfmax.vv v24, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v25 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 132 \n\t" + "vse8.v v24, (s2) \n\t" + "addi s2, s2, 132 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 256 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [ K ] "+r"(CountK) + : [ SRC ] "r"(SRC), [ DST ] "r"(DST), [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); + } else if (BlkLen == 128) { + __asm__ volatile( + "addi t2, zero, 128 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "sub %[K], %[K], t2 \n\t" + "QUANT%=: \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "vfmax.vv v28, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v30, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vfredmax.vs v31, v30, v31 \n\t" + "vfmv.f.s f10, v31 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v8, (a2) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "jal x0, QUANT%= \n\t" + "END%=: \n\t" + + : [ DST ] "+r"(DST), [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ SRC ] "r"(SRC) + : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); + } else { + float buffer[8] = {0.0f}; + size_t cnt = BlkLen / 256; + + __asm__ volatile( + "slli t3, %[BLK], 2 \n\t" + "blt %[K], %[BLK], LOOP_TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_CMP%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v16, v16, v24 \n\t" + "vfmax.vv v0, v0, v16 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, LOOP_CMP%= \n\t" + "sub %[SRC], %[SRC], t3 \n\t" + "addi t6, %[CNT], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_QUANT%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "vse8.v v4, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bnez t6, LOOP_QUANT%= \n\t" + "sub %[K], %[K], %[BLK] \n\t" + "bge %[K], %[BLK], LOOP_MAIN%= \n\t" + "blez %[K], END%= \n\t" + "LOOP_TAIL%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[K], 0 \n\t" + "addi s1, %[SRC], 0 \n\t" + "TAIL_CMP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t6, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "sub t6, t6, t0 \n\t" + "vfabs.v v0, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, TAIL_CMP%= \n\t" + "addi t6, %[K], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[K], 0 \n\t" + "TAIL_QUANT%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t1, t6, e32, m8 \n\t" + "vle32.v v0, (s1) \n\t" + "addi s1, s1, 256 \n\t" + "sub t6, t6, t1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 64 \n\t" + "bnez t6, TAIL_QUANT%= \n\t" + "END%=: \n\t" + : [ SRC ] "+r"(SRC), [ DST ] "+r"(DST), [ K ] "+r"(CountK) + : [ FONE ] "f"(fone), [ RMAXREC ] "f"(range_max_reciprocal), [ BLK ] "r"(BlkLen), [ BUFFER ] "r"(buffer), + [ CNT ] "r"(cnt) + : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); + } +} + +namespace +{ +#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \ + "vmadot v16, v14, v0 \n\t" \ + "vmadot v18, v14, v1 \n\t" \ + "vmadot v20, v14, v2 \n\t" \ + "vmadot v22, v14, v3 \n\t" \ + "vmadot v16, v15, v4 \n\t" \ + "vmadot v18, v15, v5 \n\t" \ + "vmadot v20, v15, v6 \n\t" \ + "vmadot v22, v15, v7 \n\t" + +#define SQ4BIT_KERNEL_ACC_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 16 \n\t" \ + "addi s3, s1, 32 \n\t" \ + "addi s4, s1, 48 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 8 \n\t" \ + "addi s3, s1, 16 \n\t" \ + "addi s4, s1, 24 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \ + "vle8.v v4, (s1) \n\t" \ + "addi s1, s1, 128 \n\t" \ + "vle8.v v5, (s2) \n\t" \ + "addi s2, s2, 128 \n\t" \ + "vle8.v v6, (s3) \n\t" \ + "addi s3, s3, 128 \n\t" \ + "vle8.v v7, (s4) \n\t" \ + "addi s4, s4, 128 \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vle8.v v14, (s5) \n\t" \ + "addi s5, s5, 16 \n\t" \ + "vle8.v v15, (s6) \n\t" \ + "addi s6, s6, 16 \n\t" \ + "addi t5, t5, -1 \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vand.vi v0, v4, 15 \n\t" \ + "vand.vi v1, v5, 15 \n\t" \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vsrl.vi v4, v4, 4 \n\t" \ + "vsrl.vi v5, v5, 4 \n\t" \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v1, (s7) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v8, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v9, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v10, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v11, v1, v13 \n\t" \ + "vadd.vi v13, v13, -12 \n\t" + +// using for M4Kernel +#define LOAD_B_16x8x2 \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vle8.v v6, (s1) \n\t" \ + "addi s1, s1, 32*4 \n\t" \ + "vle8.v v7, (s2) \n\t" \ + "addi s2, s2, 32*4 \n\t" \ + "vle8.v v8, (s3) \n\t" \ + "addi s3, s3, 32*4 \n\t" \ + "vle8.v v9, (s4) \n\t" \ + "addi s4, s4, 32*4 \n\t" \ + \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vand.vi v4, v8, 15 \n\t" \ + "vand.vi v5, v9, 15 \n\t" \ + \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" \ + "vsrl.vi v8, v8, 4 \n\t" \ + "vsrl.vi v9, v9, 4 \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16_FP16 \ + "addi s2, s5, -8 \n\t" \ + "addi s3, s5, 8 \n\t" \ + "addi s4, s5, 16 \n\t" \ + "addi s6, s5, 24 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e16, mf4 \n\t" \ + "vle16.v v9, (s5) \n\t" \ + "vle16.v v11, (s3) \n\t" \ + "vle16.v v13, (s4) \n\t" \ + "vle16.v v15, (s6) \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vle16.v v9, (s2), v0.t \n\t" \ + "vle16.v v11, (s5), v0.t \n\t" \ + "vle16.v v13, (s3), v0.t \n\t" \ + "vle16.v v15, (s4), v0.t \n\t" \ + "vfwcvt.f.f.v v8, v9 \n\t" \ + "vfwcvt.f.f.v v10, v11 \n\t" \ + "vfwcvt.f.f.v v12, v13 \n\t" \ + "vfwcvt.f.f.v v14, v15 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16 \ + "addi s2, s5, -16 \n\t" \ + "addi s3, s5, 16 \n\t" \ + "addi s4, s5, 32 \n\t" \ + "addi s6, s5, 48 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vle32.v v8, (s5) \n\t" \ + "vle32.v v10, (s3) \n\t" \ + "vle32.v v12, (s4) \n\t" \ + "vle32.v v14, (s6) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v8, (s2), v0.t \n\t" \ + "vle32.v v10, (s5), v0.t \n\t" \ + "vle32.v v12, (s3), v0.t \n\t" \ + "vle32.v v14, (s4), v0.t \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +//[s1| BIAS, s2, s3, s4] +#define LOAD_BIAS \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "addi s1, %[BIAS], -16 \n\t" \ + "addi s2, %[BIAS], 16 \n\t" \ + "addi s3, %[BIAS], 32 \n\t" \ + "addi s4, %[BIAS], 48 \n\t" \ + \ + "vle32.v v24, (%[BIAS]) \n\t" \ + "vle32.v v26, (s2) \n\t" \ + "vle32.v v28, (s3) \n\t" \ + "vle32.v v30, (s4) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v24, (s1), v0.t \n\t" \ + "vle32.v v26, (%[BIAS]), v0.t \n\t" \ + "vle32.v v28, (s2), v0.t \n\t" \ + "vle32.v v30, (s3), v0.t \n\t" \ + "vmv.v.v v25, v24 \n\t" \ + "vmv.v.v v27, v26 \n\t" \ + "vmv.v.v v29, v28 \n\t" \ + "vmv.v.v v31, v30 \n\t" + +#define SQ4BIT_KERNEL_COMP_4x16x16 \ + "vmadot v16, v10, v2 \n\t" \ + "vmadot v18, v10, v3 \n\t" \ + "vmadot v20, v10, v4 \n\t" \ + "vmadot v22, v10, v5 \n\t" \ + "vmadot v16, v11, v6 \n\t" \ + "vmadot v18, v11, v7 \n\t" \ + "vmadot v20, v11, v8 \n\t" \ + "vmadot v22, v11, v9 \n\t" + +#define SAVE_RESULT_4x16 \ + "addi a1, %[C], 0 \n\t" \ + "add a2, %[C], %[LDC] \n\t" \ + "add a3, a2, %[LDC] \n\t" \ + "add a4, a3, %[LDC] \n\t" \ + "addi a2, a2, -16 \n\t" \ + "addi a4, a4, -16 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + \ + "vse32.v v24, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v25, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v26, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v27, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v28, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v29, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v30, (a1) \n\t" \ + "vse32.v v31, (a3) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + \ + "vse32.v v24, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v25, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v26, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v27, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v28, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v29, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v30, (a2), v0.t \n\t" \ + "vse32.v v31, (a4), v0.t \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v11, (s6) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v12, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v13, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v14, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v15, v11, v1 \n\t" \ + "vadd.vi v1, v1, -12 \n\t" + +template +void +SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias, + const size_t ldc) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float* CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [ N ] "r"(N), [ SRC ] "r"(tmp), [ DST ] "r"(CPtr), [ LDC ] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} +template +void +SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias, + const size_t ldc) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float* bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [ SRC ] "r"(bias), [ DST ] "r"(tmp), [ N ] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr), [ BIAS ] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ LDC ] "r"(LDC), + [ BlockCountK ] "r"(BlockCountK), [ C ] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float* CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [ N ] "r"(N), [ SRC ] "r"(tmp), [ DST ] "r"(CPtr), [ LDC ] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} +template +void +SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t INNER = BlkLen / 16; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + // zp offset + "addi s7, %[B], 32 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(__fp16); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +void +SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t BlockCountK, + const float* Bias) +{ + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + const size_t INNER = BlkLen / 16; + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + // zp offset + "addi s7, %[B], 64 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + + // load scale + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load a scale + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + // a scale * b scale + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s7, %[B], 64 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte* QuantBDataPtr = (std::byte*)QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float* CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float* bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks), [ BIAS ] "+r"(bias) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [ CNT ] "+r"(cnt), [ NBLKS ] "+r"(nblks) + : [ INNER ] "r"(INNER), [ A ] "r"(QuantA), [ B ] "r"(QuantBDataPtr), [ C ] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +inline void +SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float* Bias, + const size_t ldc, + const size_t scalestride) +{ + if (scalestride == 4) { + SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias, ldc); + + } else if (scalestride == 2) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( + BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); + } +} + +template +inline void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float* Bias, + const size_t ldc, + const size_t scalestride) +{ + if (scalestride == 4) { + SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias); + } else if (scalestride == 2) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + } +} + +} // namespace + +size_t +SQ4BitGemmKernel_CompInt8(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias, + const size_t ScaleStride) +{ + GGML_UNUSED(CountM); + GGML_UNUSED(CountK); + GGML_UNUSED(ldc); + if (CountM >= 4) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 4; + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 1; + } +} +} \ No newline at end of file diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h new file mode 100644 index 0000000000000..e112e7fd7d050 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include "traits.h" +#include "ggml.h" + +namespace sqnbitgemm_spacemit_ime +{ +size_t +SQ4BitGemmKernel_CompInt8(size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias, + const size_t ScaleStride); +void +QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); + +void +QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); + +} \ No newline at end of file diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 07b377bdd82a7..5130b5dfcb6ff 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -109,6 +109,28 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G sumf += x[i]*y[i]; } #endif +#elif defined(__riscv) && defined(__riscv_v) + float sumf = 0.0f; + __asm__ volatile( + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v16, v16, v16 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "slli t1, t0, 2 \n\t" + "vle32.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t1 \n\t" + "vle32.v v8, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t1 \n\t" + "vfmacc.vv v16, v0, v8 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vfredusum.vs v24, v16, v24 \n\t" + "vfmv.f.s %[res], v24 \n\t" + : [ n ] "+r"(n), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ res ] "=f"(sumf) + : + : "cc", "t0", "t1"); #else // scalar ggml_float sumf = 0.0; @@ -224,6 +246,32 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G // if you hit this, you are likely running outside the FP range assert(!isnan(sumf) && !isinf(sumf)); +#elif defined(__riscv) && defined(__riscv_v) + float result = 0.0f; + __asm__ volatile( + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v16, v16, v16 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "slli t1, t0, 1 \n\t" + "vle16.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t1 \n\t" + "vle16.v v2, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t1 \n\t" + "vfwcvt.f.f.v v4, v0 \n\t" + "vfwcvt.f.f.v v8, v2 \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "vfmacc.vv v16, v4, v8 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vfredusum.vs v24, v16, v24 \n\t" + "vfmv.f.s %[res], v24 \n\t" + : [ n ] "+r"(n), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ res ] "=f"(result) + : + : "cc", "t0", "t1"); + sumf += result; #else for (int i = 0; i < n; ++i) { sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); @@ -251,6 +299,93 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) { for (; i + 3 < n; i += 4) { vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); } +#elif defined(__riscv) && defined(__riscv_v) + int N = n; + i += n; + constexpr struct { + float LowerRange; + float UpperRange; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_10; + float beta_8; + float beta_6; + float beta_4; + float beta_2; + float beta_0; + float one_half; + } LogisticConstants = { + -18.0f, + 18.0f, + 4.37031012579801e-11f, + 1.15627324459942e-07f, + 6.08574864600143e-05f, + 8.51377133304701e-03f, + 2.48287947061529e-01f, + 6.10247389755681e-13f, + 5.76102136993427e-09f, + 6.29106785017040e-06f, + 1.70198817374094e-03f, + 1.16817656904453e-01f, + 9.93151921023180e-01f, + 0.5f, + }; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m1,tu,mu \n\t" + "sub %[n], %[n], t0 \n\t" + "slli t0, t0, 2 \n\t" + "vfmv.v.f v20, %[b0] \n\t" + "vfmv.v.f v21, %[a1] \n\t" + "vfmv.v.f v22, %[b2] \n\t" + "vfmv.v.f v23, %[a3] \n\t" + "vfmv.v.f v24, %[b4] \n\t" + "vfmv.v.f v25, %[a5] \n\t" + "vfmv.v.f v26, %[b6] \n\t" + "vfmv.v.f v27, %[a7] \n\t" + "vfmv.v.f v28, %[b8] \n\t" + "vle32.v v0, (%[x]) \n\t" + "add %[x], %[x], t0 \n\t" + "vfmax.vf v1, v0, %[lr] \n\t" + "vfmin.vf v1, v1, %[ur] \n\t" + "vfmul.vv v4, v1, v1 \n\t" + "vmv.v.v v8, v4 \n\t" + "vfmadd.vf v8, %[a9], v27 \n\t" + "vfmadd.vv v8, v4, v25 \n\t" + "vfmadd.vv v8, v4, v23 \n\t" + "vfmadd.vv v8, v4, v21 \n\t" + "vfmul.vv v8, v8, v1 \n\t" + "vmv.v.v v12, v4 \n\t" + "vfmadd.vf v12, %[b10], v28 \n\t" + "vfmadd.vv v12, v4, v26 \n\t" + "vfmadd.vv v12, v4, v24 \n\t" + "vfmadd.vv v12, v4, v22 \n\t" + "vfmadd.vv v12, v4, v20 \n\t" + "vfdiv.vv v12, v8, v12 \n\t" + "vfadd.vf v12, v12, %[onehalf] \n\t" + "vfmul.vv v12, v12, v0 \n\t" // sigmo + "vse32.v v12, (%[y]) \n\t" + "add %[y], %[y], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ x ] "+r"(x), [ y ] "+r"(y) + : [ lr ] "f"(LogisticConstants.LowerRange), + [ ur ] "f"(LogisticConstants.UpperRange), + [ a1 ] "f"(LogisticConstants.alpha_1), + [ a3 ] "f"(LogisticConstants.alpha_3), + [ a5 ] "f"(LogisticConstants.alpha_5), + [ a7 ] "f"(LogisticConstants.alpha_7), + [ a9 ] "f"(LogisticConstants.alpha_9), + [ b0 ] "f"(LogisticConstants.beta_0), + [ b2 ] "f"(LogisticConstants.beta_2), + [ b4 ] "f"(LogisticConstants.beta_4), + [ b6 ] "f"(LogisticConstants.beta_6), + [ b8 ] "f"(LogisticConstants.beta_8), + [ b10 ] "f"(LogisticConstants.beta_10), + [ onehalf ] "f"(LogisticConstants.one_half) + : "cc", "t0"); #endif for (; i < n; ++i) { y[i] = ggml_silu_f32(x[i]); @@ -325,6 +460,119 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } +#elif defined(__riscv) && defined(__riscv_v) + int N = n; + i += n; + float* src = const_cast(reinterpret_cast(x)); + float* dst = reinterpret_cast(y); + float Accumulator = 0.0f; + const float Neg_Max = -max; + + const float LowerRange = -103.9720840454f; + const float UpperRange = 88.7762626647950f; + const float LowerRangeSumExp = -88.3762626647949f; + const float UpperRangeSumExp = 88.3762626647949f; + const float RoundingBias = 12582912.f; + const float Log2Reciprocal = 1.44269504088896341f; + const float Log2High = -6.93145752e-1f; + const float Log2Low = -1.42860677e-6f; + const float poly_0 = 0x1.694000p-10; + const float poly_1 = 0x1.125edcp-7; + const float poly_2 = 0x1.555b5ap-5; + const float poly_3 = 0x1.555450p-3; + const float poly_4 = 0x1.fffff6p-2; + const float poly_56 = 0x1.000000p+0; + // int32_t MinimumExponent = int32_t(0xC1000000); //unused + const int32_t MaximumExponent = int32_t(0x3F800000); + + __asm__ volatile( + "mv t3, %[LEN] \n\t" + "mv s1, %[SRC] \n\t" + "mv s2, %[DST] \n\t" + + /* 2.0 Compute exp() and accumulate and store to cache_buffer */ + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v16, v8, v8 \n\t" + "vxor.vv v0, v8, v8 \n\t" + + ".align 4 \n\t" + "_EXPACC_LEN_LPST: \n\t" + "vsetvli t0, t3, e32, m4,tu,mu \n\t" + + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + + /* 2.1 START exp() */ + "vfadd.vf v0, v0, %[NEG_MAX] \n\t" // v4 = x - max + + // Ensure that q = RN(x/log(2)) >= e_min, so that 2^q can be computed + // safely with a simple shift into the exponent field. xmin = + // round(-126.5 * log(2), single, RU) ~ -87.68311309814453125 const + // float xmin = -0x1.5ebb82p6; + "vfmax.vf v0, v0, %[LowerRangeSumExp] \n\t" + + // 2.1.0. Reduction x = s * q ln(2) + // const float r_ln2f = 0x1.715476p0f; // single(1/log(2)); + // const float l2uf = 0x1.62e4p-1f; // round(log(2), 24-8, RN); + // const float l2lf = 0x1.7f7d1cp-20f; // round(log(2) - l2uf, single, + // RN); + "vfmv.v.f v4, %[RoundingBias] \n\t" + "vfmacc.vf v4, %[Log2Reciprocal], v0 \n\t" // biased in mlas + "vfsub.vf v8, v4, %[RoundingBias] \n\t" // v12_a = float(x - n); + + // Use Cody-Waite range reduction method (note two constants to + // represent log(2)) to improve accuracy. + "vfmacc.vf v0, %[Log2High], v8 \n\t" + "vfmacc.vf v0, %[Log2Low], v8 \n\t" + "vfcvt.x.f.V v8, v4 \n\t" + + // 2.1.1. Approximate e^s by degree-6 polynomial approximation + "vfmv.v.f v4, %[poly_0] \n\t" + "vfmv.v.f v12, %[poly_1] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_2] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_3] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_4] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_56] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" + "vfmv.v.f v12, %[poly_56] \n\t" + "vfmadd.vv v4, v0, v12 \n\t" // v8 = poly(input - max) + + // 2.1.2. Reconstruction: compute u = u*2^q + // const int16_t p = (24 - 1); + // const int16_t bias = (128 - 1); + "vsll.vi v8, v8, 23 \n\t" + "vadd.vx v8, v8, %[MaximumExponent] \n\t" + //"vfcvt.f.x.v v12, v8 \n\t" + + "vfmul.vv v0, v4, v8 \n\t" + /* 2.1 END exp() */ + + "vse32.v v0, (s2) \n\t" // exp(输入-max)输出 + "sh2add s2, t0, s2 \n\t" + "vfadd.vv v16, v16, v0 \n\t" + "sub t3, t3, t0 \n\t" + "bgtz t3, _EXPACC_LEN_LPST \n\t" + + "_EXPACC_LEN_LPND: \n\t" + + "vsetvli t0, zero, e32, m4,tu,mu \n\t" + "vxor.vv v24, v8, v8 \n\t" + "vfredosum.vs v24, v16, v24 \n\t" + "vfmv.f.s %[RTN], v24 \n\t" // ft2 = sum(exp( )) + + : [ RTN ] "=f"(Accumulator), [ SRC ] "+r"(src), [ DST ] "+r"(dst) + : [ LEN ] "r"(N), [ NEG_MAX ] "f"(Neg_Max), [ LowerRange ] "f"(LowerRange), [ UpperRange ] "f"(UpperRange), + [ LowerRangeSumExp ] "f"(LowerRangeSumExp), [ UpperRangeSumExp ] "f"(UpperRangeSumExp), + [ RoundingBias ] "f"(RoundingBias), [ Log2Reciprocal ] "f"(Log2Reciprocal), [ Log2High ] "f"(Log2High), + [ Log2Low ] "f"(Log2Low), [ poly_0 ] "f"(poly_0), [ poly_1 ] "f"(poly_1), [ poly_2 ] "f"(poly_2), + [ poly_3 ] "f"(poly_3), [ poly_4 ] "f"(poly_4), [ poly_56 ] "f"(poly_56), + [ MaximumExponent ] "r"(MaximumExponent) + : "cc", "s1", "s2", "t0", "t3"); + sum += (ggml_float)Accumulator; #endif for (; i < n; ++i) { float val = expf(x[i] - max); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 2250d93cb00d1..c2b062b0b6fdd 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -51,7 +51,108 @@ inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { fo inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_cpy_i8(const int n, int8_t * y, const int8_t * x) { +#if defined(__riscv) && defined(__riscv_v) && defined(__riscv_v_intrinsic) + size_t vlenb = __riscv_vlenb(); + if (vlenb == 32) { + // 1024 bytes + __asm__ volatile( // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8,tu,mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8,tu,mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [ s ] "+r"(x), [ d ] "+r"(y) + : [ size ] "r"(n) + : "cc", "t0", "t1"); + } else if (vlenb == 128) { + // 2048 bytes + __asm__ volatile( // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8,tu,mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 1024 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 1024 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 1024 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 1024 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8,tu,mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [ s ] "+r"(x), [ d ] "+r"(y) + : [ size ] "r"(n) + : "cc", "t0", "t1"); + } else { + __asm__ volatile( // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8,tu,mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [ s ] "+r"(x), [ d ] "+r"(y) + : [ size ] "r"(n) + : "cc", "t0", "t1"); + } +#else + memcpy(y, x, n); +#endif +} + +inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { +#if defined(__riscv) && defined(__riscv_v) && defined(__riscv_v_intrinsic) + ggml_vec_cpy_i8(n * sizeof(int32_t), (int8_t *)y, (const int8_t *)x); +#else + for (int i = 0; i < n; ++i) y[i] = x[i]; +#endif +} inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } @@ -65,6 +166,25 @@ inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, co __m256 vz = _mm256_add_ps(vx, vy); _mm256_storeu_ps(z + i, vz); } +#elif defined(__riscv) && defined(__riscv_v) + size_t N = n; + i += n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "sub %[n], %[n], t0 \n\t" + "slli t0, t0, 2 \n\t" + "vle32.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t0 \n\t" + "vle32.v v8, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t0 \n\t" + "vfadd.vv v0, v0, v8 \n\t" + "vse32.v v0, (%[z]) \n\t" + "add %[z], %[z], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ z ] "+r"(z) + : + : "cc", "t0"); #endif for (; i < n; ++i) { z[i] = x[i] + y[i]; @@ -86,7 +206,13 @@ inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp } } inline static void ggml_vec_set_f32 (const int n, float * x, const float v) { for (int i = 0; i < n; ++i) x[i] = v; } -inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } +inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x) { +#if defined(__riscv) && defined(__riscv_v) + ggml_vec_cpy_i8(n * sizeof(float), (int8_t *)y, (const int8_t *)x); +#else + for (int i = 0; i < n; ++i) y[i] = x[i]; +#endif +} inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = -x[i]; } inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { @@ -94,7 +220,29 @@ inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp } } -inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; } +inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { +#if defined(__riscv) && defined(__riscv_v) + int N = n; + __asm__ volatile( + "LOOP%=: \t\n" + "vsetvli t0, %[n], e32, m4\t\n" + "sub %[n], %[n], t0 \t\n" + "slli t0, t0, 2 \t\n" + "vle32.v v0, (%[lhs]) \t\n" + "add %[lhs], %[lhs], t0 \t\n" + "vle32.v v8, (%[rhs]) \t\n" + "add %[rhs], %[rhs], t0 \t\n" + "vfmul.vv v0, v0, v8 \t\n" + "vse32.v v0, (%[z]) \t\n" + "add %[z], %[z], t0 \t\n" + "bnez %[n], LOOP%= \t\n" + : [ n ] "+r"(N), [ lhs ] "+r"(x), [ rhs ] "+r"(y), [ z ] "+r"(z) + : + : "cc", "t0"); +#else + for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; +#endif +} inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { for (int i = 0; i < n; ++i) { z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) * GGML_CPU_FP16_TO_FP32(y[i])); @@ -297,6 +445,28 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } +#elif defined(__riscv) && defined(__riscv_v) + size_t N = n; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "slli t1, t0, 1 \n\t" + "vle16.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t1 \n\t" + "vle16.v v2, (%[rhs]) \n\t" + "vfwcvt.f.f.v v4, v0 \n\t" + "vfwcvt.f.f.v v8, v2 \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "vfmacc.vf v8, %[fs], v4 \n\t" + "vsetvli t0, %[n], e16, m2,tu,mu \n\t" + "vfncvt.f.f.w v12, v8 \n\t" + "vse16.v v12, (%[rhs]) \n\t" + "add %[rhs], %[rhs], t1 \n\t" + "sub %[n], %[n], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ lhs ] "+r"(x), [ rhs ] "+r"(y) + : [ fs ] "f"(v) + : "cc", "t0", "t1"); #else // scalar for (int i = 0; i < n; ++i) { @@ -457,6 +627,23 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { y[i] *= v; } #endif +#elif defined(__riscv) && defined(__riscv_v) + size_t N = n; + float* out = y; + __asm__ volatile( + "LOOP%=: \n\t" + "vsetvli t0, %[n], e32, m4,tu,mu \n\t" + "sub %[n], %[n], t0 \n\t" + "slli t0, t0, 2 \n\t" + "vle32.v v0, (%[lhs]) \n\t" + "add %[lhs], %[lhs], t0 \n\t" + "vfmul.vf v0, v0, %[rhs] \n\t" + "vse32.v v0, (%[out]) \n\t" + "add %[out], %[out], t0 \n\t" + "bnez %[n], LOOP%= \n\t" + : [ n ] "+r"(N), [ lhs ] "+r"(y), [ out ] "+r"(out) + : [ rhs ] "f"(v) + : "cc", "t0"); #else // scalar for (int i = 0; i < n; ++i) { @@ -1090,6 +1277,27 @@ inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16 } inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { +#if defined(__riscv_v_intrinsic) +#ifdef __cplusplus + float * src = const_cast(reinterpret_cast(x)); +#else + float * src = (float *) x; +#endif + float Maximum = -INFINITY; + int64_t N = n; + size_t vl; + vfloat32m8_t vmaxf = __riscv_vfmv_v_f_f32m8(Maximum, __riscv_vsetvlmax_e32m8()); + for (; N > 0; N -= vl, src += vl) { + vl = __riscv_vsetvl_e32m8(N); + vfloat32m8_t vsrc = __riscv_vle32_v_f32m8(src, vl); + vmaxf = __riscv_vfmax_vv_f32m8(vmaxf, vsrc, vl); + } + vfloat32m1_t vmaxf_init = __riscv_vfmv_v_f_f32m1(Maximum, __riscv_vsetvlmax_e32m1()); + vl = __riscv_vsetvlmax_e32m8(); + vmaxf_init = __riscv_vfredmax_vs_f32m8_f32m1(vmaxf, vmaxf_init, vl); + float riscv_max = __riscv_vfmv_f_s_f32m1_f32(vmaxf_init); + *s = MAX(riscv_max, *s); +#else #ifndef GGML_USE_ACCELERATE float max = -INFINITY; for (int i = 0; i < n; ++i) { @@ -1099,6 +1307,7 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { #else vDSP_maxv(x, 1, s, n); #endif +#endif } inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { From 0f8d88f954b42ee0d2f73c49657f6716cb5a6357 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 14 Aug 2025 06:38:22 +0000 Subject: [PATCH 2/3] add new line at end of file Change-Id: I889ed1c85fb45e62350ecde0c06f70450cadfbe2 --- docs/build-riscv64-spacemit.md | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp | 2 +- ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/build-riscv64-spacemit.md b/docs/build-riscv64-spacemit.md index 87cd58d5053f6..b1356b8432097 100644 --- a/docs/build-riscv64-spacemit.md +++ b/docs/build-riscv64-spacemit.md @@ -84,4 +84,4 @@ Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04| Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10| Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08| Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04| -Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| \ No newline at end of file +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp index 24f6d328be3e3..3224b61306265 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.cpp @@ -1053,4 +1053,4 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { }; return &ggml_backend_cpu_buffer_type_riscv64_spacemit; -} \ No newline at end of file +} diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h index 8020508ac2eef..ab71bfae4e7b0 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime.h @@ -6,4 +6,4 @@ // #include // GGML internal header -ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); \ No newline at end of file +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp index 947490c8d9ab0..c51a9bfb355d4 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.cpp @@ -3216,4 +3216,4 @@ SQ4BitGemmKernel_CompInt8(size_t BlkLen, return 1; } } -} \ No newline at end of file +} diff --git a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h index e112e7fd7d050..af62472aabf84 100644 --- a/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ggml_spacemit_ime_kernels.h @@ -27,4 +27,4 @@ QuantizeARow_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* Q void QuantizeAM4Row_CompInt8(size_t BlkLen, const float* A, size_t CountK, std::byte* QuantA); -} \ No newline at end of file +} From 2ad4450e5f63b61bbec47999cdfbf41338e60c55 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 14 Aug 2025 06:57:14 +0000 Subject: [PATCH 3/3] add riscv zba extension limit Change-Id: I321eb200f859751727afe5cae13074dfce2bb0ce --- ggml/src/ggml-cpu/vec.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 5130b5dfcb6ff..982cf5497ccfb 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -460,7 +460,7 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float vst1q_f32(y + i, val); sum += (ggml_float)vaddvq_f32(val); } -#elif defined(__riscv) && defined(__riscv_v) +#elif defined(__riscv) && defined(__riscv_v) && defined(__riscv_zba) int N = n; i += n; float* src = const_cast(reinterpret_cast(x));