diff --git a/common_level3.h b/common_level3.h index fc4909e693..08373e9d6a 100644 --- a/common_level3.h +++ b/common_level3.h @@ -59,6 +59,23 @@ void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K, float beta, float * R, BLASLONG strideR); +void strmm_direct_LNUN(BLASLONG M, BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB); +void strmm_direct_LNLN(BLASLONG M, BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB); +void strmm_direct_LTUN(BLASLONG M, BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB); +void strmm_direct_LTLN(BLASLONG M, BLASLONG N, BLASLONG K, + float alpha, + float * A, BLASLONG strideA, + float * B, BLASLONG strideB); + int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K); int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float, diff --git a/common_param.h b/common_param.h index 0145f667a1..2e5832f059 100644 --- a/common_param.h +++ b/common_param.h @@ -307,6 +307,12 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *); int (*strsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *); #endif #if (BUILD_SINGLE==1) +#ifdef ARCH_ARM64 + void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG); + void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG); + void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG); + void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG); +#endif int (*strmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); int (*strmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); int (*strmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG); diff --git a/common_s.h b/common_s.h index 88b4732f51..df1388b75a 100644 --- a/common_s.h +++ b/common_s.h @@ -62,6 +62,11 @@ #define SGEMM_ITCOPY sgemm_itcopy #endif +#define STRMM_DIRECT_LNUN strmm_direct_LNUN +#define STRMM_DIRECT_LNLN strmm_direct_LNLN +#define STRMM_DIRECT_LTUN strmm_direct_LTUN +#define STRMM_DIRECT_LTLN strmm_direct_LTLN + #define STRMM_OUNUCOPY strmm_ounucopy #define STRMM_OUNNCOPY strmm_ounncopy #define STRMM_OUTUCOPY strmm_outucopy @@ -227,6 +232,13 @@ #define SGEMM_INCOPY gotoblas -> sgemm_incopy #define SGEMM_ITCOPY gotoblas -> sgemm_itcopy +#ifdef ARCH_ARM64 +#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN +#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN +#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN +#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN +#endif + #define STRMM_OUNUCOPY gotoblas -> strmm_ounucopy #define STRMM_OUTUCOPY gotoblas -> strmm_outucopy #define STRMM_OLNUCOPY gotoblas -> strmm_olnucopy diff --git a/interface/trsm.c b/interface/trsm.c index 715c83a1f3..6f2837595b 100644 --- a/interface/trsm.c +++ b/interface/trsm.c @@ -255,6 +255,21 @@ void CNAME(enum CBLAS_ORDER order, #endif PRINT_DEBUG_CNAME; +#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16) +#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH)) +#if defined(DYNAMIC_ARCH) + if (support_sme1()) +#endif + if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft) { + if (Trans == CblasNoTrans) { + (Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, m, alpha, a, lda, b, ldb); + } else if (Trans == CblasTrans) { + (Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, m, alpha, a, lda, b, ldb); + } + return; + } +#endif +#endif args.a = (void *)a; args.b = (void *)b; diff --git a/kernel/CMakeLists.txt b/kernel/CMakeLists.txt index a2e349d32d..964632c75a 100644 --- a/kernel/CMakeLists.txt +++ b/kernel/CMakeLists.txt @@ -237,6 +237,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10)) set(USE_TRMM true) endif () + set(USE_DIRECT_STRMM false) + if (ARM64) + set(USE_DIRECT_STRMM true) + endif() set(USE_DIRECT_SGEMM false) if (X86_64 OR ARM64) set(USE_DIRECT_SGEMM true) @@ -442,6 +446,17 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS) set(TRMM_KERNEL "${${float_char}GEMMKERNEL}") endif () + if (USE_DIRECT_STRMM) + set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c) + set (STRMMDIRECTPREKERNEL strmm_direct_arm64_sme1_preprocess.c) + GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_UN" false "" "" false SINGLE) + GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_LN" false "" "" false SINGLE) + endif () + if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX") # just enumerate all these. there is an extra define for these indicating which side is a conjugate (e.g. CN NC NN) that I don't really want to work into GenerateCombinationObjects diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index 79c88d76c7..d02bcda910 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -52,6 +52,7 @@ endif ifeq ($(ARCH), arm64) USE_TRMM = 1 USE_DIRECT_SGEMM = 1 +USE_DIRECT_STRMM = 1 endif ifeq ($(ARCH), riscv64) @@ -137,6 +138,10 @@ endif endif endif +ifdef USE_DIRECT_STRMM +STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c +endif + ifeq ($(BUILD_BFLOAT16), 1) ifndef BGEMMKERNEL BGEMM_BETA = ../generic/gemm_beta.c @@ -286,6 +291,15 @@ SBLASOBJS += \ strsm_kernel_RN$(TSUFFIX).$(SUFFIX) strsm_kernel_RT$(TSUFFIX).$(SUFFIX) endif +ifdef USE_DIRECT_STRMM +SBLASOBJS += \ + strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \ + strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) +SBLASOBJS += \ + strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) \ + strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX) +endif + ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" "" DBLASOBJS += \ dgemm_beta$(TSUFFIX).$(SUFFIX) \ @@ -1150,6 +1164,27 @@ else $(CC) $(CFLAGS) -c -DTRMMKERNEL -UDOUBLE -UCOMPLEX -ULEFT -DTRANSA $< -o $@ endif + +ifdef USE_DIRECT_STRMM +$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@ + +$(KDIR)strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -UUPPER $< -o $@ + +$(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -DUPPER $< -o $@ + +$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL) + $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@ + +$(KDIR)strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) : + $(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -DUPPER $< -o $@ + +$(KDIR)strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX) : + $(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -UUPPER $< -o $@ +endif + $(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL) ifeq ($(OS), AIX) $(CC) $(CFLAGS) -S -DTRMMKERNEL -DDOUBLE -UCOMPLEX -DLEFT -UTRANSA $< -o - > dtrmm_kernel_ln.s diff --git a/kernel/arm64/strmm_direct_arm64_sme1.c b/kernel/arm64/strmm_direct_arm64_sme1.c new file mode 100644 index 0000000000..2707c835e3 --- /dev/null +++ b/kernel/arm64/strmm_direct_arm64_sme1.c @@ -0,0 +1,191 @@ +/* + Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +#include "common.h" +#include +#include +#include +#if defined(HAVE_SME) + +#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 +#include "sme_abi.h" +#include +#endif + +/* Function prototypes */ +// For Upper and NonUnit Triangular preprocess +extern void strmm_direct_sme1_preprocess_UN(uint64_t nbr, uint64_t nbc, const float * restrict a, float * a_mod); +// For Lower and NonUnit Triangular preprocess +extern void strmm_direct_sme1_preprocess_LN(uint64_t nbr, uint64_t nbc, const float * restrict a, float * a_mod); + +/* Function Definitions */ +static uint64_t sve_cntw() { + uint64_t cnt; + asm volatile( + "rdsvl %[res], #1\n" + "lsr %[res], %[res], #2\n" + : [res] "=r" (cnt) :: + ); + return cnt; +} + +#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 +// Outer product kernel. +// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA. +__attribute__((always_inline)) inline void +kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim, + size_t ldc, size_t block_rows, size_t block_cols, float alpha, uint64_t row_idx) + __arm_out("za") __arm_streaming { + const uint64_t svl = svcntw(); + size_t ldb = ldc; + // Predicate set-up + svbool_t pg = svptrue_b32(); + svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows); + svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows); + +#if (!defined(TRANSA) && !defined(UPPER)) || (defined(TRANSA) && defined(UPPER)) +#define pg_a_0_full pg_a_0 +#define pg_a_1_full pg_a_1 +#endif + svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols); + svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols); + +#define pg_c_0 pg_b_0 +#define pg_c_1 pg_b_1 + + svzero_za(); + svfloat32_t alpha_vec = svdup_f32(alpha); + // Iterate through shared dimension (K) +#if (!defined(TRANSA) && defined(UPPER)) || (defined(TRANSA) && !defined(UPPER)) + for (size_t k = row_idx, valid_index = 1; k < shared_dim; k++,valid_index++) { + pg_a_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_rows)); + pg_a_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_rows)); +#else + for (size_t k = 0; k < MIN(row_idx + block_rows, shared_dim); k++) { + // If k exceeds row_idx, mask out rows before (k - row_idx) + // This ensures only valid rows are included for lower triangular logic. + if (k > row_idx) { + pg_a_0 = svnot_b_z(pg_a_0_full, svwhilelt_b32_u64(0, k - row_idx)); + pg_a_1 = svnot_b_z(pg_a_1_full, svwhilelt_b32_u64(svl, k - row_idx)); + } +#endif + +#if !defined(TRANSA) + // Load column of A + svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]); + svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]); +#else + svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * shared_dim]); + svfloat32_t col_a_1 = svld1(pg_a_1, &A[k * shared_dim + svl]); +#endif + col_a_0 = svmul_x(pg_a_0, alpha_vec, col_a_0); + col_a_1 = svmul_x(pg_a_1, alpha_vec, col_a_1); + // Load row of B + svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]); + svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]); + // Perform outer product + svmopa_za32_m(/*tile*/0, pg_a_0, pg, col_a_0, row_b_0); + svmopa_za32_m(/*tile*/1, pg_a_0, pg, col_a_0, row_b_1); + svmopa_za32_m(/*tile*/2, pg_a_1, pg, col_a_1, row_b_0); + svmopa_za32_m(/*tile*/3, pg_a_1, pg, col_a_1, row_b_1); + } + + // Store to C from ZA + for (size_t i = 0; i < MIN(svl, block_rows); i++) { + svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } + for (size_t i = svl; i < block_rows; i++) { + svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]); + svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]); + } + +} + +__arm_new("za") __arm_locally_streaming +static inline void strmm_direct_alpha_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ + const float *ba, float *restrict bb) { + const uint64_t num_rows = m; + const uint64_t num_cols = n; + + const float *restrict a_ptr = ba; + const float *restrict b_ptr = bb; + float *restrict c_ptr = bb; + + const uint64_t svl = svcntw(); + const uint64_t svl_x2 = 2*svl; + const uint64_t ldc = n; + + + uint64_t row_idx = 0; +#if (!defined(TRANSA) && defined(UPPER)) || (defined(TRANSA) && !defined(UPPER)) + // 2x2 loop + uint64_t row_batch = svl_x2; + // Block over rows of C (panels of A) + for (; row_idx < num_rows; row_idx += row_batch) { + row_batch = MIN(row_batch, num_rows - row_idx); +#else + // Calculate the remainder of num_rows divided by 2VL to determine tail tile size + uint64_t row_batch = num_rows % svl_x2; + // If there's no remainder, use full tile size (2VL) for initial batch + if (row_batch == 0) row_batch = svl_x2; + // Loop from bottom to top, processing rows in batches + for (uint64_t index = num_rows; index > 0; index -= row_batch, row_batch = svl_x2) { + // Compute the starting row index for the current batch + row_idx = index - row_batch; +#endif + uint64_t col_idx = 0; + uint64_t col_batch = svl_x2; + // Block over column dimension of C + for (; col_idx < num_cols; col_idx += col_batch) { + col_batch = MIN(col_batch, num_cols - col_idx); +#if !defined(TRANSA) + kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx], + &c_ptr[row_idx * ldc + col_idx], k, + ldc, row_batch, col_batch, *alpha, row_idx); +#else + kernel_2x2(&a_ptr[row_idx], &b_ptr[col_idx], + &c_ptr[row_idx * ldc + col_idx], k, + ldc, row_batch, col_batch, *alpha, row_idx); +#endif + } + } + + return; +} + +#else +void strmm_direct_alpha_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\ + const float *ba, float *restrict bb){} +#endif + +void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ + BLASLONG strideA, float * __restrict B, BLASLONG strideB){ +#if !defined(TRANSA) + uint64_t m_mod, vl_elms; + + vl_elms = sve_cntw(); + + m_mod = ceil((double)M/(double)vl_elms) * vl_elms; + + float *A_mod = (float *) malloc(m_mod*K*sizeof(float)); +#if defined(UPPER) + strmm_direct_sme1_preprocess_UN(M, K, A, A_mod); +#else + strmm_direct_sme1_preprocess_LN(M, K, A, A_mod); +#endif + /* Calculate B = alpha*A*B*/ + strmm_direct_alpha_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B); + free(A_mod); +#else + strmm_direct_alpha_sme1_2VLx2VL(M, K, N, &alpha, A, B); +#endif +} + +#else +void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\ + BLASLONG strideA, float * __restrict B, BLASLONG strideB){} + +#endif diff --git a/kernel/arm64/strmm_direct_arm64_sme1_preprocess.c b/kernel/arm64/strmm_direct_arm64_sme1_preprocess.c new file mode 100644 index 0000000000..0475a0e796 --- /dev/null +++ b/kernel/arm64/strmm_direct_arm64_sme1_preprocess.c @@ -0,0 +1,100 @@ +/* + Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +#include "common.h" +#include +#include +#include +#if defined(HAVE_SME) + +#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16 +#include "sme_abi.h" +#include +#endif + +#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16 +// Transpose 1SVL x N panel of A +__attribute__((always_inline)) inline static void transpose_panel_lower(const float *restrict a, float *restrict b, uint64_t rows, uint64_t cols, + uint64_t a_step, uint64_t rows_index) +__arm_out("za") __arm_streaming { + // for Lower Trangular Matrix + uint64_t svl = svcntw(); + uint64_t col_batch = svl; + + svzero_za(); + uint64_t last_rows_index = rows_index + rows - 1; + for (uint64_t k = 0; k < cols; k += col_batch) { + if (last_rows_index < k) { + // Early exit: if all rows are above the diagonal, no valid elements remain + break; + } + // Load to horizontal slices + for (uint64_t row = 0; row < rows; row++) { + svbool_t pg_row = svwhilelt_b32_u64(k, MIN(rows_index + row + 1, cols)); + svld1_hor_za32(0, row, pg_row, &a[row * a_step + k]); + } + + // Save from vertical slices + col_batch = MIN(col_batch, cols - k); + for (uint64_t col = 0; col < col_batch; col++) { + svst1_ver_za32(0, col, svptrue_b32(), &b[(col + k) * svl]); + } + } +} + + +__attribute__((always_inline)) inline static void transpose_panel_upper(const float *restrict a, float *restrict b, uint64_t rows, uint64_t cols, uint64_t a_step, uint64_t rows_index) +__arm_out("za") __arm_streaming { + // for Upper Trangular Matrix + uint64_t svl = svcntw(); + uint64_t col_batch = svl; + + svzero_za(); + // Start from column k = rows_index to ensure we only process the upper triangle (k >= rows_index) + for (uint64_t k = rows_index; k < cols; k += col_batch) { + // Load to horizontal slices + for (uint64_t row = 0; row < rows; row++) { + svbool_t pg_row = svwhilelt_b32_u64(k, cols); + svld1_hor_za32(0, row, pg_row, &a[row * a_step + k]); + } + + // Save from vertical slices + col_batch = MIN(col_batch, cols - k); + for (uint64_t col = 0, real_col = k; col < col_batch; col++, real_col++) { + // Only the upper triangular part of the matrix is stored. + svbool_t pg_col = svwhilelt_b32_u64(rows_index, real_col + 1); + svst1_ver_za32(0, col, pg_col, &b[(col + k) * svl]); + } + } +} + +__arm_new("za") __arm_locally_streaming +static inline void strmm_direct_pack_lhs_f32_sme1_acle(uint64_t nbr, uint64_t nbc, + const float *restrict a, float *restrict a_mod) { + const uint64_t num_rows = nbr; + uint64_t row_batch = svcntsw(); + for (uint64_t row_idx = 0; row_idx < num_rows; row_idx += row_batch) { + // Transpose 1SVL x N panel of A + row_batch = MIN(row_batch, num_rows - row_idx); +#if !defined(UPPER) + transpose_panel_lower(&a[row_idx * nbc], &a_mod[row_idx * nbc], row_batch, nbc, nbc, row_idx); +#else + transpose_panel_upper(&a[row_idx * nbc], &a_mod[row_idx * nbc], row_batch, nbc, nbc, row_idx); +#endif + } +} + +#else +void strmm_direct_pack_lhs_f32_sme1_acle(uint64_t nbr, uint64_t nbc, const float *restrict a, float *restrict a_mod){} +#endif + +void CNAME (BLASLONG M, BLASLONG K, const float * __restrict A, float * __restrict A_mod){ + strmm_direct_pack_lhs_f32_sme1_acle(M, K, A, A_mod); +} + +#else +void CNAME (BLASLONG M, BLASLONG K, const float * __restrict A, float * __restrict A_mod){} + +#endif diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index ccfbab8c11..d95b542e76 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -248,6 +248,12 @@ gotoblas_t TABLE_NAME = { strsm_olnucopyTS, strsm_olnncopyTS, strsm_oltucopyTS, strsm_oltncopyTS, #endif #if (BUILD_SINGLE==1) +#ifdef ARCH_ARM64 + strmm_direct_LNUNTS, + strmm_direct_LNLNTS, + strmm_direct_LTUNTS, + strmm_direct_LTLNTS, +#endif strmm_kernel_RNTS, strmm_kernel_RTTS, strmm_kernel_LNTS, strmm_kernel_LTTS, #if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N strmm_iunucopyTS, strmm_iunncopyTS, strmm_iutucopyTS, strmm_iutncopyTS, diff --git a/param.h b/param.h index d0ee246e83..a860f4b966 100644 --- a/param.h +++ b/param.h @@ -1671,6 +1671,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GEMM_PREFERED_SIZE 16 #endif #define USE_SGEMM_KERNEL_DIRECT 1 +#define USE_STRMM_KERNEL_DIRECT 1 #ifdef ARCH_X86 @@ -1792,6 +1793,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GEMM_PREFERED_SIZE 16 #endif #define USE_SGEMM_KERNEL_DIRECT 1 +#define USE_STRMM_KERNEL_DIRECT 1 #undef SBGEMM_DEFAULT_UNROLL_N #undef SBGEMM_DEFAULT_UNROLL_M @@ -1925,6 +1927,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GEMM_PREFERED_SIZE 16 #endif #define USE_SGEMM_KERNEL_DIRECT 1 +#define USE_STRMM_KERNEL_DIRECT 1 #undef SBGEMM_DEFAULT_UNROLL_N #undef SBGEMM_DEFAULT_UNROLL_M @@ -3844,6 +3847,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout #if defined(ARMV9SME) /* ARMv9 SME */ #define USE_SGEMM_KERNEL_DIRECT 1 +#define USE_STRMM_KERNEL_DIRECT 1 #endif /* ARMv9 SME */ #if defined(ARMV5)