diff --git a/mllm/backends/cpu/kernels/Kernels.hpp b/mllm/backends/cpu/kernels/Kernels.hpp index e8c05dfac..ac3f70c81 100644 --- a/mllm/backends/cpu/kernels/Kernels.hpp +++ b/mllm/backends/cpu/kernels/Kernels.hpp @@ -8,12 +8,14 @@ #include "mllm/utils/CPUArchHelper.hpp" #if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) -#include "mllm/backends/cpu/kernels/x86/fill.hpp" // IWYU pragma: export -#include "mllm/backends/cpu/kernels/x86/silu.hpp" // IWYU pragma: export -#include "mllm/backends/cpu/kernels/x86/sigmoid.hpp" // IWYU pragma: export -#include "mllm/backends/cpu/kernels/x86/softmax.hpp" // IWYU pragma: export -#include "mllm/backends/cpu/kernels/x86/rmsnorm.hpp" // IWYU pragma: export -#include "mllm/backends/cpu/kernels/x86/gelu.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/fill.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/silu.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/sigmoid.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/softmax.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/rmsnorm.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/gelu.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/transpose.hpp" // IWYU pragma: export +#include "mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp" // IWYU pragma: export #endif #if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) diff --git a/mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.cpp b/mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.cpp new file mode 100644 index 000000000..c1e66d25a --- /dev/null +++ b/mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.cpp @@ -0,0 +1,531 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include + +#include "mllm/utils/CPUArchHelper.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp" + +namespace mllm::cpu::x86 { + +// Optimized for decoding. +// Q: [1, D] +// K: [S, D] +// D is small in mllm's case(small language model). +// D=64, 96, 128 ... +void __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk_baseline(const int M, const int K, const int N, + mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, + bool transpose_b, int thread_count) { + assert(M == 1 && "Q must have shape [1, D]"); + const int S = N; + const int D = K; + + for (int s = 0; s < S; ++s) { + dst[s] = C ? C[s] : 0.0f; + for (int d = 0; d < D; ++d) { dst[s] += A[d] * B[s * D + d]; } + } +} + +// Optimized for decoding using AVX/AVX2. +// Q: [1, D] +// K: [S, D] +// D is small in mllm's case(small language model). +// D=64, 96, 128 ... +void __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, + int thread_count) { + assert(M == 1 && "Q (A) must have shape [1, D]"); + const int S = N; + const int D = K; + +#if defined(MLLM_HOST_FEATURE_AVX) || defined(MLLM_HOST_FEATURE_AVX2) + // AVX processes 8 floats at a time + const int DTileSize = 32; + const int DTileCount = D / DTileSize; + const int DRemainder = D % DTileSize; + + for (int s = 0; s < S; ++s) { + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + const int s_offset = s * D; + + for (int d = 0; d < DTileCount; ++d) { + const int d_offset = d * DTileSize; + + __m256 a0 = _mm256_loadu_ps(A + d_offset); + __m256 b0 = _mm256_loadu_ps(B + s_offset + d_offset); +#if defined(MLLM_HOST_FEATURE_FMA) + acc0 = _mm256_fmadd_ps(a0, b0, acc0); +#else + acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(a0, b0)); +#endif + + __m256 a1 = _mm256_loadu_ps(A + d_offset + 8); + __m256 b1 = _mm256_loadu_ps(B + s_offset + d_offset + 8); +#if defined(MLLM_HOST_FEATURE_FMA) + acc1 = _mm256_fmadd_ps(a1, b1, acc1); +#else + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(a1, b1)); +#endif + + __m256 a2 = _mm256_loadu_ps(A + d_offset + 16); + __m256 b2 = _mm256_loadu_ps(B + s_offset + d_offset + 16); +#if defined(MLLM_HOST_FEATURE_FMA) + acc2 = _mm256_fmadd_ps(a2, b2, acc2); +#else + acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(a2, b2)); +#endif + + __m256 a3 = _mm256_loadu_ps(A + d_offset + 24); + __m256 b3 = _mm256_loadu_ps(B + s_offset + d_offset + 24); +#if defined(MLLM_HOST_FEATURE_FMA) + acc3 = _mm256_fmadd_ps(a3, b3, acc3); +#else + acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(a3, b3)); +#endif + } + + // Combine accumulators + __m256 sum01 = _mm256_add_ps(acc0, acc1); + __m256 sum23 = _mm256_add_ps(acc2, acc3); + __m256 sum0123 = _mm256_add_ps(sum01, sum23); + + // Horizontal sum of __m256 + // sum0123 = [a0, a1, a2, a3, a4, a5, a6, a7] + __m128 hi = _mm256_extractf128_ps(sum0123, 1); // [a4, a5, a6, a7] + __m128 lo = _mm256_castps256_ps128(sum0123); // [a0, a1, a2, a3] + __m128 sum128 = _mm_add_ps(lo, hi); // [a0+a4, a1+a5, a2+a6, a3+a7] + __m128 shuf = _mm_movehdup_ps(sum128); // [a1+a5, a1+a5, a3+a7, a3+a7] + __m128 sums = _mm_add_ps(sum128, shuf); // [a0+a1+a4+a5, _, a2+a3+a6+a7, _] + shuf = _mm_movehl_ps(shuf, sums); // [a2+a3+a6+a7, _, _, _] + sums = _mm_add_ss(sums, shuf); + float result = _mm_cvtss_f32(sums); + + // Handle remainder + int d_start = DTileCount * DTileSize; + for (int d = d_start; d < D; ++d) { result += A[d] * B[s_offset + d]; } + + if (C) { + dst[s] = result + C[s]; + } else { + dst[s] = result; + } + } +#else + // Fallback to baseline + __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk_baseline(M, K, N, dst, A, B, C, transpose_a, transpose_b, thread_count); +#endif +} + +// Optimized for decoding. +// W: [B, H, 1, S] +// V: [B, H, S, D] +// D is small in mllm's case(small language model). +// D=64, 96, 128 ... +void __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv_baseline(const int M, const int K, const int N, + mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, + bool transpose_b, int thread_count) { + for (int n = 0; n < N; ++n) { + float sum = 0.0f; + if (C != nullptr) { sum = C[n]; } + for (int k = 0; k < K; ++k) { sum += A[k] * B[k * N + n]; } + dst[n] = sum; + } +} + +// Optimized for decoding using AVX/AVX2. +// W: [B, H, 1, S] +// V: [B, H, S, D] +// D is small in mllm's case(small language model). +// D=64, 96, 128 ... +void __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, + int thread_count) { +#if defined(MLLM_HOST_FEATURE_AVX) || defined(MLLM_HOST_FEATURE_AVX2) + // Initialize dst with C or zeros + if (C != nullptr) { + int n = 0; + for (; n <= N - 8; n += 8) { + __m256 c_vec = _mm256_loadu_ps(C + n); + _mm256_storeu_ps(dst + n, c_vec); + } + for (; n < N; ++n) { dst[n] = C[n]; } + } else { + int n = 0; + for (; n <= N - 8; n += 8) { _mm256_storeu_ps(dst + n, _mm256_setzero_ps()); } + for (; n < N; ++n) { dst[n] = 0.0f; } + } + + int k = 0; + for (; k <= K - 4; k += 4) { + __m256 a_vec = + _mm256_set_ps(A[k + 3], A[k + 3], A[k + 2], A[k + 2], A[k + 1], A[k + 1], A[k], A[k]); // For broadcasting later + float a0 = A[k + 0]; + float a1 = A[k + 1]; + float a2 = A[k + 2]; + float a3 = A[k + 3]; + + int n = 0; + for (; n <= N - 8; n += 8) { + __m256 dst_vec = _mm256_loadu_ps(dst + n); + + __m256 b_vec0 = _mm256_loadu_ps(B + (k + 0) * N + n); + __m256 b_vec1 = _mm256_loadu_ps(B + (k + 1) * N + n); + __m256 b_vec2 = _mm256_loadu_ps(B + (k + 2) * N + n); + __m256 b_vec3 = _mm256_loadu_ps(B + (k + 3) * N + n); + + __m256 a0_vec = _mm256_set1_ps(a0); + __m256 a1_vec = _mm256_set1_ps(a1); + __m256 a2_vec = _mm256_set1_ps(a2); + __m256 a3_vec = _mm256_set1_ps(a3); + +#if defined(MLLM_HOST_FEATURE_FMA) + dst_vec = _mm256_fmadd_ps(b_vec0, a0_vec, dst_vec); + dst_vec = _mm256_fmadd_ps(b_vec1, a1_vec, dst_vec); + dst_vec = _mm256_fmadd_ps(b_vec2, a2_vec, dst_vec); + dst_vec = _mm256_fmadd_ps(b_vec3, a3_vec, dst_vec); +#else + dst_vec = _mm256_add_ps(dst_vec, _mm256_mul_ps(b_vec0, a0_vec)); + dst_vec = _mm256_add_ps(dst_vec, _mm256_mul_ps(b_vec1, a1_vec)); + dst_vec = _mm256_add_ps(dst_vec, _mm256_mul_ps(b_vec2, a2_vec)); + dst_vec = _mm256_add_ps(dst_vec, _mm256_mul_ps(b_vec3, a3_vec)); +#endif + + _mm256_storeu_ps(dst + n, dst_vec); + } + + // Handle remainder + for (; n < N; ++n) { + float sum = dst[n]; + sum += a0 * B[(k + 0) * N + n]; + sum += a1 * B[(k + 1) * N + n]; + sum += a2 * B[(k + 2) * N + n]; + sum += a3 * B[(k + 3) * N + n]; + dst[n] = sum; + } + } + + // Handle remaining k + for (; k < K; ++k) { + float a_val = A[k]; + __m256 a_vec = _mm256_set1_ps(a_val); + + int n = 0; + for (; n <= N - 8; n += 8) { + __m256 b_vec = _mm256_loadu_ps(B + k * N + n); + __m256 dst_vec = _mm256_loadu_ps(dst + n); +#if defined(MLLM_HOST_FEATURE_FMA) + dst_vec = _mm256_fmadd_ps(b_vec, a_vec, dst_vec); +#else + dst_vec = _mm256_add_ps(dst_vec, _mm256_mul_ps(b_vec, a_vec)); +#endif + _mm256_storeu_ps(dst + n, dst_vec); + } + for (; n < N; ++n) { dst[n] += a_val * B[k * N + n]; } + } +#else + // Fallback to baseline + __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv_baseline(M, K, N, dst, A, B, C, transpose_a, transpose_b, thread_count); +#endif +} + +void __mllm_blas_matmul_fp32_gemv(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, int thread_count) { + if (!transpose_a && transpose_b) { + __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk(M, K, N, dst, A, B, C, transpose_a, transpose_b, thread_count); + } else if (!transpose_a && !transpose_b) { + __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv(M, K, N, dst, A, B, C, transpose_a, transpose_b, thread_count); + } else { + NYI("transpose_a && transpose_b"); + } +} + +void __mllm_blas_batch_matmul_fp32_gemv(const int BATCH, const int M, const int K, const int N, const int Dst_batch_stride, + const int A_batch_stride, const int B_batch_stride, const int C_batch_stride, + mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, const mllm_fp32_t* __restrict__ C, bool transpose_a, + bool transpose_b, int thread_count) { + if (!transpose_a && transpose_b) { + if (thread_count > 1) { + __mllm_blas_batch_matmul_fp32_gemv_nt_t_decode_small_d_qk(BATCH, M, K, N, Dst_batch_stride, A_batch_stride, + B_batch_stride, C_batch_stride, dst, A, B, C, transpose_a, + transpose_b, thread_count); + + } else { + __mllm_blas_batch_matmul_fp32_gemv_nt_t_decode_small_d_qk(BATCH, M, K, N, Dst_batch_stride, A_batch_stride, + B_batch_stride, C_batch_stride, dst, A, B, C, + transpose_a, transpose_b, thread_count); + } + } else if (!transpose_a && !transpose_b) { + if (thread_count > 1) { + __mllm_blas_batch_matmul_fp32_gemv_nt_nt_decode_small_d_wv(BATCH, M, K, N, Dst_batch_stride, A_batch_stride, + B_batch_stride, C_batch_stride, dst, A, B, C, + transpose_a, transpose_b, thread_count); + } else { + __mllm_blas_batch_matmul_fp32_gemv_nt_nt_decode_small_d_wv(BATCH, M, K, N, Dst_batch_stride, A_batch_stride, + B_batch_stride, C_batch_stride, dst, A, B, C, + transpose_a, transpose_b, thread_count); + } + } else { + NYI("transpose_a && transpose_b"); + } +} + +namespace { +__MLLM_UNSAFE_OPT_BEGIN_O3_FAST_MATH +static inline void dispatch_tile(int rm, int rn, const float* a, int64_t lda, const float* b, int64_t ldb, float* c, + int64_t ldc, int64_t k) { +#if defined(MLLM_HOST_FEATURE_AVX) || defined(MLLM_HOST_FEATURE_AVX2) +#define KERNEL(__tile_m, __tile_n) \ + case (__tile_m << 8) | __tile_n: MicroKernel<__tile_m, __tile_n>::accumulate(a, lda, b, ldb, c, ldc, k); break; + + switch ((std::min(rm, 8) << 8) | std::min(rn, 8)) { + // AVX optimized kernels + KERNEL(8, 8) + KERNEL(4, 8) + KERNEL(1, 8) + // General GEMV, M = 1, decode + KERNEL(1, 1) + KERNEL(1, 2) + KERNEL(1, 3) + KERNEL(1, 4) + KERNEL(1, 5) + KERNEL(1, 6) + KERNEL(1, 7) + // Compiler Optimized Kernel + KERNEL(2, 2) + KERNEL(2, 4) + KERNEL(2, 6) + KERNEL(2, 8) + KERNEL(4, 2) + KERNEL(4, 4) + KERNEL(4, 6) + default: { + auto _rm = std::min(rm, 8); + auto _rn = std::min(rn, 8); + for (int i = 0; i < _rm; ++i) { + for (int j = 0; j < _rn; ++j) { c[i * ldc + j] = 0; } + } + for (int64_t l = 0; l < k; ++l) { + for (int i = 0; i < _rm; ++i) { + const float ai = a[i * lda + l]; + for (int j = 0; j < _rn; ++j) { c[i * ldc + j] += ai * b[l * ldb + j]; } + } + } + break; + } + } +#undef KERNEL +#else + // SSE or scalar fallback + auto _rm = std::min(rm, 8); + auto _rn = std::min(rn, 8); + for (int i = 0; i < _rm; ++i) { + for (int j = 0; j < _rn; ++j) { c[i * ldc + j] = 0; } + } + for (int64_t l = 0; l < k; ++l) { + for (int i = 0; i < _rm; ++i) { + const float ai = a[i * lda + l]; + for (int j = 0; j < _rn; ++j) { c[i * ldc + j] += ai * b[l * ldb + j]; } + } + } +#endif +} +__MLLM_UNSAFE_OPT_END + +__MLLM_UNSAFE_OPT_BEGIN_O3_FAST_MATH +static inline void dispatch_tile_nt_t(int rm, int rn, const float* a, int64_t lda, const float* b, int64_t ldb, float* c, + int64_t ldc, int64_t k, const float* bias) { +#define KERNEL(__tile_m, __tile_n) \ + case (__tile_m << 8) | __tile_n: \ + MicroKernel_NT_T_Bias<__tile_m, __tile_n>::accumulate(a, lda, b, ldb, c, ldc, k, bias); \ + break; + + switch ((std::min(rm, 8) << 8) | std::min(rn, 8)) { + // Compiler Optimized Kernel + KERNEL(8, 8) + KERNEL(4, 8) + KERNEL(1, 8) + // General GEMV, M = 1, decode + KERNEL(1, 1) + KERNEL(1, 2) + KERNEL(1, 3) + KERNEL(1, 4) + KERNEL(1, 5) + KERNEL(1, 6) + KERNEL(1, 7) + // Compiler Optimized Kernel + KERNEL(2, 2) + KERNEL(2, 4) + KERNEL(2, 6) + KERNEL(2, 8) + KERNEL(4, 2) + KERNEL(4, 4) + KERNEL(4, 6) + default: { + auto _rm = std::min(rm, 8); + auto _rn = std::min(rn, 8); + for (int i = 0; i < _rm; ++i) { + for (int j = 0; j < _rn; ++j) { + float sum = 0.0f; + for (int64_t l = 0; l < k; ++l) { sum += a[i * lda + l] * b[j * ldb + l]; } + c[i * ldc + j] = sum; + } + } + if (bias != nullptr) { + for (int i = 0; i < _rm; ++i) { + for (int j = 0; j < _rn; ++j) { c[i * ldc + j] += bias[j]; } + } + } + break; + } + } + +#undef KERNEL +} +__MLLM_UNSAFE_OPT_END +} // namespace + +bool __mllm_blas_sgemm_nt_nt(int64_t m, int64_t n, int64_t k, const float* A, int64_t lda, const float* B, int64_t ldb, + float* C, int64_t ldc, int ith, int thread_count) { + if (m <= 0 || n <= 0 || k <= 0) return false; + if (lda < k || ldb < n || ldc < n) return false; + if (thread_count <= 0 || ith < 0 || ith >= thread_count) return false; + + // Dynamic tiling - use 8x8 for AVX (8 floats per register) + int64_t mc = 8, nc = 8; + if (m < 8) mc = 4; + if (m < 4) mc = 1; + if (n < 8) nc = 4; + if (n < 4) nc = 1; + + int64_t yt = (m + mc - 1) / mc, xt = (n + nc - 1) / nc; + int64_t tiles = yt * xt; + + MLLM_CONDITIONAL_PARALLEL_FOR(thread_count > 1, thread_count, job, 0, tiles, 1, { + int64_t ii = (job / xt) * mc; + int64_t jj = (job % xt) * nc; + int64_t rm = std::min(mc, m - ii); + int64_t rn = std::min(nc, n - jj); + dispatch_tile(rm, rn, &A[ii * lda], lda, &B[jj], ldb, &C[ii * ldc + jj], ldc, k); + }); + return true; +} + +bool __mllm_blas_sgemm_nt_t(int64_t m, int64_t n, int64_t k, const float* A, int64_t lda, const float* B, int64_t ldb, float* C, + int64_t ldc, int ith, const float* bias, int thread_count) { + if (m <= 0 || n <= 0 || k <= 0) return false; + if (lda < k || ldb < k || ldc < n) return false; + if (thread_count <= 0 || ith < 0 || ith >= thread_count) return false; + + // Dynamic tiling - use 8x8 for AVX + int64_t mc = 8, nc = 8; + if (m < 8) mc = 4; + if (m < 4) mc = 1; + if (n < 8) nc = 4; + if (n < 4) nc = 1; + + int64_t yt = (m + mc - 1) / mc, xt = (n + nc - 1) / nc; + int64_t tiles = yt * xt; + + MLLM_CONDITIONAL_PARALLEL_FOR(thread_count > 1, thread_count, job, 0, tiles, 1, { + int64_t ii = (job / xt) * mc; + int64_t jj = (job % xt) * nc; + int64_t rm = std::min(mc, m - ii); + int64_t rn = std::min(nc, n - jj); + dispatch_tile_nt_t(rm, rn, &A[ii * lda], lda, &B[jj * ldb], ldb, &C[ii * ldc + jj], ldc, k, bias ? &bias[jj] : bias); + }); + return true; +} + +void mllm_blas_matmul_fp32(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, int thread_count) { + // MxK, KxN + if (!transpose_a && !transpose_b) { + // gemv + if (M == 1) { + __mllm_blas_matmul_fp32_gemv(M, K, N, dst, A, B, C, transpose_a, transpose_b, thread_count); + } else + // gemm + { + if (C) { NYI("C not supported in mllm_blas_matmul_fp32::__mllm_blas_sgemm_nt_nt"); } + __mllm_blas_sgemm_nt_nt(M, N, K, A, K, B, N, dst, N, 0, thread_count); + } + return; + } else if (!transpose_a && transpose_b) + // MxK, NxK + { + // gemv + if (M == 1) { + __mllm_blas_matmul_fp32_gemv(M, K, N, dst, A, B, C, transpose_a, transpose_b, thread_count); + } else + // gemm + { + __mllm_blas_sgemm_nt_t(M, N, K, A, K, B, K, dst, N, 0, C, thread_count); + } + return; + } else { + NYI("transpose_a && transpose_b not supported not supported in mllm_blas_matmul_fp32 gemm/gemv"); + } +} + +void mllm_blas_batch_matmul_fp32(const int BATCH, const int M, const int K, const int N, const int Dst_batch_stride, + const int A_batch_stride, const int B_batch_stride, const int C_batch_stride, + mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, const mllm_fp32_t* __restrict__ C, bool transpose_a, + bool transpose_b, int thread_count) { + // MxK, KxN + if (!transpose_a && !transpose_b) { + // gemv + if (M == 1) { + __mllm_blas_batch_matmul_fp32_gemv(BATCH, M, K, N, Dst_batch_stride, A_batch_stride, B_batch_stride, C_batch_stride, dst, + A, B, C, transpose_a, transpose_b, thread_count); + } else + // gemm + { + if (C) { NYI("C not supported in mllm_blas_batch_matmul_fp32::__mllm_blas_sgemm_nt_nt"); } + // Parallel is in the inner loops, not here. + for (int i = 0; i < BATCH; ++i) { + __mllm_blas_sgemm_nt_nt(M, N, K, A + i * A_batch_stride, K, B + i * B_batch_stride, N, dst + i * Dst_batch_stride, N, 0, + thread_count); + } + } + return; + } else if (!transpose_a && transpose_b) + // MxK, NxK + { + // gemv + if (M == 1) { + __mllm_blas_batch_matmul_fp32_gemv(BATCH, M, K, N, Dst_batch_stride, A_batch_stride, B_batch_stride, C_batch_stride, dst, + A, B, C, transpose_a, transpose_b, thread_count); + } else + // gemm + { + // Parallel is in the inner loops, not here. + for (int i = 0; i < BATCH; ++i) { + __mllm_blas_sgemm_nt_t(M, N, K, A + i * A_batch_stride, K, B + i * B_batch_stride, K, dst + i * Dst_batch_stride, N, 0, + C, thread_count); + } + } + return; + } else { + NYI("transpose_a && transpose_b not supported not supported in mllm_blas_matmul_fp32 gemm/gemv"); + } +} + +} // namespace mllm::cpu::x86 diff --git a/mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp b/mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp new file mode 100644 index 000000000..2b76de4aa --- /dev/null +++ b/mllm/backends/cpu/kernels/x86/mllm_blas/mllm_blas_sgemm.hpp @@ -0,0 +1,328 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include "mllm/core/DataTypes.hpp" +#include "mllm/core/Parallel.hpp" +#include "mllm/utils/UnsafeMacros.hpp" + +namespace mllm::cpu::x86 { + +// Optimized for decoding. +// Q: [1, D] +// K: [S, D] +// D is small in mllm's case(small language model). +// D=64, 96, 128 ... +void __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk_baseline( + const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, int thread_count); + +// Optimized for decoding. +// Q: [1, D] +// K: [S, D] +// D is small in mllm's case(small language model). +// D=64, 96, 128 ... +void __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, + int thread_count); + +// Optimized for decoding. +// Q: [B, H, 1, D] +// K: [B, H, S, D] +// D is small in mllm's case(small language model). +// D=64, 96, 128 ... +template +void __mllm_blas_batch_matmul_fp32_gemv_nt_t_decode_small_d_qk(const int BATCH, const int M, const int K, const int N, + const int Dst_batch_stride, const int A_batch_stride, + const int B_batch_stride, const int C_batch_stride, + mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, + bool transpose_b, int thread_count) { + if constexpr (__enable_thread) { + MLLM_AUTO_PARALLEL_FOR_BEGIN_NT(b, 0, BATCH, 1, thread_count) { + auto a_ptr = A + b * A_batch_stride; + auto b_ptr = B + b * B_batch_stride; + auto c_ptr = C + b * C_batch_stride; + auto d_ptr = dst + b * Dst_batch_stride; + __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk(M, K, N, d_ptr, a_ptr, b_ptr, c_ptr, transpose_a, transpose_b, 0); + } + MLLM_AUTO_PARALLEL_FOR_END_NT() + } else { + for (int b = 0; b < BATCH; ++b) { + auto a_ptr = A + b * A_batch_stride; + auto b_ptr = B + b * B_batch_stride; + auto c_ptr = C + b * C_batch_stride; + auto d_ptr = dst + b * Dst_batch_stride; + __mllm_blas_matmul_fp32_gemv_nt_t_decode_small_d_qk(M, K, N, d_ptr, a_ptr, b_ptr, c_ptr, transpose_a, transpose_b, 0); + } + } +} + +// Optimized for decoding. +// W: [B, H, 1, S] +// V: [B, H, S, D] +// D is small in mllm's case(small language model). +void __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv_baseline( + const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, int thread_count); + +// Optimized for decoding. +// W: [B, H, 1, S] +// V: [B, H, S, D] +// D is small in mllm's case(small language model). +void __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, + int thread_count); + +// Optimized for decoding. +// W: [B, H, 1, S] +// V: [B, H, S, D] +// D is small in mllm's case(small language model). +template +void __mllm_blas_batch_matmul_fp32_gemv_nt_nt_decode_small_d_wv( + const int BATCH, const int M, const int K, const int N, const int Dst_batch_stride, const int A_batch_stride, + const int B_batch_stride, const int C_batch_stride, mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, + int thread_count) { + if constexpr (__enable_thread) { + MLLM_AUTO_PARALLEL_FOR_BEGIN_NT(b, 0, BATCH, 1, thread_count) { + auto a_ptr = A + b * A_batch_stride; + auto b_ptr = B + b * B_batch_stride; + auto c_ptr = C + b * C_batch_stride; + auto d_ptr = dst + b * Dst_batch_stride; + __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv(M, K, N, d_ptr, a_ptr, b_ptr, c_ptr, transpose_a, transpose_b, 0); + } + MLLM_AUTO_PARALLEL_FOR_END_NT() + } else { + for (int b = 0; b < BATCH; ++b) { + auto a_ptr = A + b * A_batch_stride; + auto b_ptr = B + b * B_batch_stride; + auto c_ptr = C + b * C_batch_stride; + auto d_ptr = dst + b * Dst_batch_stride; + __mllm_blas_matmul_fp32_gemv_nt_nt_decode_small_d_wv(M, K, N, d_ptr, a_ptr, b_ptr, c_ptr, transpose_a, transpose_b, 0); + } + } +} + +void __mllm_blas_matmul_fp32_gemv(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, int thread_count); + +void __mllm_blas_batch_matmul_fp32_gemv(const int BATCH, const int M, const int K, const int N, const int Dst_batch_stride, + const int A_batch_stride, const int B_batch_stride, const int C_batch_stride, + mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, const mllm_fp32_t* __restrict__ C, bool transpose_a, + bool transpose_b, int thread_count); + +#ifdef __cplusplus +extern "C" { +#endif + +// C = A * B (row-major, FP32) +// A : mxk B : kxn C : mxn +// lda = k, ldb = n, ldc = n +bool __mllm_blas_sgemm_nt_nt(int64_t m, int64_t n, int64_t k, const float* A, int64_t lda, const float* B, int64_t ldb, + float* C, int64_t ldc, int ith, int thread_count); + +#ifdef __cplusplus +} +#endif + +template +struct MicroKernel; + +#if defined(MLLM_HOST_FEATURE_AVX) || defined(MLLM_HOST_FEATURE_AVX2) +#include + +// AVX/AVX2 optimized 8x8 micro-kernel +template<> +struct MicroKernel<8, 8> { + static inline void accumulate(const float* a, int64_t lda, const float* b, int64_t ldb, float* c, int64_t ldc, int64_t k) { + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + __m256 acc4 = _mm256_setzero_ps(); + __m256 acc5 = _mm256_setzero_ps(); + __m256 acc6 = _mm256_setzero_ps(); + __m256 acc7 = _mm256_setzero_ps(); + + const float* a0_ptr = a; + const float* a1_ptr = a + lda; + const float* a2_ptr = a + 2 * lda; + const float* a3_ptr = a + 3 * lda; + const float* a4_ptr = a + 4 * lda; + const float* a5_ptr = a + 5 * lda; + const float* a6_ptr = a + 6 * lda; + const float* a7_ptr = a + 7 * lda; + + for (int64_t l = 0; l < k; ++l) { + __m256 b_vec = _mm256_loadu_ps(b + l * ldb); + + __m256 a0_vec = _mm256_set1_ps(a0_ptr[l]); + __m256 a1_vec = _mm256_set1_ps(a1_ptr[l]); + __m256 a2_vec = _mm256_set1_ps(a2_ptr[l]); + __m256 a3_vec = _mm256_set1_ps(a3_ptr[l]); + __m256 a4_vec = _mm256_set1_ps(a4_ptr[l]); + __m256 a5_vec = _mm256_set1_ps(a5_ptr[l]); + __m256 a6_vec = _mm256_set1_ps(a6_ptr[l]); + __m256 a7_vec = _mm256_set1_ps(a7_ptr[l]); + +#if defined(MLLM_HOST_FEATURE_FMA) + acc0 = _mm256_fmadd_ps(a0_vec, b_vec, acc0); + acc1 = _mm256_fmadd_ps(a1_vec, b_vec, acc1); + acc2 = _mm256_fmadd_ps(a2_vec, b_vec, acc2); + acc3 = _mm256_fmadd_ps(a3_vec, b_vec, acc3); + acc4 = _mm256_fmadd_ps(a4_vec, b_vec, acc4); + acc5 = _mm256_fmadd_ps(a5_vec, b_vec, acc5); + acc6 = _mm256_fmadd_ps(a6_vec, b_vec, acc6); + acc7 = _mm256_fmadd_ps(a7_vec, b_vec, acc7); +#else + acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(a0_vec, b_vec)); + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(a1_vec, b_vec)); + acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(a2_vec, b_vec)); + acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(a3_vec, b_vec)); + acc4 = _mm256_add_ps(acc4, _mm256_mul_ps(a4_vec, b_vec)); + acc5 = _mm256_add_ps(acc5, _mm256_mul_ps(a5_vec, b_vec)); + acc6 = _mm256_add_ps(acc6, _mm256_mul_ps(a6_vec, b_vec)); + acc7 = _mm256_add_ps(acc7, _mm256_mul_ps(a7_vec, b_vec)); +#endif + } + + _mm256_storeu_ps(c + 0 * ldc, acc0); + _mm256_storeu_ps(c + 1 * ldc, acc1); + _mm256_storeu_ps(c + 2 * ldc, acc2); + _mm256_storeu_ps(c + 3 * ldc, acc3); + _mm256_storeu_ps(c + 4 * ldc, acc4); + _mm256_storeu_ps(c + 5 * ldc, acc5); + _mm256_storeu_ps(c + 6 * ldc, acc6); + _mm256_storeu_ps(c + 7 * ldc, acc7); + } +}; + +// AVX/AVX2 optimized 4x8 micro-kernel +template<> +struct MicroKernel<4, 8> { + static inline void accumulate(const float* a, int64_t lda, const float* b, int64_t ldb, float* c, int64_t ldc, int64_t k) { + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + const float* a0_ptr = a; + const float* a1_ptr = a + lda; + const float* a2_ptr = a + 2 * lda; + const float* a3_ptr = a + 3 * lda; + + for (int64_t l = 0; l < k; ++l) { + __m256 b_vec = _mm256_loadu_ps(b + l * ldb); + + __m256 a0_vec = _mm256_set1_ps(a0_ptr[l]); + __m256 a1_vec = _mm256_set1_ps(a1_ptr[l]); + __m256 a2_vec = _mm256_set1_ps(a2_ptr[l]); + __m256 a3_vec = _mm256_set1_ps(a3_ptr[l]); + +#if defined(MLLM_HOST_FEATURE_FMA) + acc0 = _mm256_fmadd_ps(a0_vec, b_vec, acc0); + acc1 = _mm256_fmadd_ps(a1_vec, b_vec, acc1); + acc2 = _mm256_fmadd_ps(a2_vec, b_vec, acc2); + acc3 = _mm256_fmadd_ps(a3_vec, b_vec, acc3); +#else + acc0 = _mm256_add_ps(acc0, _mm256_mul_ps(a0_vec, b_vec)); + acc1 = _mm256_add_ps(acc1, _mm256_mul_ps(a1_vec, b_vec)); + acc2 = _mm256_add_ps(acc2, _mm256_mul_ps(a2_vec, b_vec)); + acc3 = _mm256_add_ps(acc3, _mm256_mul_ps(a3_vec, b_vec)); +#endif + } + + _mm256_storeu_ps(c + 0 * ldc, acc0); + _mm256_storeu_ps(c + 1 * ldc, acc1); + _mm256_storeu_ps(c + 2 * ldc, acc2); + _mm256_storeu_ps(c + 3 * ldc, acc3); + } +}; + +// AVX/AVX2 optimized 1x8 micro-kernel (GEMV) +template<> +struct MicroKernel<1, 8> { + static inline void accumulate(const float* a, int64_t lda, const float* b, int64_t ldb, float* c, int64_t ldc, int64_t k) { + __m256 acc = _mm256_setzero_ps(); + + for (int64_t l = 0; l < k; ++l) { + __m256 b_vec = _mm256_loadu_ps(b + l * ldb); + __m256 a_vec = _mm256_set1_ps(a[l]); +#if defined(MLLM_HOST_FEATURE_FMA) + acc = _mm256_fmadd_ps(a_vec, b_vec, acc); +#else + acc = _mm256_add_ps(acc, _mm256_mul_ps(a_vec, b_vec)); +#endif + } + + _mm256_storeu_ps(c, acc); + } +}; + +#endif // MLLM_HOST_FEATURE_AVX || MLLM_HOST_FEATURE_AVX2 + +// Generic fallback micro-kernel +template +struct MicroKernel { + __MLLM_UNSAFE_OPT_BEGIN_O3_FAST_MATH + static inline void accumulate(const float* a, int64_t lda, const float* b, int64_t ldb, float* c, int64_t ldc, + int64_t k) noexcept { + for (int i = 0; i < RM; ++i) { + for (int j = 0; j < RN; ++j) { c[i * ldc + j] = 0; } + } + for (int64_t l = 0; l < k; ++l) { + for (int i = 0; i < RM; ++i) { + const float ai = a[i * lda + l]; + for (int j = 0; j < RN; ++j) { c[i * ldc + j] += ai * b[l * ldb + j]; } + } + } + } + __MLLM_UNSAFE_OPT_END +}; + +template +struct MicroKernel_NT_T_Bias; + +template +struct MicroKernel_NT_T_Bias { + static inline void accumulate(const float* a, int64_t lda, const float* b, int64_t ldb, float* c, int64_t ldc, int64_t k, + const float* bias) { +#pragma unroll + for (int i = 0; i < RM; ++i) { +#pragma unroll + for (int j = 0; j < RN; ++j) { + float sum = 0.0f; + for (int64_t l = 0; l < k; ++l) { sum += a[i * lda + l] * b[j * ldb + l]; } + c[i * ldc + j] = sum; + } + } + if (bias != nullptr) { +#pragma unroll + for (int i = 0; i < RM; ++i) { +#pragma unroll + for (int j = 0; j < RN; ++j) { c[i * ldc + j] += bias[j]; } + } + } + } +}; + +bool __mllm_blas_sgemm_nt_t(int64_t m, int64_t n, int64_t k, const float* A, int64_t lda, const float* B, int64_t ldb, float* C, + int64_t ldc, int ith, const float* bias, int thread_count); + +void mllm_blas_matmul_fp32(const int M, const int K, const int N, mllm_fp32_t* __restrict__ dst, + const mllm_fp32_t* __restrict__ A, const mllm_fp32_t* __restrict__ B, + const mllm_fp32_t* __restrict__ C, bool transpose_a, bool transpose_b, int thread_count); + +void mllm_blas_batch_matmul_fp32(const int BATCH, const int M, const int K, const int N, const int Dst_batch_stride, + const int A_batch_stride, const int B_batch_stride, const int C_batch_stride, + mllm_fp32_t* __restrict__ dst, const mllm_fp32_t* __restrict__ A, + const mllm_fp32_t* __restrict__ B, const mllm_fp32_t* __restrict__ C, bool transpose_a, + bool transpose_b, int thread_count); + +} // namespace mllm::cpu::x86 diff --git a/mllm/backends/cpu/kernels/x86/transpose.cpp b/mllm/backends/cpu/kernels/x86/transpose.cpp new file mode 100644 index 000000000..ac936ba58 --- /dev/null +++ b/mllm/backends/cpu/kernels/x86/transpose.cpp @@ -0,0 +1,420 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/cpu/kernels/x86/transpose.hpp" + +#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + +#include +#include + +#if defined(MLLM_HOST_FEATURE_AVX512F) || defined(MLLM_HOST_FEATURE_AVX2) || defined(MLLM_HOST_FEATURE_AVX) +#include +#elif defined(MLLM_HOST_FEATURE_SSE2) +#include +#elif defined(MLLM_HOST_FEATURE_SSE) +#include +#endif + +namespace mllm::cpu::x86 { + +namespace { +void compute_strides(const int* shape, int ndim, int* strides) { + strides[ndim - 1] = 1; + for (int i = ndim - 2; i >= 0; --i) { strides[i] = strides[i + 1] * shape[i + 1]; } +} +} // namespace + +void transpose_hw_wh_fp32(const mllm_fp32_t* __restrict X, mllm_fp32_t* __restrict Y, size_t H, size_t W) { +#if defined(MLLM_HOST_FEATURE_SSE) + // Process 4x4 blocks using SSE + for (size_t i = 0; i + 4 <= H; i += 4) { + for (size_t j = 0; j + 4 <= W; j += 4) { + // Load 4 rows, each containing 4 floats + __m128 r0 = _mm_loadu_ps(X + i * W + j); + __m128 r1 = _mm_loadu_ps(X + (i + 1) * W + j); + __m128 r2 = _mm_loadu_ps(X + (i + 2) * W + j); + __m128 r3 = _mm_loadu_ps(X + (i + 3) * W + j); + + // Transpose 4x4 matrix using SSE intrinsics + // Step 1: Interleave low and high halves + __m128 t0 = _mm_unpacklo_ps(r0, r1); // a0 b0 a1 b1 + __m128 t1 = _mm_unpackhi_ps(r0, r1); // a2 b2 a3 b3 + __m128 t2 = _mm_unpacklo_ps(r2, r3); // c0 d0 c1 d1 + __m128 t3 = _mm_unpackhi_ps(r2, r3); // c2 d2 c3 d3 + + // Step 2: Shuffle to get final transposed rows + __m128 col0 = _mm_movelh_ps(t0, t2); // a0 b0 c0 d0 + __m128 col1 = _mm_movehl_ps(t2, t0); // a1 b1 c1 d1 + __m128 col2 = _mm_movelh_ps(t1, t3); // a2 b2 c2 d2 + __m128 col3 = _mm_movehl_ps(t3, t1); // a3 b3 c3 d3 + + // Store transposed columns as rows in output + _mm_storeu_ps(Y + j * H + i, col0); + _mm_storeu_ps(Y + (j + 1) * H + i, col1); + _mm_storeu_ps(Y + (j + 2) * H + i, col2); + _mm_storeu_ps(Y + (j + 3) * H + i, col3); + } + + // Handle remaining columns + size_t j_remain = W - (W % 4); + for (size_t j = j_remain; j < W; ++j) { + __m128 col = _mm_set_ps(X[(i + 3) * W + j], X[(i + 2) * W + j], X[(i + 1) * W + j], X[i * W + j]); + _mm_storeu_ps(Y + j * H + i, col); + } + } + + // Handle remaining rows + size_t i_remain = H - (H % 4); + for (size_t j = 0; j < W; ++j) { + for (size_t i = i_remain; i < H; ++i) { Y[j * H + i] = X[i * W + j]; } + } +#else + // Scalar fallback + for (size_t i = 0; i < H; ++i) { + for (size_t j = 0; j < W; ++j) { Y[j * H + i] = X[i * W + j]; } + } +#endif +} + +void transpose_bshd_bhsd_fp32(const mllm_fp32_t* __restrict X, mllm_fp32_t* __restrict Y, size_t B, size_t S, size_t H, + size_t D) { +#if defined(MLLM_HOST_FEATURE_SSE) + for (size_t b = 0; b < B; ++b) { + for (size_t h = 0; h < H; ++h) { + for (size_t s = 0; s < S; ++s) { + size_t d = 0; + // Process 4 elements at a time using SSE + for (; d + 4 <= D; d += 4) { + // B, S, H, D + const mllm_fp32_t* src_ptr = X + b * S * H * D + s * H * D + h * D + d; + // B, H, S, D + mllm_fp32_t* dst_ptr = Y + b * H * S * D + h * S * D + s * D + d; + + __m128 data = _mm_loadu_ps(src_ptr); + _mm_storeu_ps(dst_ptr, data); + } + // Handle remaining elements + for (; d < D; ++d) { + const mllm_fp32_t* src_ptr = X + b * S * H * D + s * H * D + h * D + d; + mllm_fp32_t* dst_ptr = Y + b * H * S * D + h * S * D + s * D + d; + *dst_ptr = *src_ptr; + } + } + } + } +#else + // Scalar fallback + for (size_t b = 0; b < B; ++b) { + for (size_t h = 0; h < H; ++h) { + for (size_t s = 0; s < S; ++s) { + for (size_t d = 0; d < D; ++d) { + const mllm_fp32_t* src_ptr = X + b * S * H * D + s * H * D + h * D + d; + mllm_fp32_t* dst_ptr = Y + b * H * S * D + h * S * D + s * D + d; + *dst_ptr = *src_ptr; + } + } + } + } +#endif +} + +void transpose_last_dims_fp32(const mllm_fp32_t* __restrict input, mllm_fp32_t* __restrict output, size_t batch, size_t dim0, + size_t dim1) { +#if defined(MLLM_HOST_FEATURE_SSE) + for (size_t b = 0; b < batch; b++) { + const mllm_fp32_t* input_batch = input + b * dim0 * dim1; + mllm_fp32_t* output_batch = output + b * dim0 * dim1; + + // Process 4x4 blocks + for (size_t i = 0; i + 4 <= dim0; i += 4) { + for (size_t j = 0; j + 4 <= dim1; j += 4) { + __m128 r0 = _mm_loadu_ps(input_batch + i * dim1 + j); + __m128 r1 = _mm_loadu_ps(input_batch + (i + 1) * dim1 + j); + __m128 r2 = _mm_loadu_ps(input_batch + (i + 2) * dim1 + j); + __m128 r3 = _mm_loadu_ps(input_batch + (i + 3) * dim1 + j); + + // Transpose 4x4 matrix + __m128 t0 = _mm_unpacklo_ps(r0, r1); + __m128 t1 = _mm_unpackhi_ps(r0, r1); + __m128 t2 = _mm_unpacklo_ps(r2, r3); + __m128 t3 = _mm_unpackhi_ps(r2, r3); + + __m128 col0 = _mm_movelh_ps(t0, t2); + __m128 col1 = _mm_movehl_ps(t2, t0); + __m128 col2 = _mm_movelh_ps(t1, t3); + __m128 col3 = _mm_movehl_ps(t3, t1); + + _mm_storeu_ps(output_batch + j * dim0 + i, col0); + _mm_storeu_ps(output_batch + (j + 1) * dim0 + i, col1); + _mm_storeu_ps(output_batch + (j + 2) * dim0 + i, col2); + _mm_storeu_ps(output_batch + (j + 3) * dim0 + i, col3); + } + + // Handle remaining columns in the block + size_t j_remain = dim1 - (dim1 % 4); + for (size_t j = j_remain; j < dim1; ++j) { + __m128 col = _mm_set_ps(input_batch[(i + 3) * dim1 + j], input_batch[(i + 2) * dim1 + j], + input_batch[(i + 1) * dim1 + j], input_batch[i * dim1 + j]); + _mm_storeu_ps(output_batch + j * dim0 + i, col); + } + } + + // Handle remaining rows + size_t i_remain = dim0 - (dim0 % 4); + for (size_t j = 0; j < dim1; ++j) { + for (size_t i = i_remain; i < dim0; ++i) { output_batch[j * dim0 + i] = input_batch[i * dim1 + j]; } + } + } +#else + // Scalar fallback + for (size_t b = 0; b < batch; b++) { + const mllm_fp32_t* input_batch = input + b * dim0 * dim1; + mllm_fp32_t* output_batch = output + b * dim0 * dim1; + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { output_batch[j * dim0 + i] = input_batch[i * dim1 + j]; } + } + } +#endif +} + +void transpose_hw_wh_int64(const mllm_int64_t* __restrict X, mllm_int64_t* __restrict Y, size_t H, size_t W) { +#if defined(MLLM_HOST_FEATURE_SSE2) + // Process 2x2 blocks using SSE2 (128-bit registers hold 2 int64) + for (size_t i = 0; i + 2 <= H; i += 2) { + for (size_t j = 0; j + 2 <= W; j += 2) { + // Load 2 rows, each containing 2 int64s + __m128i r0 = _mm_loadu_si128(reinterpret_cast(X + i * W + j)); + __m128i r1 = _mm_loadu_si128(reinterpret_cast(X + (i + 1) * W + j)); + + // Transpose 2x2 matrix + // col0 = [r0[0], r1[0]] + // col1 = [r0[1], r1[1]] + __m128i col0 = _mm_unpacklo_epi64(r0, r1); + __m128i col1 = _mm_unpackhi_epi64(r0, r1); + + _mm_storeu_si128(reinterpret_cast<__m128i*>(Y + j * H + i), col0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(Y + (j + 1) * H + i), col1); + } + + // Handle remaining columns + size_t j_remain = W - (W % 2); + for (size_t j = j_remain; j < W; ++j) { + __m128i col = _mm_set_epi64x(X[(i + 1) * W + j], X[i * W + j]); + _mm_storeu_si128(reinterpret_cast<__m128i*>(Y + j * H + i), col); + } + } + + // Handle remaining rows + size_t i_remain = H - (H % 2); + for (size_t j = 0; j < W; ++j) { + for (size_t i = i_remain; i < H; ++i) { Y[j * H + i] = X[i * W + j]; } + } +#else + // Scalar fallback + for (size_t i = 0; i < H; ++i) { + for (size_t j = 0; j < W; ++j) { Y[j * H + i] = X[i * W + j]; } + } +#endif +} + +void transpose_bshd_bhsd_int64(const mllm_int64_t* __restrict X, mllm_int64_t* __restrict Y, size_t B, size_t S, size_t H, + size_t D) { +#if defined(MLLM_HOST_FEATURE_SSE2) + for (size_t b = 0; b < B; ++b) { + for (size_t h = 0; h < H; ++h) { + for (size_t s = 0; s < S; ++s) { + size_t d = 0; + // Process 2 elements at a time using SSE2 + for (; d + 2 <= D; d += 2) { + // B, S, H, D + const mllm_int64_t* src_ptr = X + b * S * H * D + s * H * D + h * D + d; + // B, H, S, D + mllm_int64_t* dst_ptr = Y + b * H * S * D + h * S * D + s * D + d; + + __m128i data = _mm_loadu_si128(reinterpret_cast(src_ptr)); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_ptr), data); + } + // Handle remaining element + for (; d < D; ++d) { + const mllm_int64_t* src_ptr = X + b * S * H * D + s * H * D + h * D + d; + mllm_int64_t* dst_ptr = Y + b * H * S * D + h * S * D + s * D + d; + *dst_ptr = *src_ptr; + } + } + } + } +#else + // Scalar fallback + for (size_t b = 0; b < B; ++b) { + for (size_t h = 0; h < H; ++h) { + for (size_t s = 0; s < S; ++s) { + for (size_t d = 0; d < D; ++d) { + const mllm_int64_t* src_ptr = X + b * S * H * D + s * H * D + h * D + d; + mllm_int64_t* dst_ptr = Y + b * H * S * D + h * S * D + s * D + d; + *dst_ptr = *src_ptr; + } + } + } + } +#endif +} + +void transpose_last_dims_int64(const mllm_int64_t* __restrict input, mllm_int64_t* __restrict output, size_t batch, size_t dim0, + size_t dim1) { +#if defined(MLLM_HOST_FEATURE_SSE2) + for (size_t b = 0; b < batch; b++) { + const mllm_int64_t* input_batch = input + b * dim0 * dim1; + mllm_int64_t* output_batch = output + b * dim0 * dim1; + + // Process 2x2 blocks + for (size_t i = 0; i + 2 <= dim0; i += 2) { + for (size_t j = 0; j + 2 <= dim1; j += 2) { + __m128i r0 = _mm_loadu_si128(reinterpret_cast(input_batch + i * dim1 + j)); + __m128i r1 = _mm_loadu_si128(reinterpret_cast(input_batch + (i + 1) * dim1 + j)); + + __m128i col0 = _mm_unpacklo_epi64(r0, r1); + __m128i col1 = _mm_unpackhi_epi64(r0, r1); + + _mm_storeu_si128(reinterpret_cast<__m128i*>(output_batch + j * dim0 + i), col0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(output_batch + (j + 1) * dim0 + i), col1); + } + + // Handle remaining columns + size_t j_remain = dim1 - (dim1 % 2); + for (size_t j = j_remain; j < dim1; ++j) { + __m128i col = _mm_set_epi64x(input_batch[(i + 1) * dim1 + j], input_batch[i * dim1 + j]); + _mm_storeu_si128(reinterpret_cast<__m128i*>(output_batch + j * dim0 + i), col); + } + } + + // Handle remaining rows + size_t i_remain = dim0 - (dim0 % 2); + for (size_t j = 0; j < dim1; ++j) { + for (size_t i = i_remain; i < dim0; ++i) { output_batch[j * dim0 + i] = input_batch[i * dim1 + j]; } + } + } +#else + // Scalar fallback + for (size_t b = 0; b < batch; b++) { + const mllm_int64_t* input_batch = input + b * dim0 * dim1; + mllm_int64_t* output_batch = output + b * dim0 * dim1; + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { output_batch[j * dim0 + i] = input_batch[i * dim1 + j]; } + } + } +#endif +} + +void permute_fp32(const mllm_fp32_t* __restrict input, mllm_fp32_t* __restrict output, const int* __restrict in_shape, + const int* __restrict perm, int ndim) { + std::vector out_shape(ndim); + for (int i = 0; i < ndim; ++i) { out_shape[i] = in_shape[perm[i]]; } + std::vector in_strides(ndim), out_strides(ndim); + compute_strides(in_shape, ndim, in_strides.data()); + compute_strides(out_shape.data(), ndim, out_strides.data()); + int total_elements = 1; + for (int i = 0; i < ndim; ++i) { total_elements *= in_shape[i]; } + bool inner_dim_contiguous = (perm[ndim - 1] == ndim - 1); + int inner_dim_size = out_shape[ndim - 1]; + if (inner_dim_contiguous && inner_dim_size >= 4) { + int outer_elements = total_elements / inner_dim_size; +#if defined(MLLM_HOST_FEATURE_SSE) + const int chunk_size = 4; +#else + const int chunk_size = 1; +#endif + for (int outer_idx = 0; outer_idx < outer_elements; ++outer_idx) { + std::vector coord(ndim - 1); + int temp = outer_idx; + for (int i = ndim - 2; i >= 0; --i) { + coord[i] = temp % out_shape[i]; + temp /= out_shape[i]; + } + int in_offset = 0; + int out_offset = 0; + for (int i = 0; i < ndim - 1; ++i) { + int orig_dim = perm[i]; + in_offset += coord[i] * in_strides[orig_dim]; + out_offset += coord[i] * out_strides[i]; + } + const float* in_ptr = input + in_offset; + float* out_ptr = output + out_offset; + int j = 0; +#if defined(MLLM_HOST_FEATURE_SSE) + for (; j <= inner_dim_size - chunk_size; j += chunk_size) { + __m128 vec = _mm_loadu_ps(in_ptr + j); + _mm_storeu_ps(out_ptr + j, vec); + } +#endif + for (; j < inner_dim_size; ++j) { out_ptr[j] = in_ptr[j]; } + } + } else { + std::vector out_coord(ndim); + std::vector in_coord(ndim); + for (int i = 0; i < total_elements; ++i) { + int temp_idx = i; + for (int d = ndim - 1; d >= 0; --d) { + out_coord[d] = temp_idx % out_shape[d]; + temp_idx /= out_shape[d]; + } + for (int d = 0; d < ndim; ++d) { in_coord[perm[d]] = out_coord[d]; } + int in_offset = 0; + for (int d = 0; d < ndim; ++d) { in_offset += in_coord[d] * in_strides[d]; } + + output[i] = input[in_offset]; + } + } +} + +template +void permute_generic(const T* __restrict input, T* __restrict output, const int* __restrict in_shape, + const int* __restrict perm, int ndim) { + std::vector out_shape(ndim); + for (int i = 0; i < ndim; ++i) { out_shape[i] = in_shape[perm[i]]; } + + std::vector in_strides(ndim), out_strides(ndim); + compute_strides(in_shape, ndim, in_strides.data()); + compute_strides(out_shape.data(), ndim, out_strides.data()); + + int total_elements = 1; + for (int i = 0; i < ndim; ++i) { total_elements *= in_shape[i]; } + + // Use simple element-by-element copy for generic types + std::vector out_coord(ndim); + std::vector in_coord(ndim); + for (int i = 0; i < total_elements; ++i) { + int temp_idx = i; + for (int d = ndim - 1; d >= 0; --d) { + out_coord[d] = temp_idx % out_shape[d]; + temp_idx /= out_shape[d]; + } + for (int d = 0; d < ndim; ++d) { in_coord[perm[d]] = out_coord[d]; } + int in_offset = 0; + for (int d = 0; d < ndim; ++d) { in_offset += in_coord[d] * in_strides[d]; } + + output[i] = input[in_offset]; + } +} + +// Explicit template instantiations for commonly used types +template void permute_generic(const mllm_int8_t* __restrict input, mllm_int8_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); +template void permute_generic(const mllm_uint8_t* __restrict input, mllm_uint8_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); +template void permute_generic(const mllm_int16_t* __restrict input, mllm_int16_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); +template void permute_generic(const mllm_uint16_t* __restrict input, mllm_uint16_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); +template void permute_generic(const mllm_int32_t* __restrict input, mllm_int32_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); +template void permute_generic(const mllm_uint32_t* __restrict input, mllm_uint32_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); +template void permute_generic(const mllm_int64_t* __restrict input, mllm_int64_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); +template void permute_generic(const mllm_uint64_t* __restrict input, mllm_uint64_t* __restrict output, + const int* __restrict in_shape, const int* __restrict perm, int ndim); + +} // namespace mllm::cpu::x86 + +#endif diff --git a/mllm/backends/cpu/kernels/x86/transpose.hpp b/mllm/backends/cpu/kernels/x86/transpose.hpp new file mode 100644 index 000000000..06a30f814 --- /dev/null +++ b/mllm/backends/cpu/kernels/x86/transpose.hpp @@ -0,0 +1,40 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/DataTypes.hpp" +#include "mllm/utils/CPUArchHelper.hpp" + +#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + +#include + +namespace mllm::cpu::x86 { + +void transpose_hw_wh_fp32(const mllm_fp32_t* __restrict X, mllm_fp32_t* __restrict Y, size_t H, size_t W); + +void transpose_bshd_bhsd_fp32(const mllm_fp32_t* __restrict X, mllm_fp32_t* __restrict Y, size_t B, size_t S, size_t H, + size_t D); + +void transpose_last_dims_fp32(const mllm_fp32_t* __restrict input, mllm_fp32_t* __restrict output, size_t batch, size_t dim0, + size_t dim1); + +void transpose_hw_wh_int64(const mllm_int64_t* __restrict X, mllm_int64_t* __restrict Y, size_t H, size_t W); + +void transpose_bshd_bhsd_int64(const mllm_int64_t* __restrict X, mllm_int64_t* __restrict Y, size_t B, size_t S, size_t H, + size_t D); + +void transpose_last_dims_int64(const mllm_int64_t* __restrict input, mllm_int64_t* __restrict output, size_t batch, size_t dim0, + size_t dim1); + +void permute_fp32(const mllm_fp32_t* __restrict input, mllm_fp32_t* __restrict output, const int* __restrict in_shape, + const int* __restrict perm, int ndim); + +template +void permute_generic(const T* __restrict input, T* __restrict output, const int* __restrict in_shape, + const int* __restrict perm, int ndim); + +} // namespace mllm::cpu::x86 + +#endif diff --git a/mllm/backends/cpu/ops/EmbeddingOp.cpp b/mllm/backends/cpu/ops/EmbeddingOp.cpp index da530a8dc..71af75f68 100644 --- a/mllm/backends/cpu/ops/EmbeddingOp.cpp +++ b/mllm/backends/cpu/ops/EmbeddingOp.cpp @@ -46,6 +46,14 @@ void CPUEmbeddingOp::forward(const std::vector& inputs, std::vector({b, (int)s}); + if (token_idx >= 0) { + dequantize_row_q4_0(weight_.ptr() + token_idx * options_.hidden_size / QK4_0, + ous.coffsettedPtr({b, (int)s, 0}), options_.hidden_size); + } + break; + } default: NYI("Not supported weight dtype for arm llm embedding token op"); } }); diff --git a/mllm/backends/cpu/ops/MatMulOp.cpp b/mllm/backends/cpu/ops/MatMulOp.cpp index 4f4cc0efa..312bed3a3 100644 --- a/mllm/backends/cpu/ops/MatMulOp.cpp +++ b/mllm/backends/cpu/ops/MatMulOp.cpp @@ -49,8 +49,8 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector #if defined(MLLM_USE_BLAS) mt = aops::MatMulOpType::kBLAS; #else - if (!transpose_a && transpose_b) { - // TODO: kGGUF still buggy !!! + if (!transpose_a && transpose_b && M >= 4) { + // TODO: GGUF matmul should be correct when M < 4 mt = aops::MatMulOpType::kGGUF; } else // All fallback to mllm blas @@ -66,10 +66,6 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector break; } case aops::MatMulOpType::kGGUF: { - // llamafile implementation - // only supports specific transpose options - MLLM_RT_ASSERT(transpose_a == false && transpose_b == true); - // llamafile uses column-major order, so we actually perform K^T x Q if (lhs.isContiguousN(0)) { mllm::cpu::ggml::mat_mul(lhs, rhs, o, false, nullptr, transpose_a, transpose_b, options_.getThreads()); @@ -110,20 +106,21 @@ void CPUMatMulOp::forward(const std::vector& inputs, std::vector transpose_a, transpose_b, thread_count); } } -// #elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) -// if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && o.dtype() == kFloat32) { -// if (batch_count == 1) { -// x86::mllm_blas_matmul_fp32(M, K, N, o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, -// transpose_a, transpose_b); -// } else { -// x86::mllm_blas_batch_matmul_fp32(batch_count, M, K, N, o.stride()[o.shape().size() - 3], -// lhs.stride()[lhs_shape.size() - 3], rhs.stride()[rhs_shape.size() - 3], 0, -// o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, -// transpose_a, transpose_b); -// } -// } +#elif defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + if (lhs.dtype() == kFloat32 && rhs.dtype() == kFloat32 && o.dtype() == kFloat32) { + if (batch_count == 1) { + x86::mllm_blas_matmul_fp32(M, K, N, o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, + transpose_a, transpose_b, thread_count); + } else { + x86::mllm_blas_batch_matmul_fp32(batch_count, M, K, N, o.stride()[o.shape().size() - 3], + lhs.stride()[lhs_shape.size() - 3], rhs.stride()[rhs_shape.size() - 3], 0, + o.ptr(), lhs.ptr(), rhs.ptr(), nullptr, + transpose_a, transpose_b, thread_count); + } + } #else - NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64 or MLLM_HOST_ARCH_ARM right now.") + NYI("MllmBlas only support MLLM_HOST_ARCH_ARM64, MLLM_HOST_ARCH_ARM, MLLM_HOST_ARCH_X86_64 or MLLM_HOST_ARCH_X86 right " + "now.") #endif break; } diff --git a/mllm/backends/cpu/ops/TransposeOp.cpp b/mllm/backends/cpu/ops/TransposeOp.cpp index b5e1ae939..d9817f82c 100644 --- a/mllm/backends/cpu/ops/TransposeOp.cpp +++ b/mllm/backends/cpu/ops/TransposeOp.cpp @@ -28,7 +28,7 @@ void CPUTransposeOp::forward(const std::vector& inputs, std::vector WH) fp32 not supported in x86"); + x86::transpose_hw_wh_fp32(input.ptr(), output.ptr(), input_shape[0], input_shape[1]); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::transpose_hw_wh_fp32(input.ptr(), output.ptr(), input_shape[0], input_shape[1]); #endif @@ -43,7 +43,9 @@ void CPUTransposeOp::forward(const std::vector& inputs, std::vector(), output.ptr(), input_shape[0], input_shape[1]); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::transpose_hw_wh_int64(input.ptr(), output.ptr(), input_shape[0], input_shape[1]); #endif break; @@ -57,7 +59,8 @@ void CPUTransposeOp::forward(const std::vector& inputs, std::vector BHSD) fp32 not supported in x86"); + x86::transpose_bshd_bhsd_fp32(input.ptr(), output.ptr(), input_shape[0], input_shape[1], + input_shape[2], input_shape[3]); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::transpose_bshd_bhsd_fp32(input.ptr(), output.ptr(), input_shape[0], input_shape[1], input_shape[2], input_shape[3]); @@ -74,7 +77,10 @@ void CPUTransposeOp::forward(const std::vector& inputs, std::vector(), output.ptr(), input_shape[0], input_shape[1], + input_shape[2], input_shape[3]); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::transpose_bshd_bhsd_int64(input.ptr(), output.ptr(), input_shape[0], input_shape[1], input_shape[2], input_shape[3]); #endif @@ -93,7 +99,8 @@ void CPUTransposeOp::forward(const std::vector& inputs, std::vector BSDH) fp32 not supported in x86"); + x86::transpose_last_dims_fp32(input.ptr(), output.ptr(), batch, input_shape[0], + input_shape[1]); #elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::transpose_last_dims_fp32(input.ptr(), output.ptr(), batch, input_shape[0], input_shape[1]); @@ -110,7 +117,10 @@ void CPUTransposeOp::forward(const std::vector& inputs, std::vector(), output.ptr(), batch, input_shape[0], + input_shape[1]); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::transpose_last_dims_int64(input.ptr(), output.ptr(), batch, input_shape[0], input_shape[1]); #endif @@ -129,21 +139,29 @@ void CPUTransposeOp::forward(const std::vector& inputs, std::vector(), output.ptr(), input_shape.data(), permute_axis.data(), + permute_axis.size()); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::permute_fp32(input.ptr(), output.ptr(), input_shape.data(), permute_axis.data(), permute_axis.size()); #endif break; } case kFloat16: { -#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + NYI("Transpose op(General permute) fp16 not supported in x86"); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::permute_fp16(input.ptr(), output.ptr(), input_shape.data(), permute_axis.data(), permute_axis.size()); #endif break; } case kInt64: { -#if defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) +#if defined(MLLM_HOST_ARCH_X86_64) || defined(MLLM_HOST_ARCH_X86) + x86::permute_generic(input.ptr(), output.ptr(), input_shape.data(), + permute_axis.data(), permute_axis.size()); +#elif defined(MLLM_HOST_ARCH_ARM64) || defined(MLLM_HOST_ARCH_ARM) arm::permute_generic(input.ptr(), output.ptr(), input_shape.data(), permute_axis.data(), permute_axis.size()); #endif diff --git a/mllm/core/aops/EmbeddingOp.cpp b/mllm/core/aops/EmbeddingOp.cpp index a5a6400dd..a1653ba72 100644 --- a/mllm/core/aops/EmbeddingOp.cpp +++ b/mllm/core/aops/EmbeddingOp.cpp @@ -73,6 +73,7 @@ void EmbeddingOp::reshape(const std::vector& inputs, std::vector // Output dtype should match weight dtype (e.g., uint16 for AsymPerTensor quantization) auto out_dtype = weight_.dtype(); if (weight_.dtype() == kUInt16) { out_dtype = kUInt16PerTensorAsy; } + if (weight_.dtype() == kGGUF_Q4_0 || weight_.dtype() == kGGUF_Q4_K) { out_dtype = kFloat32; } outputs.emplace_back(Tensor::empty(o_shape, out_dtype, i.device())); }