diff --git a/interface/CMakeLists.txt b/interface/CMakeLists.txt index ef9f81be82..25b127d9b2 100644 --- a/interface/CMakeLists.txt +++ b/interface/CMakeLists.txt @@ -124,10 +124,8 @@ foreach (CBLAS_FLAG ${CBLAS_FLAGS}) #sdsdot, dsdot if (BUILD_SINGLE OR BUILD_DOUBLE) GenerateNamedObjects("sdsdot.c" "" "sdsdot" ${CBLAS_FLAG} "" "" true "SINGLE") - if(CBLAS_FLAG EQUAL 1) - GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false) -endif () -endif () + GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" ${CBLAS_FLAG} "" "" false) + endif () if (BUILD_DOUBLE) GenerateNamedObjects("dsdot.c" "" "dsdot" ${CBLAS_FLAG} "" "" true "SINGLE") endif () @@ -162,10 +160,8 @@ if (BUILD_BFLOAT16) GenerateNamedObjects("tobf16.c" "DOUBLE_PREC" "sbdtobf16" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("bf16to.c" "SINGLE_PREC" "sbf16tos" ${CBLAS_FLAG} "" "" true "BFLOAT16") GenerateNamedObjects("bf16to.c" "DOUBLE_PREC" "dbf16tod" ${CBLAS_FLAG} "" "" true "BFLOAT16") - if(CBLAS_FLAG EQUAL 1) GenerateNamedObjects("gemm_batch.c" "" "sbgemm_batch" ${CBLAS_FLAG} "" "" true "BFLOAT16") endif () -endif () if (BUILD_HFLOAT16) GenerateNamedObjects("gemm.c" "" "shgemm" ${CBLAS_FLAG} "" "" true "HFLOAT16") endif () @@ -197,9 +193,7 @@ foreach (float_type ${FLOAT_TYPES}) GenerateNamedObjects("max.c" "USE_ABS" "scamax" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("asum.c" "" "scasum" ${CBLAS_FLAG} "" "" true "COMPLEX") GenerateNamedObjects("sum.c" "" "scsum" ${CBLAS_FLAG} "" "" true "COMPLEX") - if(CBLAS_FLAG EQUAL 1) - GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX") - endif () + GenerateNamedObjects("gemm_batch.c" "" "cgemm_batch" ${CBLAS_FLAG} "" "" true "COMPLEX") endif () if (${float_type} STREQUAL "ZCOMPLEX") GenerateNamedObjects("zscal.c" "SSCAL" "dscal" ${CBLAS_FLAG} "" "" false "ZCOMPLEX") @@ -209,9 +203,7 @@ foreach (float_type ${FLOAT_TYPES}) GenerateNamedObjects("max.c" "USE_ABS" "dzamax" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("asum.c" "" "dzasum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") GenerateNamedObjects("sum.c" "" "dzsum" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") - if(CBLAS_FLAG EQUAL 1) - GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") - endif () + GenerateNamedObjects("gemm_batch.c" "" "zgemm_batch" ${CBLAS_FLAG} "" "" true "ZCOMPLEX") endif () endforeach () @@ -262,7 +254,7 @@ if ( BUILD_COMPLEX AND NOT BUILD_SINGLE) GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "SINGLE") GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "SINGLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "SINGLE") - GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "SINGLE") + GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 0 "" "" false "SINGLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "SINGLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "SINGLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "SINGLE") @@ -276,7 +268,7 @@ if ( BUILD_COMPLEX16 AND NOT BUILD_DOUBLE) GenerateNamedObjects("nrm2.c" "" "nrm2" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemv.c" "" "gemv" 0 "" "" false "DOUBLE") GenerateNamedObjects("gemm.c" "" "gemm" 0 "" "" false "DOUBLE") - GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 1 "" "" false "DOUBLE") + GenerateNamedObjects("gemm_batch.c" "" "gemm_batch" 0 "" "" false "DOUBLE") GenerateNamedObjects("asum.c" "" "asum" 0 "" "" false "DOUBLE") GenerateNamedObjects("swap.c" "" "swap" 0 "" "" false "DOUBLE") GenerateNamedObjects("axpy.c" "" "axpy" 0 "" "" false "DOUBLE") diff --git a/interface/Makefile b/interface/Makefile index e9b9d2e631..9212abe694 100644 --- a/interface/Makefile +++ b/interface/Makefile @@ -72,7 +72,8 @@ SBLAS3OBJS = \ sgemm.$(SUFFIX) ssymm.$(SUFFIX) strmm.$(SUFFIX) \ strsm.$(SUFFIX) ssyrk.$(SUFFIX) ssyr2k.$(SUFFIX) \ somatcopy.$(SUFFIX) simatcopy.$(SUFFIX)\ - sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX) + sgeadd.$(SUFFIX) sgemmt.$(SUFFIX) sgemmtr.$(SUFFIX) \ + sgemm_batch.$(SUFFIX) ifeq ($(BUILD_BFLOAT16),1) BBLAS3OBJS = bgemm.$(SUFFIX) @@ -80,7 +81,7 @@ BBLAS2OBJS = bgemv.$(SUFFIX) BBLAS1OBJS = bscal.$(SUFFIX) SBBLAS1OBJS = sbdot.$(SUFFIX) SBBLAS2OBJS = sbgemv.$(SUFFIX) -SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) +SBBLAS3OBJS = sbgemm.$(SUFFIX) sbgemmt.$(SUFFIX) sbgemmtr.$(SUFFIX) sbgemm_batch.$(SUFFIX) SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX) endif @@ -111,7 +112,8 @@ DBLAS3OBJS = \ dgemm.$(SUFFIX) dsymm.$(SUFFIX) dtrmm.$(SUFFIX) \ dtrsm.$(SUFFIX) dsyrk.$(SUFFIX) dsyr2k.$(SUFFIX) \ domatcopy.$(SUFFIX) dimatcopy.$(SUFFIX)\ - dgeadd.$(SUFFIX) dgemmt.$(SUFFIX) dgemmtr.$(SUFFIX) + dgeadd.$(SUFFIX) dgemmt.$(SUFFIX) dgemmtr.$(SUFFIX) \ + dgemm_batch.$(SUFFIX) CBLAS1OBJS = \ caxpy.$(SUFFIX) caxpyc.$(SUFFIX) cswap.$(SUFFIX) \ @@ -140,7 +142,8 @@ CBLAS3OBJS = \ ctrsm.$(SUFFIX) csyrk.$(SUFFIX) csyr2k.$(SUFFIX) \ chemm.$(SUFFIX) cherk.$(SUFFIX) cher2k.$(SUFFIX) \ comatcopy.$(SUFFIX) cimatcopy.$(SUFFIX)\ - cgeadd.$(SUFFIX) cgemmt.$(SUFFIX) cgemmtr.$(SUFFIX) + cgeadd.$(SUFFIX) cgemmt.$(SUFFIX) cgemmtr.$(SUFFIX) \ + cgemm_batch.$(SUFFIX) ZBLAS1OBJS = \ zaxpy.$(SUFFIX) zaxpyc.$(SUFFIX) zswap.$(SUFFIX) \ @@ -169,7 +172,8 @@ ZBLAS3OBJS = \ ztrsm.$(SUFFIX) zsyrk.$(SUFFIX) zsyr2k.$(SUFFIX) \ zhemm.$(SUFFIX) zherk.$(SUFFIX) zher2k.$(SUFFIX) \ zomatcopy.$(SUFFIX) zimatcopy.$(SUFFIX)\ - zgeadd.$(SUFFIX) zgemmt.$(SUFFIX) zgemmtr.$(SUFFIX) + zgeadd.$(SUFFIX) zgemmt.$(SUFFIX) zgemmtr.$(SUFFIX) \ + zgemm_batch.$(SUFFIX) ifeq ($(SUPPORT_GEMM3M), 1) @@ -2539,3 +2543,19 @@ cblas_cgemm_batch.$(SUFFIX) cblas_cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param cblas_zgemm_batch.$(SUFFIX) cblas_zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h $(CC) -c $(CFLAGS) -DCBLAS $< -o $(@F) + +sbgemm_batch.$(SUFFIX) sbgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +sgemm_batch.$(SUFFIX) sgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +dgemm_batch.$(SUFFIX) dgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +cgemm_batch.$(SUFFIX) cgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + +zgemm_batch.$(SUFFIX) zgemm_batch.$(PSUFFIX) : gemm_batch.c ../param.h + $(CC) -c $(CFLAGS) -UCBLAS $< -o $(@F) + diff --git a/interface/gemm_batch.c b/interface/gemm_batch.c index 56ccc12ce4..b0d5b29cc2 100644 --- a/interface/gemm_batch.c +++ b/interface/gemm_batch.c @@ -114,6 +114,17 @@ static size_t zgemm_small_kernel_b0[] = { #endif #endif +#ifndef CBLAS +void NAME(char *transa_array, char *transb_array, + blasint * m_array, blasint * n_array, blasint * k_array, + FLOAT * alpha_array, + IFLOAT ** a_array, blasint * lda_array, + IFLOAT ** b_array, blasint * ldb_array, + FLOAT * beta_array, + FLOAT ** c_array, blasint * ldc_array, blasint group_count, blasint * group_size) { + +#else + void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CBLAS_TRANSPOSE * transb_array, blasint * m_array, blasint * n_array, blasint * k_array, #ifndef COMPLEX @@ -134,8 +145,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB FLOAT ** a_array=(FLOAT**)va_array; FLOAT ** b_array=(FLOAT**)vb_array; FLOAT ** c_array=(FLOAT**)vc_array; - #endif +#endif + BLASLONG group_m, group_n, group_k; + BLASLONG group_lda, group_ldb, group_ldc; + blas_arg_t * args_array=NULL; int mode=0, group_mode=0; @@ -148,8 +162,6 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB blasint info; void * group_alpha, * group_beta; - BLASLONG group_m, group_n, group_k; - BLASLONG group_lda, group_ldb, group_ldc; void * group_routine=NULL; #ifdef SMALL_MATRIX_OPT void * group_small_matrix_opt_routine=NULL; @@ -201,7 +213,8 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB group_transa = -1; group_transb = -1; info = 0; - + +#if defined(CBLAS) if (order == CblasColMajor) { group_m = m_array[i]; group_n = n_array[i]; @@ -254,7 +267,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB group_lda = ldb_array[i]; group_ldb = lda_array[i]; group_ldc = ldc_array[i]; - + if (transb_array[i] == CblasNoTrans) group_transa = 0; if (transb_array[i] == CblasTrans) group_transa = 1; #ifndef COMPLEX @@ -273,6 +286,32 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB if (transa_array[i] == CblasConjNoTrans) group_transb = 2; if (transa_array[i] == CblasConjTrans) group_transb = 3; #endif + +#else + group_m = m_array[i]; + group_n = n_array[i]; + group_k = k_array[i]; + + group_lda = lda_array[i]; + group_ldb = ldb_array[i]; + group_ldc = ldc_array[i]; + + if (transb_array[i] == 'N') group_transa = 0; + if (transb_array[i] == 'T') group_transa = 1; +#ifndef COMPLEX + if (transb_array[i] == 'C') group_transa = 1; +#else + if (transb_array[i] == 'C') group_transa = 3; +#endif + if (transa_array[i] == 'N') group_transb = 0; + if (transa_array[i] == 'T') group_transb = 1; +#ifndef COMPLEX + if (transa_array[i] == 'C') group_transb = 1; +#else + if (transa_array[i] == 'C') group_transb = 3; +#endif +#endif + group_nrowa = group_m; if (group_transa & 1) group_nrowa = group_k; group_nrowb = group_k; @@ -288,7 +327,9 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB if (group_m < 0) info = 3; if (group_transb < 0) info = 2; if (group_transa < 0) info = 1; +#if defined(CBLAS) } +#endif if (info >= 0) { BLASFUNC(xerbla)(ERROR_NAME, &info, sizeof(ERROR_NAME)); @@ -344,13 +385,17 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE * transa_array, enum CB args_array[count].alpha=group_alpha; args_array[count].beta=group_beta; +#if defined(CBLAS) if (order == CblasColMajor) { args_array[count].a=(a_array[matrix_idx+j]); args_array[count].b=(b_array[matrix_idx+j]); }else if(order == CblasRowMajor){ +#endif args_array[count].a=(b_array[matrix_idx+j]); args_array[count].b=(a_array[matrix_idx+j]); +#if defined(CBLAS) } +#endif args_array[count].c=(c_array[matrix_idx+j]);