Skip to content

Commit 5d80985

Browse files
committed
Support for SME1 based strmm_direct kernel for cblas_strmm level 3 API
1 parent e2f9f57 commit 5d80985

File tree

10 files changed

+399
-0
lines changed

10 files changed

+399
-0
lines changed

common_level3.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,23 @@ void sgemm_direct_alpha_beta(BLASLONG M, BLASLONG N, BLASLONG K,
5959
float beta,
6060
float * R, BLASLONG strideR);
6161

62+
void strmm_direct_LNUN(BLASLONG M, BLASLONG N, BLASLONG K,
63+
float alpha,
64+
float * A, BLASLONG strideA,
65+
float * B, BLASLONG strideB);
66+
void strmm_direct_LNLN(BLASLONG M, BLASLONG N, BLASLONG K,
67+
float alpha,
68+
float * A, BLASLONG strideA,
69+
float * B, BLASLONG strideB);
70+
void strmm_direct_LTUN(BLASLONG M, BLASLONG N, BLASLONG K,
71+
float alpha,
72+
float * A, BLASLONG strideA,
73+
float * B, BLASLONG strideB);
74+
void strmm_direct_LTLN(BLASLONG M, BLASLONG N, BLASLONG K,
75+
float alpha,
76+
float * A, BLASLONG strideA,
77+
float * B, BLASLONG strideB);
78+
6279
int sgemm_direct_performant(BLASLONG M, BLASLONG N, BLASLONG K);
6380

6481
int shgemm_beta(BLASLONG, BLASLONG, BLASLONG, float,

common_param.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,12 @@ int (*shgemm_otcopy )(BLASLONG, BLASLONG, hfloat16 *, BLASLONG, hfloat16 *);
307307
int (*strsm_oltncopy)(BLASLONG, BLASLONG, float *, BLASLONG, BLASLONG, float *);
308308
#endif
309309
#if (BUILD_SINGLE==1)
310+
#ifdef ARCH_ARM64
311+
void (*strmm_direct_LNUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
312+
void (*strmm_direct_LNLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
313+
void (*strmm_direct_LTUN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
314+
void (*strmm_direct_LTLN) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG);
315+
#endif
310316
int (*strmm_kernel_RN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
311317
int (*strmm_kernel_RT)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);
312318
int (*strmm_kernel_LN)(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG, BLASLONG);

common_s.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@
6262
#define SGEMM_ITCOPY sgemm_itcopy
6363
#endif
6464

65+
#define STRMM_DIRECT_LNUN strmm_direct_LNUN
66+
#define STRMM_DIRECT_LNLN strmm_direct_LNLN
67+
#define STRMM_DIRECT_LTUN strmm_direct_LTUN
68+
#define STRMM_DIRECT_LTLN strmm_direct_LTLN
69+
6570
#define STRMM_OUNUCOPY strmm_ounucopy
6671
#define STRMM_OUNNCOPY strmm_ounncopy
6772
#define STRMM_OUTUCOPY strmm_outucopy
@@ -227,6 +232,13 @@
227232
#define SGEMM_INCOPY gotoblas -> sgemm_incopy
228233
#define SGEMM_ITCOPY gotoblas -> sgemm_itcopy
229234

235+
#ifdef ARCH_ARM64
236+
#define STRMM_DIRECT_LNUN gotoblas -> strmm_direct_LNUN
237+
#define STRMM_DIRECT_LNLN gotoblas -> strmm_direct_LNLN
238+
#define STRMM_DIRECT_LTUN gotoblas -> strmm_direct_LTUN
239+
#define STRMM_DIRECT_LTLN gotoblas -> strmm_direct_LTLN
240+
#endif
241+
230242
#define STRMM_OUNUCOPY gotoblas -> strmm_ounucopy
231243
#define STRMM_OUTUCOPY gotoblas -> strmm_outucopy
232244
#define STRMM_OLNUCOPY gotoblas -> strmm_olnucopy

interface/trsm.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,21 @@ void CNAME(enum CBLAS_ORDER order,
255255
#endif
256256

257257
PRINT_DEBUG_CNAME;
258+
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && !defined(HFLOAT16)
259+
#if defined(ARCH_ARM64) && (defined(USE_STRMM_KERNEL_DIRECT)||defined(DYNAMIC_ARCH))
260+
#if defined(DYNAMIC_ARCH)
261+
if (support_sme1())
262+
#endif
263+
if (order == CblasRowMajor && Diag == CblasNonUnit && Side == CblasLeft) {
264+
if (Trans == CblasNoTrans) {
265+
(Uplo == CblasUpper ? STRMM_DIRECT_LNUN : STRMM_DIRECT_LNLN)(m, n, m, alpha, a, lda, b, ldb);
266+
} else if (Trans == CblasTrans) {
267+
(Uplo == CblasUpper ? STRMM_DIRECT_LTUN : STRMM_DIRECT_LTLN)(m, n, m, alpha, a, lda, b, ldb);
268+
}
269+
return;
270+
}
271+
#endif
272+
#endif
258273

259274
args.a = (void *)a;
260275
args.b = (void *)b;

kernel/CMakeLists.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,10 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
237237
if (ZARCH OR (UC_TARGET_CORE MATCHES POWER8) OR (UC_TARGET_CORE MATCHES POWER9) OR (UC_TARGET_CORE MATCHES POWER10))
238238
set(USE_TRMM true)
239239
endif ()
240+
set(USE_DIRECT_STRMM false)
241+
if (ARM64)
242+
set(USE_DIRECT_STRMM true)
243+
endif()
240244
set(USE_DIRECT_SGEMM false)
241245
if (X86_64 OR ARM64)
242246
set(USE_DIRECT_SGEMM true)
@@ -442,6 +446,17 @@ function (build_core TARGET_CORE KDIR TSUFFIX KERNEL_DEFINITIONS)
442446
set(TRMM_KERNEL "${${float_char}GEMMKERNEL}")
443447
endif ()
444448

449+
if (USE_DIRECT_STRMM)
450+
set (STRMMDIRECTKERNEL strmm_direct_arm64_sme1.c)
451+
set (STRMMDIRECTPREKERNEL strmm_direct_arm64_sme1_preprocess.c)
452+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNUN" false "" "" false SINGLE)
453+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LNLN" false "" "" false SINGLE)
454+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTUN" false "" "" false SINGLE)
455+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTKERNEL}" "" "trmm_direct_LTLN" false "" "" false SINGLE)
456+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_UN" false "" "" false SINGLE)
457+
GenerateNamedObjects("${KERNELDIR}/${STRMMDIRECTPREKERNEL}" "" "trmm_direct_sme1_preprocess_LN" false "" "" false SINGLE)
458+
endif ()
459+
445460
if (${float_type} STREQUAL "COMPLEX" OR ${float_type} STREQUAL "ZCOMPLEX")
446461

447462
# just enumerate all these. there is an extra define for these indicating which side is a conjugate (e.g. CN NC NN) that I don't really want to work into GenerateCombinationObjects

kernel/Makefile.L3

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ endif
5252
ifeq ($(ARCH), arm64)
5353
USE_TRMM = 1
5454
USE_DIRECT_SGEMM = 1
55+
USE_DIRECT_STRMM = 1
5556
endif
5657

5758
ifeq ($(ARCH), riscv64)
@@ -137,6 +138,10 @@ endif
137138
endif
138139
endif
139140

141+
ifdef USE_DIRECT_STRMM
142+
STRMMDIRECTKERNEL = strmm_direct_arm64_sme1.c
143+
endif
144+
140145
ifeq ($(BUILD_BFLOAT16), 1)
141146
ifndef BGEMMKERNEL
142147
BGEMM_BETA = ../generic/gemm_beta.c
@@ -286,6 +291,15 @@ SBLASOBJS += \
286291
strsm_kernel_RN$(TSUFFIX).$(SUFFIX) strsm_kernel_RT$(TSUFFIX).$(SUFFIX)
287292
endif
288293

294+
ifdef USE_DIRECT_STRMM
295+
SBLASOBJS += \
296+
strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) \
297+
strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) strmm_direct_LTLN$(TSUFFIX).$(SUFFIX)
298+
SBLASOBJS += \
299+
strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) \
300+
strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX)
301+
endif
302+
289303
ifneq "$(or $(BUILD_DOUBLE),$(BUILD_COMPLEX16))" ""
290304
DBLASOBJS += \
291305
dgemm_beta$(TSUFFIX).$(SUFFIX) \
@@ -1150,6 +1164,27 @@ else
11501164
$(CC) $(CFLAGS) -c -DTRMMKERNEL -UDOUBLE -UCOMPLEX -ULEFT -DTRANSA $< -o $@
11511165
endif
11521166

1167+
1168+
ifdef USE_DIRECT_STRMM
1169+
$(KDIR)strmm_direct_LNUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1170+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -DUPPER $< -o $@
1171+
1172+
$(KDIR)strmm_direct_LNLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1173+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -UTRANSA -UUPPER $< -o $@
1174+
1175+
$(KDIR)strmm_direct_LTUN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1176+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -DUPPER $< -o $@
1177+
1178+
$(KDIR)strmm_direct_LTLN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(STRMMDIRECTKERNEL)
1179+
$(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX -DTRANSA -UUPPER $< -o $@
1180+
1181+
$(KDIR)strmm_direct_sme1_preprocess_UN$(TSUFFIX).$(SUFFIX) :
1182+
$(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -DUPPER $< -o $@
1183+
1184+
$(KDIR)strmm_direct_sme1_preprocess_LN$(TSUFFIX).$(SUFFIX) :
1185+
$(CC) $(CFLAGS) -c $(KERNELDIR)/strmm_direct_arm64_sme1_preprocess.c -UDOUBLE -UCOMPLEX -UUPPER $< -o $@
1186+
endif
1187+
11531188
$(KDIR)dtrmm_kernel_LN$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(DTRMMKERNEL)
11541189
ifeq ($(OS), AIX)
11551190
$(CC) $(CFLAGS) -S -DTRMMKERNEL -DDOUBLE -UCOMPLEX -DLEFT -UTRANSA $< -o - > dtrmm_kernel_ln.s
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/*
2+
Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
3+
SPDX-License-Identifier: BSD-3-Clause-Clear
4+
*/
5+
6+
#include "common.h"
7+
#include <stdlib.h>
8+
#include <inttypes.h>
9+
#include <math.h>
10+
#include "sme_abi.h"
11+
#if defined(HAVE_SME)
12+
13+
#if defined(__ARM_FEATURE_SME) && defined(__clang__) && __clang_major__ >= 16
14+
#include <arm_sme.h>
15+
#endif
16+
17+
/* Function prototypes */
18+
// For Upper and NonUnit Triangular preprocess
19+
extern void strmm_direct_sme1_preprocess_UN(uint64_t nbr, uint64_t nbc, const float * restrict a, float * a_mod);
20+
// For Lower and NonUnit Triangular preprocess
21+
extern void strmm_direct_sme1_preprocess_LN(uint64_t nbr, uint64_t nbc, const float * restrict a, float * a_mod);
22+
23+
/* Function Definitions */
24+
static uint64_t sve_cntw() {
25+
uint64_t cnt;
26+
asm volatile(
27+
"rdsvl %[res], #1\n"
28+
"lsr %[res], %[res], #2\n"
29+
: [res] "=r" (cnt) ::
30+
);
31+
return cnt;
32+
}
33+
34+
#if defined(__ARM_FEATURE_SME) && defined(__ARM_FEATURE_LOCALLY_STREAMING) && defined(__clang__) && __clang_major__ >= 16
35+
// Outer product kernel.
36+
// Computes a 2SVL x 2SVL block of C, utilizing all four FP32 tiles of ZA.
37+
__attribute__((always_inline)) inline void
38+
kernel_2x2(const float *A, const float *B, float *C, size_t shared_dim,
39+
size_t ldc, size_t block_rows, size_t block_cols, float alpha, uint64_t row_idx)
40+
__arm_out("za") __arm_streaming {
41+
const uint64_t svl = svcntw();
42+
size_t ldb = ldc;
43+
// Predicate set-up
44+
svbool_t pg = svptrue_b32();
45+
svbool_t pg_a_0 = svwhilelt_b32_u64(0, block_rows);
46+
svbool_t pg_a_1 = svwhilelt_b32_u64(svl, block_rows);
47+
48+
#if (!defined(TRANSA) && !defined(UPPER)) || (defined(TRANSA) && defined(UPPER))
49+
#define pg_a_0_full pg_a_0
50+
#define pg_a_1_full pg_a_1
51+
#endif
52+
svbool_t pg_b_0 = svwhilelt_b32_u64(0, block_cols);
53+
svbool_t pg_b_1 = svwhilelt_b32_u64(svl, block_cols);
54+
55+
#define pg_c_0 pg_b_0
56+
#define pg_c_1 pg_b_1
57+
58+
svzero_za();
59+
svfloat32_t alpha_vec = svdup_f32(alpha);
60+
// Iterate through shared dimension (K)
61+
#if (!defined(TRANSA) && defined(UPPER)) || (defined(TRANSA) && !defined(UPPER))
62+
for (size_t k = row_idx, valid_index = 1; k < shared_dim; k++,valid_index++) {
63+
pg_a_0 = svwhilelt_b32_u64(0, MIN(valid_index, block_rows));
64+
pg_a_1 = svwhilelt_b32_u64(svl, MIN(valid_index, block_rows));
65+
#else
66+
for (size_t k = 0; k < MIN(row_idx + block_rows, shared_dim); k++) {
67+
// If k exceeds row_idx, mask out rows before (k - row_idx)
68+
// This ensures only valid rows are included for lower triangular logic.
69+
if (k > row_idx) {
70+
pg_a_0 = svnot_b_z(pg_a_0_full, svwhilelt_b32_u64(0, k - row_idx));
71+
pg_a_1 = svnot_b_z(pg_a_1_full, svwhilelt_b32_u64(svl, k - row_idx));
72+
}
73+
#endif
74+
75+
#if !defined(TRANSA)
76+
// Load column of A
77+
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * svl]);
78+
svfloat32_t col_a_1 = svld1(pg_a_1, &A[(k + shared_dim) * svl]);
79+
#else
80+
svfloat32_t col_a_0 = svld1(pg_a_0, &A[k * shared_dim]);
81+
svfloat32_t col_a_1 = svld1(pg_a_1, &A[k * shared_dim + svl]);
82+
#endif
83+
col_a_0 = svmul_x(pg_a_0, alpha_vec, col_a_0);
84+
col_a_1 = svmul_x(pg_a_1, alpha_vec, col_a_1);
85+
// Load row of B
86+
svfloat32_t row_b_0 = svld1(pg_b_0, &B[k * ldb]);
87+
svfloat32_t row_b_1 = svld1(pg_b_1, &B[k * ldb + svl]);
88+
// Perform outer product
89+
svmopa_za32_m(/*tile*/0, pg_a_0, pg, col_a_0, row_b_0);
90+
svmopa_za32_m(/*tile*/1, pg_a_0, pg, col_a_0, row_b_1);
91+
svmopa_za32_m(/*tile*/2, pg_a_1, pg, col_a_1, row_b_0);
92+
svmopa_za32_m(/*tile*/3, pg_a_1, pg, col_a_1, row_b_1);
93+
}
94+
95+
// Store to C from ZA
96+
for (size_t i = 0; i < MIN(svl, block_rows); i++) {
97+
svst1_hor_za32(/*tile*/0, /*slice*/i, pg_c_0, &C[i * ldc]);
98+
svst1_hor_za32(/*tile*/1, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
99+
}
100+
for (size_t i = svl; i < block_rows; i++) {
101+
svst1_hor_za32(/*tile*/2, /*slice*/i, pg_c_0, &C[i * ldc]);
102+
svst1_hor_za32(/*tile*/3, /*slice*/i, pg_c_1, &C[i * ldc + svl]);
103+
}
104+
105+
}
106+
107+
__arm_new("za") __arm_locally_streaming
108+
static inline void strmm_direct_alpha_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
109+
const float *ba, float *restrict bb) {
110+
const uint64_t num_rows = m;
111+
const uint64_t num_cols = n;
112+
113+
const float *restrict a_ptr = ba;
114+
const float *restrict b_ptr = bb;
115+
float *restrict c_ptr = bb;
116+
117+
const uint64_t svl = svcntw();
118+
const uint64_t svl_x2 = 2*svl;
119+
const uint64_t ldc = n;
120+
121+
122+
uint64_t row_idx = 0;
123+
#if (!defined(TRANSA) && defined(UPPER)) || (defined(TRANSA) && !defined(UPPER))
124+
// 2x2 loop
125+
uint64_t row_batch = svl_x2;
126+
// Block over rows of C (panels of A)
127+
for (; row_idx < num_rows; row_idx += row_batch) {
128+
row_batch = MIN(row_batch, num_rows - row_idx);
129+
#else
130+
// Calculate the remainder of num_rows divided by 2VL to determine tail tile size
131+
uint64_t row_batch = num_rows % svl_x2;
132+
// If there's no remainder, use full tile size (2VL) for initial batch
133+
if (row_batch == 0) row_batch = svl_x2;
134+
// Loop from bottom to top, processing rows in batches
135+
for (uint64_t index = num_rows; index > 0; index -= row_batch, row_batch = svl_x2) {
136+
// Compute the starting row index for the current batch
137+
row_idx = index - row_batch;
138+
#endif
139+
uint64_t col_idx = 0;
140+
uint64_t col_batch = svl_x2;
141+
// Block over column dimension of C
142+
for (; col_idx < num_cols; col_idx += col_batch) {
143+
col_batch = MIN(col_batch, num_cols - col_idx);
144+
#if !defined(TRANSA)
145+
kernel_2x2(&a_ptr[row_idx * k], &b_ptr[col_idx],
146+
&c_ptr[row_idx * ldc + col_idx], k,
147+
ldc, row_batch, col_batch, *alpha, row_idx);
148+
#else
149+
kernel_2x2(&a_ptr[row_idx], &b_ptr[col_idx],
150+
&c_ptr[row_idx * ldc + col_idx], k,
151+
ldc, row_batch, col_batch, *alpha, row_idx);
152+
#endif
153+
}
154+
}
155+
156+
return;
157+
}
158+
159+
#else
160+
void strmm_direct_alpha_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n, const float* alpha,\
161+
const float *ba, float *restrict bb){}
162+
#endif
163+
164+
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
165+
BLASLONG strideA, float * __restrict B, BLASLONG strideB){
166+
#if !defined(TRANSA)
167+
uint64_t m_mod, vl_elms;
168+
169+
vl_elms = sve_cntw();
170+
171+
m_mod = ceil((double)M/(double)vl_elms) * vl_elms;
172+
173+
float *A_mod = (float *) malloc(m_mod*K*sizeof(float));
174+
#if defined(UPPER)
175+
strmm_direct_sme1_preprocess_UN(M, K, A, A_mod);
176+
#else
177+
strmm_direct_sme1_preprocess_LN(M, K, A, A_mod);
178+
#endif
179+
/* Calculate B = alpha*A*B*/
180+
strmm_direct_alpha_sme1_2VLx2VL(M, K, N, &alpha, A_mod, B);
181+
free(A_mod);
182+
#else
183+
strmm_direct_alpha_sme1_2VLx2VL(M, K, N, &alpha, A, B);
184+
#endif
185+
}
186+
187+
#else
188+
void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float alpha, float * __restrict A,\
189+
BLASLONG strideA, float * __restrict B, BLASLONG strideB){}
190+
191+
#endif

0 commit comments

Comments
 (0)