diff --git a/sw/include/kultest/gemmx/data.h b/sw/include/kultest/gemmx/data.h new file mode 100644 index 0000000..4e3f3e4 --- /dev/null +++ b/sw/include/kultest/gemmx/data.h @@ -0,0 +1,554 @@ +#include + +static int broadcast_C = 1; + +static int channel_en_C = 255; + +static int Nbatch = 1; + +static int H = 1; + +static int W = 8; + +static int Cin = 8; + +static int Cout = 8; + +static int Kh = 1; + +static int Kw = 1; + +static int stride_h = 1; + +static int stride_w = 1; + +static int pad_h = 0; + +static int pad_w = 0; + +static int Batch = 1; + +static int M = 1; + +static int K = 1; + +static int N = 1; + +static int set_addr_remap_index_A = 2; + +static int set_addr_remap_index_B = 2; + +static int set_addr_remap_index_C = 2; + +static int set_addr_remap_index_D32 = 2; + +static int set_addr_remap_index_D8 = 2; + +static int interleaved_address = 0; + +static int delta_physical_a = 0; + +static int delta_physical_b = 64; + +static int delta_physical_d8 = 192; + +static int delta_physical_c = 128; + +static int delta_physical_d32 = 192; + +static int delta_local_a = 0; + +static int delta_local_b = 32768.0; + +static int delta_local_d8 = 98304.0; + +static int delta_local_c = 65536.0; + +static int delta_local_d32 = 98304.0; + +static int Aslstride0 = 1; + +static int Aslstride1 = 8; + +static int Atlbound0 = 1; + +static int Atlstride0 = 8; + +static int Atlbound1 = 1; + +static int Atlstride1 = 64; + +static int Atlbound2 = 1; + +static int Atlstride2 = 64; + +static int Atlbound3 = 1; + +static int Atlstride3 = 0; + +static int Atlbound4 = 1; + +static int Atlstride4 = 64; + +static int Atlbound5 = 1; + +static int Atlstride5 = 64; + +static int Atlbound6 = 1; + +static int Atlstride6 = 64; + +static int Bslstride0 = 1; + +static int Bslstride1 = 8; + +static int Btlbound0 = 1; + +static int Btlstride0 = 64; + +static int Btlbound1 = 1; + +static int Btlstride1 = 64; + +static int Btlbound2 = 1; + +static int Btlstride2 = 0; + +static int Btlbound3 = 1; + +static int Btlstride3 = 0; + +static int Cslstride0 = 8; + +static int Cslstride1 = 64; + +static int Ctlbound0 = 1; + +static int Ctlstride0 = 256; + +static int Ctlbound1 = 1; + +static int Ctlstride1 = 256; + +static int Ctlbound2 = 1; + +static int Ctlstride2 = 256; + +static int Ctlbound3 = 1; + +static int Ctlstride3 = 256; + +static int D32slstride0 = 8; + +static int D32slstride1 = 64; + +static int D32tlbound0 = 1; + +static int D32tlstride0 = 256; + +static int D32tlbound1 = 1; + +static int D32tlstride1 = 256; + +static int D32tlbound2 = 1; + +static int D32tlstride2 = 256; + +static int D32tlbound3 = 1; + +static int D32tlstride3 = 256; + +static int D8slstride0 = 1; + +static int D8slstride1 = 8; + +static int D8tlbound0 = 1; + +static int D8tlstride0 = 64; + +static int D8tlbound1 = 1; + +static int D8tlstride1 = 64; + +static int D8tlbound2 = 1; + +static int D8tlstride2 = 64; + +static int D8tlbound3 = 1; + +static int D8tlstride3 = 64; + +static int8_t subtraction_a = 0; + +static int8_t subtraction_b = 0; + +static int8_t A[64] = { + -4, + 9, + 4, + 0, + -3, + -4, + 8, + 0, + 0, + -7, + -3, + -8, + -9, + 1, + -5, + -9, + -10, + 1, + 1, + 6, + -1, + 5, + 4, + 4, + 8, + 1, + 9, + -8, + -6, + 8, + -4, + -2, + -4, + 7, + -7, + 3, + 7, + -2, + -9, + 9, + 4, + -4, + 1, + -3, + 4, + -8, + 3, + 6, + -7, + 7, + -3, + -7, + -9, + -5, + -1, + -7, + 7, + 1, + -9, + -1, + -7, + 3, + 5, + 4, +}; + +static int8_t B[64] = { + -3, + 3, + -3, + 5, + 2, + 7, + 4, + 2, + -2, + 4, + 2, + -10, + -4, + -2, + -10, + 1, + -3, + 0, + 8, + 6, + -3, + -8, + -8, + -10, + -6, + -1, + -4, + -2, + -4, + -2, + -3, + 1, + -9, + -10, + 5, + -6, + -8, + 1, + -3, + -8, + -10, + -8, + -6, + 4, + 3, + -8, + -10, + -6, + 3, + -4, + -2, + 4, + 4, + -1, + 2, + 8, + -4, + 6, + 9, + -7, + -6, + -4, + 2, + 4, +}; + +static int C[64] = { + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, + 23841962, + -201600484, + -91061213, + 828299788, + -743115617, + -649117882, + 170731194, + -237146894, +}; + +static int transposed_A = 0; + +static int transposed_B = 0; + +static int D32[64] = { + 23841987, + -201600492, + -91061192, + 828299783, + -743115655, + -649117995, + 170731146, + -237146738, + 23841861, + -201600363, + -91061136, + 828299863, + -743115354, + -649117771, + 170731077, + -237146903, + 23842079, + -201600560, + -91061248, + 828299817, + -743115599, + -649117879, + 170731213, + -237146871, + 23841898, + -201600352, + -91061207, + 828299737, + -743115522, + -649118086, + 170731108, + -237146795, + 23842013, + -201600417, + -91061262, + 828299839, + -743115807, + -649117771, + 170731264, + -237146936, + 23841896, + -201600500, + -91061267, + 828299767, + -743115687, + -649117898, + 170731286, + -237146866, + 23841907, + -201600329, + -91061113, + 828299891, + -743115471, + -649117813, + 170731034, + -237146758, + 23842001, + -201600526, + -91061395, + 828299794, + -743115717, + -649118029, + 170731236, + -237146934, +}; + +static int bypassSIMD = 0; + +static int8_t input_zp_i = -43; + +static int8_t output_zp_i = -101; + +static int8_t max_int_i = 127; + +static int8_t min_int_i = -128; + +static int8_t double_round_i = 0; + +static int shared_bitpacked_shift0 = 1026304257; + +static int shared_bitpacked_shift1 = 453325880; + +static int shared_multiplier0 = 1304261659; + +static int shared_multiplier1 = -1209289109; + +static int shared_multiplier2 = -1346171349; + +static int shared_multiplier3 = -358587053; + +static int shared_multiplier4 = 1686028061; + +static int shared_multiplier5 = 1646176189; + +static int shared_multiplier6 = 168973642; + +static int shared_multiplier7 = -754432385; + +static int8_t D8[64] = { + 128, + 127, + 127, + 154, + 137, + 128, + 127, + 128, + 128, + 127, + 127, + 154, + 137, + 128, + 127, + 128, + 128, + 127, + 127, + 154, + 137, + 128, + 127, + 128, + 128, + 127, + 127, + 154, + 137, + 128, + 127, + 128, + 128, + 127, + 127, + 154, + 137, + 128, + 128, + 128, + 127, + 127, + 127, + 154, + 137, + 128, + 128, + 128, + 128, + 127, + 127, + 154, + 137, + 128, + 128, + 128, + 127, + 127, + 127, + 154, + 137, + 128, + 127, + 128, +}; diff --git a/sw/include/kultest/gemmx/snax-gemmx-lib.h b/sw/include/kultest/gemmx/snax-gemmx-lib.h new file mode 100644 index 0000000..acde776 --- /dev/null +++ b/sw/include/kultest/gemmx/snax-gemmx-lib.h @@ -0,0 +1,115 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#include +#include "../snrt/snrt_TO.h" +#include "stdint.h" +#include "streamer_csr_addr_map.h" + +#pragma once + +#define GEMMX_CSR_ADDR_BASE (STREAMER_PERFORMANCE_COUNTER_CSR + 1) +#define T_BOUND_K (GEMMX_CSR_ADDR_BASE) +#define T_BOUND_N (T_BOUND_K + 1) +#define T_BOUND_M (T_BOUND_N + 1) + +#define SUBTRACTIONS (T_BOUND_M + 1) + +#define SIMD_CSR0 (SUBTRACTIONS + 1) +#define SIMD_CSR1 (SIMD_CSR0 + 1) + +#define SIMD_SHARED_BITPACKED_SHIFT0 (SIMD_CSR1 + 1) +#define SIMD_SHARED_BITPACKED_SHIFT1 (SIMD_SHARED_BITPACKED_SHIFT0 + 1) + +#define SIMD_SHARED_MULTIPLIER0 (SIMD_SHARED_BITPACKED_SHIFT1 + 1) +#define SIMD_SHARED_MULTIPLIER1 (SIMD_SHARED_MULTIPLIER0 + 1) +#define SIMD_SHARED_MULTIPLIER2 (SIMD_SHARED_MULTIPLIER1 + 1) +#define SIMD_SHARED_MULTIPLIER3 (SIMD_SHARED_MULTIPLIER2 + 1) +#define SIMD_SHARED_MULTIPLIER4 (SIMD_SHARED_MULTIPLIER3 + 1) +#define SIMD_SHARED_MULTIPLIER5 (SIMD_SHARED_MULTIPLIER4 + 1) +#define SIMD_SHARED_MULTIPLIER6 (SIMD_SHARED_MULTIPLIER5 + 1) +#define SIMD_SHARED_MULTIPLIER7 (SIMD_SHARED_MULTIPLIER6 + 1) + +#define TEMPORAL_LOOP_BOUND (SIMD_SHARED_MULTIPLIER7 + 1) +#define BYPASS_SIMD (TEMPORAL_LOOP_BOUND + 1) + +#define GEMMX_START (BYPASS_SIMD + 1) +#define GEMMX_BUSY (GEMMX_START + 1) +#define GEMMX_PERFORMANCE_COUNTER (GEMMX_BUSY + 1) + +// Pack matrix size setting to one CSR +int32_t gen_size_config(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N); + +// Pack two subtraction values to one CSR +int32_t gen_subtraction_config(int8_t subtraction_a, int8_t subtraction_b); + +// generate the configuration for CSR0 +int32_t gen_csr0_config(uint8_t input_zp_i, uint8_t output_zp_i, + uint8_t max_int_i, uint8_t min_int_i); + +// generate the configuration for CSR1 +int32_t gen_csr1_config(bool double_round_i); + +// Set STREAMER configuration CSR +void set_gemmx_streamer_csr( + int Aslstride0, int Aslstride1, int Atlbound0, int Atlstride0, + int Atlbound1, int Atlstride1, int Atlbound2, int Atlstride2, int Atlbound3, + int Atlstride3, int Atlbound4, int Atlstride4, int Atlbound5, + int Atlstride5, int set_addr_remap_index_A, + + int Bslstride0, int Bslstride1, int Btlbound0, int Btlstride0, + int Btlbound1, int Btlstride1, int Btlbound2, int Btlstride2, + int set_addr_remap_index_B, + + int D8slstride0, int D8slstride1, int D8tlbound0, int D8tlstride0, + int D8tlbound1, int D8tlstride1, int D8tlbound2, int D8tlstride2, + int set_addr_remap_index_D8, + + int Cslstride0, int Cslstride1, int Ctlbound0, int Ctlstride0, + int Ctlbound1, int Ctlstride1, int Ctlbound2, int Ctlstride2, + int set_addr_remap_index_C, + + int D32slstride0, int D32slstride1, int D32tlbound0, int D32tlstride0, + int D32tlbound1, int D32tlstride1, int D32tlbound2, int D32tlstride2, + int set_addr_remap_index_D32, + + int delta_local_a, int delta_local_b, int delta_local_d8, int delta_local_c, + int delta_local_d32, int bypassSIMD, int32_t transpose_A, + int32_t transpose_B, int32_t channel_en_C, int32_t broadcast_C); + +// Set CSR to start STREAMER +inline void set_gemmx_streamer_start() { write_csr(STREAMER_START_CSR, 1); } + +// Set GEMM configuration CSR +void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2, + int subtractions, uint32_t csr0, uint32_t csr1, + int shared_bitpacked_shift0, int shared_bitpacked_shift1, + int shared_multiplier0, int shared_multiplier1, + int shared_multiplier2, int shared_multiplier3, + int shared_multiplier4, int shared_multiplier5, + int shared_multiplier6, int shared_multiplier7, + uint32_t temporal_loop_bound, uint32_t bypassSIMD); + +// Set CSR to start GEMM +inline void set_gemmx_start() { write_csr(GEMMX_START, 1); } + +// Poll until Streamer and GEMM accelerator finish +void wait_gemmx_and_streamer(); + +// Read performance counter of the Streamer, a read-only CSR +uint32_t read_gemmx_streamer_perf_counter(); + +// Read performance counter of GEMM, a read-only CSR +uint32_t read_gemmx_perf_counter(); + +// Check the result of the implicit im2col convolution +uint32_t check_gemmx_result_D8(int8_t* output, int8_t* output_golden, + int32_t Batch, int32_t M, int32_t N, + bool banked_data_layout); + +uint32_t check_gemmx_result_D32(int32_t* output, int32_t* output_golden, + int32_t Batch, int32_t M, int32_t N, + bool banked_data_layout); diff --git a/sw/include/kultest/gemmx/snax-gemmx-params.h b/sw/include/kultest/gemmx/snax-gemmx-params.h new file mode 100644 index 0000000..cea91a4 --- /dev/null +++ b/sw/include/kultest/gemmx/snax-gemmx-params.h @@ -0,0 +1,11 @@ +// Copyright 2023 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#pragma once + +#define meshRow 8 +#define tileSize 8 +#define meshCol 8 diff --git a/sw/include/kultest/gemmx/streamer_csr_addr_map.h b/sw/include/kultest/gemmx/streamer_csr_addr_map.h new file mode 100644 index 0000000..0070163 --- /dev/null +++ b/sw/include/kultest/gemmx/streamer_csr_addr_map.h @@ -0,0 +1,83 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi +// This file is generated by Streamer module in hw/chisel to map the CSR address of Streamer automatically, do not modify it manually +// Generated at 2024-12-05T20:22:29.903791Z + +// CSR Map for READER_0 +#define BASE_PTR_READER_0_LOW 960 +#define BASE_PTR_READER_0_HIGH 961 +#define S_STRIDE_READER_0_0 962 +#define T_BOUND_READER_0_0 963 +#define T_BOUND_READER_0_1 964 +#define T_BOUND_READER_0_2 965 +#define T_BOUND_READER_0_3 966 +#define T_BOUND_READER_0_4 967 +#define T_BOUND_READER_0_5 968 +#define T_STRIDE_READER_0_0 969 +#define T_STRIDE_READER_0_1 970 +#define T_STRIDE_READER_0_2 971 +#define T_STRIDE_READER_0_3 972 +#define T_STRIDE_READER_0_4 973 +#define T_STRIDE_READER_0_5 974 +#define ADDR_REMAP_INDEX_READER_0 975 +// CSR Map for READER_1 +#define BASE_PTR_READER_1_LOW 976 +#define BASE_PTR_READER_1_HIGH 977 +#define S_STRIDE_READER_1_0 978 +#define T_BOUND_READER_1_0 979 +#define T_BOUND_READER_1_1 980 +#define T_BOUND_READER_1_2 981 +#define T_STRIDE_READER_1_0 982 +#define T_STRIDE_READER_1_1 983 +#define T_STRIDE_READER_1_2 984 +#define ADDR_REMAP_INDEX_READER_1 985 +// CSR Map for WRITER_0 +#define BASE_PTR_WRITER_0_LOW 986 +#define BASE_PTR_WRITER_0_HIGH 987 +#define S_STRIDE_WRITER_0_0 988 +#define T_BOUND_WRITER_0_0 989 +#define T_BOUND_WRITER_0_1 990 +#define T_BOUND_WRITER_0_2 991 +#define T_STRIDE_WRITER_0_0 992 +#define T_STRIDE_WRITER_0_1 993 +#define T_STRIDE_WRITER_0_2 994 +#define ADDR_REMAP_INDEX_WRITER_0 995 +// CSR Map for READER_WRITER_0 +#define BASE_PTR_READER_WRITER_0_LOW 996 +#define BASE_PTR_READER_WRITER_0_HIGH 997 +#define S_STRIDE_READER_WRITER_0_0 998 +#define S_STRIDE_READER_WRITER_0_1 999 +#define T_BOUND_READER_WRITER_0_0 1000 +#define T_BOUND_READER_WRITER_0_1 1001 +#define T_BOUND_READER_WRITER_0_2 1002 +#define T_STRIDE_READER_WRITER_0_0 1003 +#define T_STRIDE_READER_WRITER_0_1 1004 +#define T_STRIDE_READER_WRITER_0_2 1005 +#define ADDR_REMAP_INDEX_READER_WRITER_0 1006 +#define ENABLED_CHANNEL_READER_WRITER_0 1007 +// CSR Map for READER_WRITER_1 +#define BASE_PTR_READER_WRITER_1_LOW 1008 +#define BASE_PTR_READER_WRITER_1_HIGH 1009 +#define S_STRIDE_READER_WRITER_1_0 1010 +#define S_STRIDE_READER_WRITER_1_1 1011 +#define T_BOUND_READER_WRITER_1_0 1012 +#define T_BOUND_READER_WRITER_1_1 1013 +#define T_BOUND_READER_WRITER_1_2 1014 +#define T_STRIDE_READER_WRITER_1_0 1015 +#define T_STRIDE_READER_WRITER_1_1 1016 +#define T_STRIDE_READER_WRITER_1_2 1017 +#define ADDR_REMAP_INDEX_READER_WRITER_1 1018 +#define TRANSPOSE_EXTENSION_ENABLE +#define TRANSPOSE_CSR_READER_0 1019 +#define TRANSPOSE_CSR_READER_1 1020 +#define C_BROADCAST_EXTENSION_ENABLE +#define C_BROADCAST_CSR_READER_WRITER_0 1021 +// Other resgiters +// Status register +#define STREAMER_START_CSR 1022 +// Read only CSRs +#define STREAMER_BUSY_CSR 1023 +#define STREAMER_PERFORMANCE_COUNTER_CSR 1024 diff --git a/sw/include/kultest/snax-kul-cluster-gemmx-test.h b/sw/include/kultest/snax-kul-cluster-gemmx-test.h new file mode 100644 index 0000000..ba36b71 --- /dev/null +++ b/sw/include/kultest/snax-kul-cluster-gemmx-test.h @@ -0,0 +1,21 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#pragma once + +#include "snrt/snrt_TO.h" +// #include "snrt/csr.h" + +#include "gemmx/data.h" + +#include "gemmx/snax-gemmx-params.h" +#include "gemmx/snax-gemmx-lib.h" +#include "gemmx/streamer_csr_addr_map.h" + +// This is the test function for the SNAX GEMM for Conv2d +// We use several nested loops to iterate over the input data and weights, +// achieving implicit im2col +int kul_cluster_gemmx_test(); diff --git a/sw/include/kultest/snax-kul-cluster-xdma-test.h b/sw/include/kultest/snax-kul-cluster-xdma-test.h new file mode 100644 index 0000000..0037134 --- /dev/null +++ b/sw/include/kultest/snax-kul-cluster-xdma-test.h @@ -0,0 +1,22 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#pragma once + +#include "snrt/snrt_TO.h" +#include "snrt/csr.h" + +#include "xdma/data.h" + +#include "xdma/snax-xdma-csr-addr.h" +#include "xdma/snax-xdma-lib.h" +// #include "xdma/streamer_csr_addr_map.h" + + +// This is the test function for the SNAX GEMM for Conv2d +// We use several nested loops to iterate over the input data and weights, +// achieving implicit im2col +int kul_cluster_xdma_test(); diff --git a/sw/include/kultest/snrt/csr.h b/sw/include/kultest/snrt/csr.h new file mode 100644 index 0000000..790fb11 --- /dev/null +++ b/sw/include/kultest/snrt/csr.h @@ -0,0 +1,565 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Yunhao Deng + +// This file provides the function to read and write CSR with CSR address in +// register As CSR instruction in RISC-V is immediate-number addressed, this +// workaround function deploys a switch-case to map the CSR address to implement +// pseudo register-mapping mechanism. To avoid the loss of performance, the +// function is defined in header, so that it can be compiled together with the +// main program, and optimized by the compiler. If @csr_address is provided in +// an immediate number, (macros, constant etc.) the compiler won't add it in a +// separate function creating loss in switching cycles. + +#ifndef CSR_H +#define CSR_H +#define CSR_LONG_ADDR_MODE +// Uncomment the above line to enable 64 CSRs addressability, with the down side +// of larger binary size. + +static void write_csr_obs(uint32_t value) { + write_csr(1989, value); + return; +} + +static uint32_t read_csr_obs(void) { return read_csr(1989); } + +static uint32_t csrr_ss(uint32_t csr_address) { + uint32_t value; + switch (csr_address) { + case 960: + return read_csr(960); + case 961: + return read_csr(961); + case 962: + return read_csr(962); + case 963: + return read_csr(963); + case 964: + return read_csr(964); + case 965: + return read_csr(965); + case 966: + return read_csr(966); + case 967: + return read_csr(967); + case 968: + return read_csr(968); + case 969: + return read_csr(969); + case 970: + return read_csr(970); + case 971: + return read_csr(971); + case 972: + return read_csr(972); + case 973: + return read_csr(973); + case 974: + return read_csr(974); + case 975: + return read_csr(975); + case 976: + return read_csr(976); + case 977: + return read_csr(977); + case 978: + return read_csr(978); + case 979: + return read_csr(979); + case 980: + return read_csr(980); + case 981: + return read_csr(981); + case 982: + return read_csr(982); + case 983: + return read_csr(983); + case 984: + return read_csr(984); + case 985: + return read_csr(985); + case 986: + return read_csr(986); + case 987: + return read_csr(987); + case 988: + return read_csr(988); + case 989: + return read_csr(989); + case 990: + return read_csr(990); + case 991: + return read_csr(991); +#ifdef CSR_LONG_ADDR_MODE + case 992: + return read_csr(992); + case 993: + return read_csr(993); + case 994: + return read_csr(994); + case 995: + return read_csr(995); + case 996: + return read_csr(996); + case 997: + return read_csr(997); + case 998: + return read_csr(998); + case 999: + return read_csr(999); + case 1000: + return read_csr(1000); + case 1001: + return read_csr(1001); + case 1002: + return read_csr(1002); + case 1003: + return read_csr(1003); + case 1004: + return read_csr(1004); + case 1005: + return read_csr(1005); + case 1006: + return read_csr(1006); + case 1007: + return read_csr(1007); + case 1008: + return read_csr(1008); + case 1009: + return read_csr(1009); + case 1010: + return read_csr(1010); + case 1011: + return read_csr(1011); + case 1012: + return read_csr(1012); + case 1013: + return read_csr(1013); + case 1014: + return read_csr(1014); + case 1015: + return read_csr(1015); + case 1016: + return read_csr(1016); + case 1017: + return read_csr(1017); + case 1018: + return read_csr(1018); + case 1019: + return read_csr(1019); + case 1020: + return read_csr(1020); + case 1021: + return read_csr(1021); + case 1022: + return read_csr(1022); + case 1023: + return read_csr(1023); + case 1024: + return read_csr(1024); + case 1025: + return read_csr(1025); + case 1026: + return read_csr(1026); + case 1027: + return read_csr(1027); + case 1028: + return read_csr(1028); + case 1029: + return read_csr(1029); + case 1030: + return read_csr(1030); + case 1031: + return read_csr(1031); + case 1032: + return read_csr(1032); + case 1033: + return read_csr(1033); + case 1034: + return read_csr(1034); + case 1035: + return read_csr(1035); + case 1036: + return read_csr(1036); + case 1037: + return read_csr(1037); + case 1038: + return read_csr(1038); + case 1039: + return read_csr(1039); + case 1040: + return read_csr(1040); + case 1041: + return read_csr(1041); + case 1042: + return read_csr(1042); + case 1043: + return read_csr(1043); + case 1044: + return read_csr(1044); + case 1045: + return read_csr(1045); + case 1046: + return read_csr(1046); + case 1047: + return read_csr(1047); + case 1048: + return read_csr(1048); + case 1049: + return read_csr(1049); + case 1050: + return read_csr(1050); + case 1051: + return read_csr(1051); + case 1052: + return read_csr(1052); + case 1053: + return read_csr(1053); + case 1054: + return read_csr(1054); + case 1055: + return read_csr(1055); + case 1056: + return read_csr(1056); + case 1057: + return read_csr(1057); + case 1058: + return read_csr(1058); + case 1059: + return read_csr(1059); + case 1060: + return read_csr(1060); + case 1061: + return read_csr(1061); + case 1062: + return read_csr(1062); + case 1063: + return read_csr(1063); +#endif + } + return 0; +} + +static void csrw_ss(uint32_t csr_address, uint32_t value) { + switch (csr_address) { + case 960: + write_csr(960, value); + break; + case 961: + write_csr(961, value); + break; + case 962: + write_csr(962, value); + break; + case 963: + write_csr(963, value); + break; + case 964: + write_csr(964, value); + break; + case 965: + write_csr(965, value); + break; + case 966: + write_csr(966, value); + break; + case 967: + write_csr(967, value); + break; + case 968: + write_csr(968, value); + break; + case 969: + write_csr(969, value); + break; + case 970: + write_csr(970, value); + break; + case 971: + write_csr(971, value); + break; + case 972: + write_csr(972, value); + break; + case 973: + write_csr(973, value); + break; + case 974: + write_csr(974, value); + break; + case 975: + write_csr(975, value); + break; + case 976: + write_csr(976, value); + break; + case 977: + write_csr(977, value); + break; + case 978: + write_csr(978, value); + break; + case 979: + write_csr(979, value); + break; + case 980: + write_csr(980, value); + break; + case 981: + write_csr(981, value); + break; + case 982: + write_csr(982, value); + break; + case 983: + write_csr(983, value); + break; + case 984: + write_csr(984, value); + break; + case 985: + write_csr(985, value); + break; + case 986: + write_csr(986, value); + break; + case 987: + write_csr(987, value); + break; + case 988: + write_csr(988, value); + break; + case 989: + write_csr(989, value); + break; + case 990: + write_csr(990, value); + break; + case 991: + write_csr(991, value); + break; +#ifdef CSR_LONG_ADDR_MODE + case 992: + write_csr(992, value); + break; + case 993: + write_csr(993, value); + break; + case 994: + write_csr(994, value); + break; + case 995: + write_csr(995, value); + break; + case 996: + write_csr(996, value); + break; + case 997: + write_csr(997, value); + break; + case 998: + write_csr(998, value); + break; + case 999: + write_csr(999, value); + break; + case 1000: + write_csr(1000, value); + break; + case 1001: + write_csr(1001, value); + break; + case 1002: + write_csr(1002, value); + break; + case 1003: + write_csr(1003, value); + break; + case 1004: + write_csr(1004, value); + break; + case 1005: + write_csr(1005, value); + break; + case 1006: + write_csr(1006, value); + break; + case 1007: + write_csr(1007, value); + break; + case 1008: + write_csr(1008, value); + break; + case 1009: + write_csr(1009, value); + break; + case 1010: + write_csr(1010, value); + break; + case 1011: + write_csr(1011, value); + break; + case 1012: + write_csr(1012, value); + break; + case 1013: + write_csr(1013, value); + break; + case 1014: + write_csr(1014, value); + break; + case 1015: + write_csr(1015, value); + break; + case 1016: + write_csr(1016, value); + break; + case 1017: + write_csr(1017, value); + break; + case 1018: + write_csr(1018, value); + break; + case 1019: + write_csr(1019, value); + break; + case 1020: + write_csr(1020, value); + break; + case 1021: + write_csr(1021, value); + break; + case 1022: + write_csr(1022, value); + break; + case 1023: + write_csr(1023, value); + break; + case 1024: + write_csr(1024, value); + break; + case 1025: + write_csr(1025, value); + break; + case 1026: + write_csr(1026, value); + break; + case 1027: + write_csr(1027, value); + break; + case 1028: + write_csr(1028, value); + break; + case 1029: + write_csr(1029, value); + break; + case 1030: + write_csr(1030, value); + break; + case 1031: + write_csr(1031, value); + break; + case 1032: + write_csr(1032, value); + break; + case 1033: + write_csr(1033, value); + break; + case 1034: + write_csr(1034, value); + break; + case 1035: + write_csr(1035, value); + break; + case 1036: + write_csr(1036, value); + break; + case 1037: + write_csr(1037, value); + break; + case 1038: + write_csr(1038, value); + break; + case 1039: + write_csr(1039, value); + break; + case 1040: + write_csr(1040, value); + break; + case 1041: + write_csr(1041, value); + break; + case 1042: + write_csr(1042, value); + break; + case 1043: + write_csr(1043, value); + break; + case 1044: + write_csr(1044, value); + break; + case 1045: + write_csr(1045, value); + break; + case 1046: + write_csr(1046, value); + break; + case 1047: + write_csr(1047, value); + break; + case 1048: + write_csr(1048, value); + break; + case 1049: + write_csr(1049, value); + break; + case 1050: + write_csr(1050, value); + break; + case 1051: + write_csr(1051, value); + break; + case 1052: + write_csr(1052, value); + break; + case 1053: + write_csr(1053, value); + break; + case 1054: + write_csr(1054, value); + break; + case 1055: + write_csr(1055, value); + break; + case 1056: + write_csr(1056, value); + break; + case 1057: + write_csr(1057, value); + break; + case 1058: + write_csr(1058, value); + break; + case 1059: + write_csr(1059, value); + break; + case 1060: + write_csr(1060, value); + break; + case 1061: + write_csr(1061, value); + break; + case 1062: + write_csr(1062, value); + break; + case 1063: + write_csr(1063, value); + break; +#endif + } +} + +#endif // CSR_H diff --git a/sw/include/kultest/snrt/snrt_TO.h b/sw/include/kultest/snrt/snrt_TO.h new file mode 100644 index 0000000..d334073 --- /dev/null +++ b/sw/include/kultest/snrt/snrt_TO.h @@ -0,0 +1,242 @@ +#pragma once + +#include "stdint.h" +#include + +// -------------------------------------------------------------------------------- +// --------------------------- Core IDX functions ---------------------------------- +// -------------------------------------------------------------------------------- + +#define SNRT_CLUSTER_DM_CORE_NUM 1 + +inline uint32_t __attribute__((const)) snrt_cluster_base_addrl() { + uint32_t base_address_l; + asm("csrr %0, 0xbc1" : "=r"(base_address_l)); + return base_address_l; +} + + +inline uint32_t __attribute__((const)) snrt_cluster_base_addrh() { + uint32_t base_address_h; + asm("csrr %0, 0xbc2" : "=r"(base_address_h)); + return base_address_h; +} + +inline uint32_t __attribute__((const)) snrt_cluster_core_idx() { + // return snrt_global_core_idx() % snrt_cluster_core_num(); + uint32_t cluster_core_id; + asm("csrr %0, 0xbc3" : "=r"(cluster_core_id)); + return cluster_core_id & 0xffff; +} + +inline uint32_t __attribute__((const)) snrt_cluster_core_num() { + // return SNRT_CLUSTER_CORE_NUM; + uint32_t cluster_core_id; + asm("csrr %0, 0xbc3" : "=r"(cluster_core_id)); + return cluster_core_id >> 16; +} + +inline uint32_t __attribute__((const)) snrt_cluster_dm_core_num() { + return SNRT_CLUSTER_DM_CORE_NUM; +} + +inline uint32_t __attribute__((const)) snrt_cluster_compute_core_num() { + return snrt_cluster_core_num() - snrt_cluster_dm_core_num(); +} + +inline int __attribute__((const)) snrt_is_compute_core() { + return snrt_cluster_core_idx() < snrt_cluster_compute_core_num(); +} + +inline int __attribute__((const)) snrt_is_dm_core() { + return !snrt_is_compute_core(); +} + + +// -------------------------------------------------------------------------------- +// --------------------------- DMA functions -------------------------------------- +// -------------------------------------------------------------------------------- + +/// A DMA transfer identifier. +typedef uint32_t snrt_dma_txid_t; + +/// Initiate an asynchronous 1D DMA transfer with wide 64-bit pointers. +inline snrt_dma_txid_t snrt_dma_start_1d_wideptr(uint64_t dst, uint64_t src, + size_t size) { + // Current DMA does not allow transfers with size == 0 (blocks) + // TODO(colluca) remove this check once new DMA is integrated + if (size > 0) { + register uint32_t reg_dst_low asm("a0") = dst >> 0; // 10 + register uint32_t reg_dst_high asm("a1") = dst >> 32; // 11 + register uint32_t reg_src_low asm("a2") = src >> 0; // 12 + register uint32_t reg_src_high asm("a3") = src >> 32; // 13 + register uint32_t reg_size asm("a4") = size; // 14 + + // dmsrc a2, a3 + asm volatile( + ".word (0b0000000 << 25) | \ + ( (13) << 20) | \ + ( (12) << 15) | \ + ( 0b000 << 12) | \ + (0b0101011 << 0) \n" ::"r"(reg_src_high), + "r"(reg_src_low)); + + // dmdst a0, a1 + asm volatile( + ".word (0b0000001 << 25) | \ + ( (11) << 20) | \ + ( (10) << 15) | \ + ( 0b000 << 12) | \ + (0b0101011 << 0) \n" ::"r"(reg_dst_high), + "r"(reg_dst_low)); + + // dmcpyi a0, a4, 0b00 + register uint32_t reg_txid asm("a0"); // 10 + asm volatile( + ".word (0b0000010 << 25) | \ + ( 0b00000 << 20) | \ + ( (14) << 15) | \ + ( 0b000 << 12) | \ + ( (10) << 7) | \ + (0b0101011 << 0) \n" + : "=r"(reg_txid) + : "r"(reg_size)); + + return reg_txid; + } else { + return -1; + } +} + +/// Initiate an asynchronous 1D DMA transfer. +inline snrt_dma_txid_t snrt_dma_start_1d(void *dst, const void *src, + size_t size) { + return snrt_dma_start_1d_wideptr((size_t)dst, (size_t)src, size); +} + +/// Initiate an asynchronous 2D DMA transfer with wide 64-bit pointers. +inline snrt_dma_txid_t snrt_dma_start_2d_wideptr(uint64_t dst, uint64_t src, + size_t size, size_t dst_stride, + size_t src_stride, + size_t repeat) { + // Current DMA does not allow transfers with size == 0 (blocks) + // TODO(colluca) remove this check once new DMA is integrated + if (size > 0) { + register uint32_t reg_dst_low asm("a0") = dst >> 0; // 10 + register uint32_t reg_dst_high asm("a1") = dst >> 32; // 11 + register uint32_t reg_src_low asm("a2") = src >> 0; // 12 + register uint32_t reg_src_high asm("a3") = src >> 32; // 13 + register uint32_t reg_size asm("a4") = size; // 14 + register uint32_t reg_dst_stride asm("a5") = dst_stride; // 15 + register uint32_t reg_src_stride asm("a6") = src_stride; // 16 + register uint32_t reg_repeat asm("a7") = repeat; // 17 + + // dmsrc a0, a1 + asm volatile( + ".word (0b0000000 << 25) | \ + ( (13) << 20) | \ + ( (12) << 15) | \ + ( 0b000 << 12) | \ + (0b0101011 << 0) \n" ::"r"(reg_src_high), + "r"(reg_src_low)); + + // dmdst a0, a1 + asm volatile( + ".word (0b0000001 << 25) | \ + ( (11) << 20) | \ + ( (10) << 15) | \ + ( 0b000 << 12) | \ + (0b0101011 << 0) \n" ::"r"(reg_dst_high), + "r"(reg_dst_low)); + + // dmstr a5, a6 + asm volatile( + ".word (0b0000110 << 25) | \ + ( (15) << 20) | \ + ( (16) << 15) | \ + ( 0b000 << 12) | \ + (0b0101011 << 0) \n" + : + : "r"(reg_dst_stride), "r"(reg_src_stride)); + + // dmrep a7 + asm volatile( + ".word (0b0000111 << 25) | \ + ( (17) << 15) | \ + ( 0b000 << 12) | \ + (0b0101011 << 0) \n" + : + : "r"(reg_repeat)); + + // dmcpyi a0, a4, 0b10 + register uint32_t reg_txid asm("a0"); // 10 + asm volatile( + ".word (0b0000010 << 25) | \ + ( 0b00010 << 20) | \ + ( (14) << 15) | \ + ( 0b000 << 12) | \ + ( (10) << 7) | \ + (0b0101011 << 0) \n" + : "=r"(reg_txid) + : "r"(reg_size)); + + return reg_txid; + } else { + return -1; + } +} + +/// Initiate an asynchronous 2D DMA transfer. (for local-chip transfers) +inline snrt_dma_txid_t snrt_dma_start_2d(void *dst, const void *src, + size_t size, size_t dst_stride, + size_t src_stride, size_t repeat) { + uint64_t dst_wideptr = (uint64_t)dst; + dst_wideptr += (uint64_t)snrt_cluster_base_addrh() << 32; + uint64_t src_wideptr = (uint64_t)src; + src_wideptr += (uint64_t)snrt_cluster_base_addrh() << 32; + return snrt_dma_start_2d_wideptr(dst_wideptr, src_wideptr, size, dst_stride, + src_stride, repeat); +} + +/// Block until all operation on the DMA ceases. +inline void snrt_dma_wait_all() { + // dmstati t0, 2 # 2=status.busy + asm volatile( + "1: \n" + ".word (0b0000100 << 25) | \ + ( 0b00010 << 20) | \ + ( 0b000 << 12) | \ + ( (5) << 7) | \ + (0b0101011 << 0) \n" + "bne t0, zero, 1b \n" :: + : "t0"); +} + +//================================================================================ +// --------------------------- Barrier functions -------------------------------- +//================================================================================ + +/// Synchronize cores in a cluster with a hardware barrier +inline void snrt_cluster_hw_barrier() { + asm volatile("csrr x0, 0x7C2" ::: "memory"); +} + +// -------------------------------------------------------------------------------- +// --------------------------- CSR Write&Read functions -------------------------- +// -------------------------------------------------------------------------------- + +// #define read_csr(reg) ({ unsigned long __tmp; \ +// asm volatile ("csrr %0, " #reg : "=r"(__tmp)); \ +// __tmp; }) + +// #define write_csr(reg, val) ({ \ +// asm volatile ("csrw " #reg ", %0" :: "rK"(val)); }) + +#define STR(x) #x +#define XSTR(x) STR(x) + +#define read_csr(reg) ({ unsigned long __tmp; \ + asm volatile ("csrr %0, " XSTR(reg) : "=r"(__tmp)); \ + __tmp; }) + +#define write_csr(reg, val) ({ asm volatile ("csrw " XSTR(reg) ", %0" :: "rK"(val)); }) diff --git a/sw/include/kultest/xdma/data.h b/sw/include/kultest/xdma/data.h new file mode 100644 index 0000000..333d67f --- /dev/null +++ b/sw/include/kultest/xdma/data.h @@ -0,0 +1,187 @@ +#include + +#include + +static int input_data_len = 64; + +static int tempLoop0_in = 1; + +static int tempLoop1_in = 1; + +static int tempLoop2_in = 1; + +static int tempLoop3_in = 1; + +static int tempLoop4_in = 1; + +static int delta_local_in = 0; + +static int spatialStride1_in = 8; + +static int tempStride0_in = 8; + +static int tempStride1_in = 64; + +static int tempStride2_in = 64; + +static int tempStride3_in = 64; + +static int tempStride4_in = 64; + +static int tempLoop0_out = 1; + +static int tempLoop1_out = 1; + +static int tempLoop2_out = 1; + +static int output_data_len = 64; + +static int delta_local_out = 64; + +static int spatialStride1_out = 8; + +static int tempStride0_out = 64; + +static int tempStride1_out = 64; + +static int tempStride2_out = 64; + +static int opcode = 2; + +static int TloopLen = 1; + +static int reduceLen = 1; + +static int8_t DataIn[64] = { + -26, + 51, + -36, + -114, + -22, + -57, + 60, + -108, + -26, + -7, + 82, + 86, + -54, + 74, + -41, + -12, + -29, + -25, + 23, + 2, + 21, + -76, + -127, + -41, + 107, + 29, + -91, + 1, + 63, + 59, + -108, + 32, + 75, + -71, + -107, + 124, + 107, + -40, + -80, + 90, + -70, + 126, + 41, + 91, + 59, + 79, + -114, + 61, + 61, + 46, + 61, + -78, + -21, + -74, + 115, + -65, + 120, + 2, + 100, + -78, + 6, + -108, + -56, + 38, +}; + +static int8_t C_golden[64] = { + -26, + 51, + -36, + -114, + -22, + -57, + 60, + -108, + -26, + -7, + 82, + 86, + -54, + 74, + -41, + -12, + -29, + -25, + 23, + 2, + 21, + -76, + -127, + -41, + 107, + 29, + -91, + 1, + 63, + 59, + -108, + 32, + 75, + -71, + -107, + 124, + 107, + -40, + -80, + 90, + -70, + 126, + 41, + 91, + 59, + 79, + -114, + 61, + 61, + 46, + 61, + -78, + -21, + -74, + 115, + -65, + 120, + 2, + 100, + -78, + 6, + -108, + -56, + 38, +}; diff --git a/sw/include/kultest/xdma/snax-xdma-csr-addr.h b/sw/include/kultest/xdma/snax-xdma-csr-addr.h new file mode 100644 index 0000000..c68f3c2 --- /dev/null +++ b/sw/include/kultest/xdma/snax-xdma-csr-addr.h @@ -0,0 +1,45 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Yunhao Deng + +// This file is generated by Chisel in hw/chisel, do not modify it manually + +#define XDMA_BASE_ADDR 960 +#define XDMA_WIDTH 64 +#define XDMA_SPATIAL_CHAN 8 +#define XDMA_SRC_ADDR_PTR_LSB XDMA_BASE_ADDR +#define XDMA_SRC_ADDR_PTR_MSB XDMA_SRC_ADDR_PTR_LSB + 1 +#define XDMA_SRC_SPATIAL_DIM 1 +#define XDMA_SRC_TEMP_DIM 6 +#define XDMA_SRC_SPATIAL_STRIDE_PTR XDMA_SRC_ADDR_PTR_MSB + 1 +#define XDMA_SRC_TEMP_BOUND_PTR XDMA_SRC_SPATIAL_STRIDE_PTR + XDMA_SRC_SPATIAL_DIM +#define XDMA_SRC_TEMP_STRIDE_PTR XDMA_SRC_TEMP_BOUND_PTR + XDMA_SRC_TEMP_DIM +#define XDMA_SRC_ENABLED_CHAN_PTR XDMA_SRC_TEMP_STRIDE_PTR + XDMA_SRC_TEMP_DIM +#define XDMA_SRC_BYPASS_PTR XDMA_SRC_ENABLED_CHAN_PTR + 1 +#define XDMA_SRC_EXT_NUM 0 +#define XDMA_SRC_EXT_CSR_PTR XDMA_SRC_BYPASS_PTR + 0 +#define XDMA_SRC_EXT_CSR_NUM 0 +#define XDMA_SRC_EXT_CUSTOM_CSR_NUM \ + { } + +#define XDMA_DST_ADDR_PTR_LSB XDMA_SRC_EXT_CSR_PTR + XDMA_SRC_EXT_CSR_NUM +#define XDMA_DST_ADDR_PTR_MSB XDMA_DST_ADDR_PTR_LSB + 1 + +#define XDMA_DST_SPATIAL_DIM 1 +#define XDMA_DST_TEMP_DIM 6 +#define XDMA_DST_SPATIAL_STRIDE_PTR XDMA_DST_ADDR_PTR_MSB + 1 +#define XDMA_DST_TEMP_BOUND_PTR XDMA_DST_SPATIAL_STRIDE_PTR + XDMA_DST_SPATIAL_DIM +#define XDMA_DST_TEMP_STRIDE_PTR XDMA_DST_TEMP_BOUND_PTR + XDMA_DST_TEMP_DIM +#define XDMA_DST_ENABLED_CHAN_PTR XDMA_DST_TEMP_STRIDE_PTR + XDMA_DST_TEMP_DIM +#define XDMA_DST_ENABLED_BYTE_PTR XDMA_DST_ENABLED_CHAN_PTR + 1 +#define XDMA_DST_BYPASS_PTR XDMA_DST_ENABLED_BYTE_PTR + 1 +#define XDMA_DST_EXT_NUM 3 +#define XDMA_DST_EXT_CSR_PTR XDMA_DST_BYPASS_PTR + 1 +#define XDMA_DST_EXT_CSR_NUM 2 +#define XDMA_DST_EXT_CUSTOM_CSR_NUM \ + { 1, 1, 0 } +#define XDMA_START_PTR XDMA_DST_EXT_CSR_PTR + XDMA_DST_EXT_CSR_NUM +#define XDMA_COMMIT_TASK_PTR XDMA_START_PTR + 1 +#define XDMA_FINISH_TASK_PTR XDMA_COMMIT_TASK_PTR + 1 diff --git a/sw/include/kultest/xdma/snax-xdma-lib.h b/sw/include/kultest/xdma/snax-xdma-lib.h new file mode 100644 index 0000000..8034f51 --- /dev/null +++ b/sw/include/kultest/xdma/snax-xdma-lib.h @@ -0,0 +1,32 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Yunhao Deng + +#pragma once +#include +#include +// Define the CSR address of xdma, should be generated by scala +#include "snax-xdma-csr-addr.h" + +// Set CSR for xdma +int xdma_memcpy_nd(unsigned char* src, unsigned char* dst, unsigned int* spatial_stride_src, + unsigned int* spatial_stride_dst, unsigned int temp_dim_src, + unsigned int* temp_stride_src, unsigned int* temp_bound_src, + unsigned int temp_dim_dst, unsigned int* temp_stride_dst, + unsigned int* temp_bound_dst, unsigned int enabled_chan_src, + unsigned int enabled_chan_dst, unsigned int enabled_byte_dst); +int xdma_memcpy_1d(unsigned char* src, unsigned char* dst, unsigned int size); +int xdma_enable_src_ext(unsigned char ext, unsigned int* csr_value); +int xdma_disable_src_ext(unsigned char ext); +int xdma_enable_dst_ext(unsigned char ext, unsigned int* csr_value); +int xdma_disable_dst_ext(unsigned char ext); + +// Start xdma +unsigned int xdma_start(); + +// Check if xdma is finished +bool xdma_is_finished(unsigned int task_id); + +void xdma_wait(unsigned int task_id); diff --git a/sw/lib/kultest/snax-kul-cluster-gemmx-test.c b/sw/lib/kultest/snax-kul-cluster-gemmx-test.c new file mode 100644 index 0000000..190eb1f --- /dev/null +++ b/sw/lib/kultest/snax-kul-cluster-gemmx-test.c @@ -0,0 +1,456 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#include "soc_addr_map.h" +#include "kultest/snax-kul-cluster-gemmx-test.h" + + +int32_t gen_size_config(uint8_t Batch, uint8_t M, uint8_t K, uint8_t N) { + return ((int32_t)Batch << 24) | ((int32_t)M << 16) | ((int32_t)K << 8) | + (int32_t)N; +} + +int32_t gen_subtraction_config(int8_t subtraction_a, int8_t subtraction_b) { + return ((uint8_t)subtraction_b << 8) | (uint8_t)subtraction_a; +} + +int32_t gen_csr0_config(uint8_t input_zp_i, uint8_t output_zp_i, + uint8_t max_int_i, uint8_t min_int_i) { + // encode the configuration into a single 32-bit integer + return ((int32_t)min_int_i << 24) | ((int32_t)max_int_i << 16) | + ((int32_t)output_zp_i << 8) | (int32_t)input_zp_i; +} + +int32_t gen_csr1_config(bool double_round_i) { + // encode the configuration into a single 32-bit integer + return (uint32_t)double_round_i; +} + +// Set STREAMER configuration CSR +void set_gemmx_streamer_csr( + int Aslstride0, int Aslstride1, int Atlbound0, int Atlstride0, + int Atlbound1, int Atlstride1, int Atlbound2, int Atlstride2, int Atlbound3, + int Atlstride3, int Atlbound4, int Atlstride4, int Atlbound5, + int Atlstride5, int set_addr_remap_index_A, + + int Bslstride0, int Bslstride1, int Btlbound0, int Btlstride0, + int Btlbound1, int Btlstride1, int Btlbound2, int Btlstride2, + int set_addr_remap_index_B, + + int D8slstride0, int D8slstride1, int D8tlbound0, int D8tlstride0, + int D8tlbound1, int D8tlstride1, int D8tlbound2, int D8tlstride2, + int set_addr_remap_index_D8, + + int Cslstride0, int Cslstride1, int Ctlbound0, int Ctlstride0, + int Ctlbound1, int Ctlstride1, int Ctlbound2, int Ctlstride2, + int set_addr_remap_index_C, + + int D32slstride0, int D32slstride1, int D32tlbound0, int D32tlstride0, + int D32tlbound1, int D32tlstride1, int D32tlbound2, int D32tlstride2, + int set_addr_remap_index_D32, + + int delta_local_a, int delta_local_b, int delta_local_d8, int delta_local_c, + int delta_local_d32, int bypassSIMD, int32_t transpose_A, + int32_t transpose_B, int32_t channel_en_C, int32_t broadcast_C) { + // base ptr for A + write_csr(BASE_PTR_READER_0_LOW, (uint32_t)(delta_local_a + snrt_cluster_base_addrl())); + + // spatial strides for A + write_csr(S_STRIDE_READER_0_0, Aslstride1); + + // loop bounds, from innermost to outermost, for data mover A + write_csr(T_BOUND_READER_0_0, Atlbound0); + write_csr(T_BOUND_READER_0_1, Atlbound1); + write_csr(T_BOUND_READER_0_2, Atlbound2); + write_csr(T_BOUND_READER_0_3, Atlbound3); + write_csr(T_BOUND_READER_0_4, Atlbound4); + write_csr(T_BOUND_READER_0_5, Atlbound5); + + // temporal strides for A + write_csr(T_STRIDE_READER_0_0, Atlstride0); + write_csr(T_STRIDE_READER_0_1, Atlstride1); + write_csr(T_STRIDE_READER_0_2, Atlstride2); + write_csr(T_STRIDE_READER_0_3, Atlstride3); + write_csr(T_STRIDE_READER_0_4, Atlstride4); + write_csr(T_STRIDE_READER_0_5, Atlstride5); + + // set the address remap index for A + write_csr(ADDR_REMAP_INDEX_READER_0, set_addr_remap_index_A); + + // base ptr for B + write_csr(BASE_PTR_READER_1_LOW, (uint32_t)(delta_local_b + snrt_cluster_base_addrl())); + + // spatial strides for B + write_csr(S_STRIDE_READER_1_0, Bslstride1); + + // loop bounds, from innermost to outermost, for data mover B + write_csr(T_BOUND_READER_1_0, Btlbound0); + write_csr(T_BOUND_READER_1_1, Btlbound1); + write_csr(T_BOUND_READER_1_2, Btlbound2); + + // temporal strides for B + write_csr(T_STRIDE_READER_1_0, Btlstride0); + write_csr(T_STRIDE_READER_1_1, Btlstride1); + write_csr(T_STRIDE_READER_1_2, Btlstride2); + + // set the address remap index for B + write_csr(ADDR_REMAP_INDEX_READER_1, set_addr_remap_index_B); + + // base ptr for D8 + write_csr(BASE_PTR_WRITER_0_LOW, (uint32_t)(delta_local_d8 + snrt_cluster_base_addrl())); + + // spatial strides for D8 + write_csr(S_STRIDE_WRITER_0_0, D8slstride1); + + // for D8, from N to M + if (bypassSIMD == 0) { + write_csr(T_BOUND_WRITER_0_0, D8tlbound0); + write_csr(T_BOUND_WRITER_0_1, D8tlbound1); + write_csr(T_BOUND_WRITER_0_2, D8tlbound2); + } else { + write_csr(T_BOUND_WRITER_0_0, 0); + write_csr(T_BOUND_WRITER_0_1, 0); + write_csr(T_BOUND_WRITER_0_2, 0); + } + + // temporal strides for D8 + write_csr(T_STRIDE_WRITER_0_0, D8tlstride0); + write_csr(T_STRIDE_WRITER_0_1, D8tlstride1); + write_csr(T_STRIDE_WRITER_0_2, D8tlstride2); + + // set the address remap index for D8 + write_csr(ADDR_REMAP_INDEX_WRITER_0, set_addr_remap_index_D8); + + // base ptr for C + write_csr(BASE_PTR_READER_WRITER_0_LOW, + (uint32_t)(delta_local_c + snrt_cluster_base_addrl())); + + // spatial strides for C + write_csr(S_STRIDE_READER_WRITER_0_0, Cslstride0); + write_csr(S_STRIDE_READER_WRITER_0_1, Cslstride1); + + // loop bounds, from innermost to outermost, for data mover C + write_csr(T_BOUND_READER_WRITER_0_0, Ctlbound0); + write_csr(T_BOUND_READER_WRITER_0_1, Ctlbound1); + write_csr(T_BOUND_READER_WRITER_0_2, Ctlbound2); + + // temporal strides for C + write_csr(T_STRIDE_READER_WRITER_0_0, Ctlstride0); + write_csr(T_STRIDE_READER_WRITER_0_1, Ctlstride1); + write_csr(T_STRIDE_READER_WRITER_0_2, Ctlstride2); + + // set the address remap index for C + write_csr(ADDR_REMAP_INDEX_READER_WRITER_0, set_addr_remap_index_C); + +#ifdef ENABLED_CHANNEL_READER_WRITER_0 + write_csr(ENABLED_CHANNEL_READER_WRITER_0, channel_en_C); +#endif + +#ifdef C_BROADCAST_EXTENSION_ENABLE + write_csr(C_BROADCAST_CSR_READER_WRITER_0, broadcast_C == 1 ? 0 : 1); +#endif + + // base ptr for D32 + write_csr(BASE_PTR_READER_WRITER_1_LOW, + (uint32_t)(delta_local_d32 + snrt_cluster_base_addrl())); + + // spatial strides for D32 + write_csr(S_STRIDE_READER_WRITER_1_0, D32slstride0); + write_csr(S_STRIDE_READER_WRITER_1_1, D32slstride1); + + // for D32, from N to M + if (bypassSIMD == 0) { + write_csr(T_BOUND_READER_WRITER_1_0, 0); + write_csr(T_BOUND_READER_WRITER_1_1, 0); + write_csr(T_BOUND_READER_WRITER_1_2, 0); + } else { + write_csr(T_BOUND_READER_WRITER_1_0, D32tlbound0); + write_csr(T_BOUND_READER_WRITER_1_1, D32tlbound1); + write_csr(T_BOUND_READER_WRITER_1_2, D32tlbound2); + } + + // temporal strides for D32 + write_csr(T_STRIDE_READER_WRITER_1_0, D32tlstride0); + write_csr(T_STRIDE_READER_WRITER_1_1, D32tlstride1); + write_csr(T_STRIDE_READER_WRITER_1_2, D32tlstride2); + + // set the address remap index for D32 + write_csr(ADDR_REMAP_INDEX_READER_WRITER_1, set_addr_remap_index_D32); + + // set the transpose +#ifdef TRANSPOSE_EXTENSION_ENABLE + write_csr(TRANSPOSE_CSR_READER_0, transpose_A == 0 ? 1 : 0); + write_csr(TRANSPOSE_CSR_READER_1, transpose_B == 0 ? 1 : 0); +#endif +} + +// Set GEMM configuration CSR +void set_gemmx_csr(int tempLoop0, int tempLoop1, int tempLoop2, + int subtractions, uint32_t csr0, uint32_t csr1, + int shared_bitpacked_shift0, int shared_bitpacked_shift1, + int shared_multiplier0, int shared_multiplier1, + int shared_multiplier2, int shared_multiplier3, + int shared_multiplier4, int shared_multiplier5, + int shared_multiplier6, int shared_multiplier7, + uint32_t temporal_loop_bound, uint32_t bypassSIMD) { + // set loop bounds, from innermost to outermost, aka from K to N to M + write_csr(T_BOUND_K, tempLoop0); + write_csr(T_BOUND_N, tempLoop1); + write_csr(T_BOUND_M, tempLoop2); + + // set subtraction a and b + write_csr(SUBTRACTIONS, subtractions); + + // set the constants for the SIMD unit + write_csr(SIMD_CSR0, csr0); + write_csr(SIMD_CSR1, csr1); + + // set the shared bitpacked shift + write_csr(SIMD_SHARED_BITPACKED_SHIFT0, shared_bitpacked_shift0); + write_csr(SIMD_SHARED_BITPACKED_SHIFT1, shared_bitpacked_shift1); + + // set the shared multipliers + write_csr(SIMD_SHARED_MULTIPLIER0, shared_multiplier0); + write_csr(SIMD_SHARED_MULTIPLIER1, shared_multiplier1); + write_csr(SIMD_SHARED_MULTIPLIER2, shared_multiplier2); + write_csr(SIMD_SHARED_MULTIPLIER3, shared_multiplier3); + write_csr(SIMD_SHARED_MULTIPLIER4, shared_multiplier4); + write_csr(SIMD_SHARED_MULTIPLIER5, shared_multiplier5); + write_csr(SIMD_SHARED_MULTIPLIER6, shared_multiplier6); + write_csr(SIMD_SHARED_MULTIPLIER7, shared_multiplier7); + + // set the temporal loop bound + write_csr(TEMPORAL_LOOP_BOUND, temporal_loop_bound); + + write_csr(BYPASS_SIMD, bypassSIMD); +} + +// Stall until Streamer and GEMM accelerator finish +void wait_gemmx_and_streamer() { + write_csr(STREAMER_START_CSR, 0); + write_csr(STREAMER_START_CSR, 0); + while (read_csr(GEMMX_BUSY)) { + } + while (read_csr(STREAMER_BUSY_CSR)) { + } + write_csr(GEMMX_START, 0); +} + +// Read performance counter of the Streamer, a read-only CSR +uint32_t read_gemmx_streamer_perf_counter() { + uint32_t perf_counter = read_csr(STREAMER_PERFORMANCE_COUNTER_CSR); + return perf_counter; +} + +// Read performance counter of GEMM, a read-only CSR +uint32_t read_gemmx_perf_counter() { + uint32_t perf_counter = read_csr(GEMMX_PERFORMANCE_COUNTER); + return perf_counter; +} + +uint32_t check_gemmx_result_D8(int8_t* output, int8_t* output_golden, + int32_t Batch, int32_t M, int32_t N, + bool banked_data_layout) { + uint32_t err = 0; + uint32_t size = 0; + size = Batch * M * N * meshRow * meshCol; + + if (banked_data_layout) { + for (int i = 0; i < size / 64; i += 1) { + for (int j = 0; j < 64; j++) { + if (*(output + i * 256 + j) != output_golden[i * 64 + j]) { + err++; + } + } + } + } else { + for (int i = 0; i < size; i++) { + if (output[i] != output_golden[i]) { + err++; + } + } + } + + return err; +} + +uint32_t check_gemmx_result_D32(int32_t* output, int32_t* output_golden, + int32_t Batch, int32_t M, int32_t N, + bool banked_data_layout) { + uint32_t err = 0; + uint32_t size = 0; + size = Batch * M * N * meshRow * meshCol; + + if (banked_data_layout) { + for (int i = 0; i < size / 16; i += 1) { + for (int j = 0; j < 16; j++) { + if (*(output + i * (256 / 4) + j) != + output_golden[i * 16 + j]) { + err++; + } + } + } + } else { + for (int i = 0; i < size; i++) { + if (output[i] != output_golden[i]) { + err++; + } + } + } + + return err; +} + +// This is the test function for the SNAX GEMM for Conv2d +// We use several nested loops to iterate over the input data and weights, +// achieving implicit im2col +extern uint32_t __global_pointer$; + +int kul_cluster_gemmx_test() { + // wake up the dma core, not work... + // if (snrt_cluster_core_idx() == 0) { + // volatile uint32_t *interruptTarget = ((uint32_t *)CLINT_CTRL_BASE) + 6 + 1; + // *interruptTarget = 1; + // } + // !!! set the stack pointer and global pointer !!! + // set it to the end of the KUL cluster TCDM (size = 128KB) address - 4 + uint32_t stack_start = snrt_cluster_base_addrl() + 128 * 1024 - 4; + asm("mv sp, %0" ::"r"((uint32_t)stack_start)); + uint32_t gp_value = (uint32_t)(&__global_pointer$); + asm("mv gp, %0" ::"r"(gp_value)); + + // Set err value for checking + int err = 0; + + // Prepare addresses pointers in TCDM for DMA + int8_t *local_a_dma, *local_b_dma; + int32_t *local_c_dma, *local_d32_dma; + int8_t *local_d8_dma; + + // Allocate space in TCDM for DMA + local_a_dma = (int8_t *)(snrt_cluster_base_addrl() + delta_physical_a); + local_b_dma = (int8_t *)(snrt_cluster_base_addrl() + delta_physical_b); + local_c_dma = (int32_t *)(snrt_cluster_base_addrl() + delta_physical_c); + local_d32_dma = (int32_t *)(snrt_cluster_base_addrl() + delta_physical_d32); + local_d8_dma = (int8_t *)(snrt_cluster_base_addrl() + delta_physical_d8); + + // Prepare addresses pointers in TCDM for streamer + int8_t *local_a, *local_b; + int32_t *local_c, *local_d32; + int8_t *local_d8; + + // Allocate space in TCDM for streamer + local_a = (int8_t *)(snrt_cluster_base_addrl() + delta_local_a); + local_b = (int8_t *)(snrt_cluster_base_addrl() + delta_local_b); + local_c = (int32_t *)(snrt_cluster_base_addrl() + delta_local_c); + local_d32 = (int32_t *)(snrt_cluster_base_addrl() + delta_local_d32); + local_d8 = (int8_t *)(snrt_cluster_base_addrl() + delta_local_d8); + + // Transfer data from L3 to L1 + // Using DMA only + if (snrt_is_dm_core()) { + if (interleaved_address == 1) { + snrt_dma_start_1d(local_a, A, + Nbatch * (H + 2 * pad_h) * (W + 2 * pad_w) * Cin * + sizeof(int8_t)); + snrt_dma_start_1d(local_b, B, + Cout * Kh * Kw * Cin * sizeof(int8_t)); + } else { + snrt_dma_start_2d( + local_a_dma, A, 64 * sizeof(int8_t), 256, 64, + Nbatch * (H + 2 * pad_h) * (W + 2 * pad_w) * Cin / 64); + snrt_dma_start_2d(local_b_dma, B, 64 * sizeof(int8_t), 256, 64, + Cout * Kh * Kw * Cin / 64); + } + snrt_dma_wait_all(); + } + + // Wait for DMA to finish + snrt_cluster_hw_barrier(); + if (snrt_is_dm_core()) { + if (interleaved_address == 1) { + snrt_dma_start_1d(local_c, C, + M * N * meshRow * meshCol * sizeof(int32_t)); + } else { + snrt_dma_start_2d(local_c_dma, C, 16 * sizeof(int32_t), 256, + 16 * sizeof(int32_t), + M * N * meshRow * meshCol / 16); + } + snrt_dma_wait_all(); + } + + snrt_cluster_hw_barrier(); + + if (snrt_cluster_core_idx() == 0) { + // Set Streamer configuration CSR for conv2d + set_gemmx_streamer_csr( + Aslstride0, Aslstride1, Atlbound0, Atlstride0, Atlbound1, + Atlstride1, Atlbound2, Atlstride2, Atlbound3, Atlstride3, Atlbound4, + Atlstride4, Atlbound5, Atlstride5, set_addr_remap_index_A, + + Bslstride0, Bslstride1, Btlbound0, Btlstride0, Btlbound1, + Btlstride1, Btlbound2, Btlstride2, set_addr_remap_index_B, + + D8slstride0, D8slstride1, D8tlbound0, D8tlstride0, D8tlbound1, + D8tlstride1, D8tlbound2, D8tlstride2, set_addr_remap_index_D8, + + Cslstride0, Cslstride1, Ctlbound0, Ctlstride0, Ctlbound1, + Ctlstride1, Ctlbound2, Ctlstride2, set_addr_remap_index_C, + + D32slstride0, D32slstride1, D32tlbound0, D32tlstride0, D32tlbound1, + D32tlstride1, D32tlbound2, D32tlstride2, set_addr_remap_index_D32, + + delta_local_a, delta_local_b, delta_local_d8, delta_local_c, + delta_local_d32, bypassSIMD, transposed_A, transposed_B, + channel_en_C, broadcast_C); + + // Set GEMMX configuration CSR + uint32_t subtraction_setting = + gen_subtraction_config(subtraction_a, subtraction_b); + + uint32_t csr0 = + gen_csr0_config(input_zp_i, output_zp_i, max_int_i, min_int_i); + uint32_t csr1 = gen_csr1_config(double_round_i); + + set_gemmx_csr( + K, N, M, subtraction_setting, csr0, csr1, shared_bitpacked_shift0, + shared_bitpacked_shift1, shared_multiplier0, shared_multiplier1, + shared_multiplier2, shared_multiplier3, shared_multiplier4, + shared_multiplier5, shared_multiplier6, shared_multiplier7, M * N, + bypassSIMD); + + // Set CSR to start Streamer for conv2d + set_gemmx_streamer_start(); + + // Set CSR to start GEMM + set_gemmx_start(); + + // Poll until Streamer and GEMM accelerator finish + wait_gemmx_and_streamer(); + + // check the result of the implicit im2col convolution + if (interleaved_address == 1) { + if (!bypassSIMD) { + err += check_gemmx_result_D8(local_d8, D8, Batch, M, N, false); + } else { + err += + check_gemmx_result_D32(local_d32, D32, Batch, M, N, false); + } + } else { + if (!bypassSIMD) { + err += + check_gemmx_result_D8(local_d8_dma, D8, Batch, M, N, true); + } else { + err += check_gemmx_result_D32(local_d32_dma, D32, Batch, M, N, + true); + } + } + + }; + + return err; + +} diff --git a/sw/lib/kultest/snax-kul-cluster-xdma-test.c b/sw/lib/kultest/snax-kul-cluster-xdma-test.c new file mode 100644 index 0000000..34e62ed --- /dev/null +++ b/sw/lib/kultest/snax-kul-cluster-xdma-test.c @@ -0,0 +1,358 @@ +// Copyright 2024 KU Leuven. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Xiaoling Yi + +#include "soc_addr_map.h" +#include "kultest/snax-kul-cluster-xdma-test.h" + + +// #define XDMA_DEBUG +// #ifdef XDMA_DEBUG +// #define XDMA_DEBUG_PRINT(...) printf(__VA_ARGS__) +// #else +// #define XDMA_DEBUG_PRINT(...) +// #endif + +int xdma_memcpy_nd(unsigned char* src, unsigned char* dst, unsigned int* spatial_stride_src, + unsigned int* spatial_stride_dst, unsigned int temp_dim_src, + unsigned int* temp_stride_src, unsigned int* temp_bound_src, + unsigned int temp_dim_dst, unsigned int* temp_stride_dst, + unsigned int* temp_bound_dst, unsigned int enabled_chan_src, + unsigned int enabled_chan_dst, unsigned int enabled_byte_dst) { + csrw_ss(XDMA_SRC_ADDR_PTR_LSB, (unsigned int)(uint64_t)src); + csrw_ss(XDMA_SRC_ADDR_PTR_MSB, (unsigned int)((uint64_t)src >> 32)); + + csrw_ss(XDMA_DST_ADDR_PTR_LSB, (unsigned int)(uint64_t)dst); + csrw_ss(XDMA_DST_ADDR_PTR_MSB, (unsigned int)((uint64_t)dst >> 32)); + // Rule check + // The enabled spatial bound for input should be equal to the enabled + // Src frame count and dst frame count should be equal + // unsigned int src_size = 1; + // if (temp_dim_src > 0) { + // for (unsigned int i = 0; i < temp_dim_src; i++) { + // src_size *= temp_bound_src[i]; + // } + // } + // unsigned int dst_size = 1; + // if (temp_dim_dst > 0) { + // for (unsigned int i = 0; i < temp_dim_dst; i++) { + // dst_size *= temp_bound_dst[i]; + // } + // } + // if (src_size != dst_size) { + // // XDMA_DEBUG_PRINT("src loop and dst loop is not equal\n"); + // // return -3; + // } + // Spatial Stride 0 to XDMA_SRC_SPATIAL_DIM at src + for (unsigned int i = 0; i < XDMA_SRC_SPATIAL_DIM; i++) { + csrw_ss(XDMA_SRC_SPATIAL_STRIDE_PTR + i, spatial_stride_src[i]); + } + // Spatial Stride 0 to XDMA_DST_SPATIAL_DIM at dst + for (unsigned int i = 0; i < XDMA_DST_SPATIAL_DIM; i++) { + csrw_ss(XDMA_DST_SPATIAL_STRIDE_PTR + i, spatial_stride_dst[i]); + } + // Temporal Dimension 0 to n at src + for (unsigned int i = 0; i < temp_dim_src; i++) { + if (i >= XDMA_SRC_TEMP_DIM) { + // XDMA_DEBUG_PRINT("Source dimension is too high for xdma\n"); + return -4; + } + csrw_ss(XDMA_SRC_TEMP_BOUND_PTR + i, temp_bound_src[i]); + csrw_ss(XDMA_SRC_TEMP_STRIDE_PTR + i, temp_stride_src[i]); + } + // Dimension n to MAX at src + for (unsigned int i = temp_dim_src; i < XDMA_SRC_TEMP_DIM; i++) { + csrw_ss(XDMA_SRC_TEMP_BOUND_PTR + i, 1); + csrw_ss(XDMA_SRC_TEMP_STRIDE_PTR + i, 0); + } + // Temporal Dimension 0 to n at dst + for (unsigned int i = 0; i < temp_dim_dst; i++) { + if (i >= XDMA_DST_TEMP_DIM) { + // XDMA_DEBUG_PRINT("Destination dimension is too high for xdma\n"); + return -4; + } + csrw_ss(XDMA_DST_TEMP_BOUND_PTR + i, temp_bound_dst[i]); + csrw_ss(XDMA_DST_TEMP_STRIDE_PTR + i, temp_stride_dst[i]); + } + // Dimension n to MAX at dst + for (unsigned int i = temp_dim_dst; i < XDMA_DST_TEMP_DIM; i++) { + csrw_ss(XDMA_DST_TEMP_BOUND_PTR + i, 1); + csrw_ss(XDMA_DST_TEMP_STRIDE_PTR + i, 0); + } + // Enabled channel at src + csrw_ss(XDMA_SRC_ENABLED_CHAN_PTR, enabled_chan_src); + // Enabled channel at dst + csrw_ss(XDMA_DST_ENABLED_CHAN_PTR, enabled_chan_dst); + // Enabled byte at dst + csrw_ss(XDMA_DST_ENABLED_BYTE_PTR, enabled_byte_dst); + return 0; +} + +int xdma_memcpy_1d(unsigned char* src, unsigned char* dst, unsigned int size) { + if (size % XDMA_WIDTH != 0) { + // XDMA_DEBUG_PRINT("Size is not multiple of XDMA_WIDTH\n"); + return -1; + } + unsigned int spatial_stride[1] = {XDMA_WIDTH / XDMA_SPATIAL_CHAN}; + unsigned int temporal_stride[1] = {XDMA_WIDTH}; + unsigned int temporal_bound[1] = {size / XDMA_WIDTH}; + unsigned int bound[2] = {XDMA_SPATIAL_CHAN, size / XDMA_WIDTH}; + return xdma_memcpy_nd(src, dst, spatial_stride, spatial_stride, 2, + temporal_stride, temporal_bound, 2, temporal_stride, + temporal_bound, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF); +} + +// xdma extension interface +int xdma_enable_src_ext(unsigned char ext, unsigned int* csr_value) { + if (ext >= XDMA_SRC_EXT_NUM) { + return -1; + } + unsigned char custom_csr_list[XDMA_SRC_EXT_NUM] = XDMA_SRC_EXT_CUSTOM_CSR_NUM; + unsigned int csr_offset = XDMA_SRC_EXT_CSR_PTR; + for (unsigned char i = 0; i < ext; i++) { + csr_offset += custom_csr_list[i]; + } + + // Not bypass the xdma extension -> set the corresponding CSR bit to 0 + csrw_ss(XDMA_SRC_BYPASS_PTR, csrr_ss(XDMA_SRC_BYPASS_PTR) & ~(1 << ext)); + + for (unsigned char i = 0; i < custom_csr_list[ext]; i++) { + csrw_ss(csr_offset + i, csr_value[i]); + } + return 0; +} +int xdma_enable_dst_ext(unsigned char ext, unsigned int* csr_value) { + if (ext >= XDMA_DST_EXT_NUM) { + return -1; + } + unsigned char custom_csr_list[XDMA_DST_EXT_NUM] = XDMA_DST_EXT_CUSTOM_CSR_NUM; + unsigned int csr_offset = XDMA_DST_EXT_CSR_PTR; + for (unsigned char i = 0; i < ext; i++) { + csr_offset += custom_csr_list[i]; + } + + // Not bypass the xdma extension -> set the corresponding CSR bit to 0 + csrw_ss(XDMA_DST_BYPASS_PTR, csrr_ss(XDMA_DST_BYPASS_PTR) & ~(1 << ext)); + for (unsigned char i = 0; i < custom_csr_list[ext]; i++) { + csrw_ss(csr_offset + i, csr_value[i]); + } + return 0; +} + +int xdma_disable_src_ext(unsigned char ext) { + if (ext >= XDMA_SRC_EXT_NUM) { + return 0; + } + // Bypass the xdma extension -> set the corresponding CSR bit to 1 + csrw_ss(XDMA_SRC_BYPASS_PTR, csrr_ss(XDMA_SRC_BYPASS_PTR) | (1 << ext)); + return 0; +} + +int xdma_disable_dst_ext(unsigned char ext) { + if (ext >= XDMA_DST_EXT_NUM) { + return 0; + } + // Bypass the xdma extension -> set the corresponding CSR bit to 1 + csrw_ss(XDMA_DST_BYPASS_PTR, csrr_ss(XDMA_DST_BYPASS_PTR) | (1 << ext)); + return 0; +} + +// Start xdma +unsigned int xdma_start() { + int ret = csrr_ss(XDMA_COMMIT_TASK_PTR); + csrw_ss(XDMA_START_PTR, 1); + while (csrr_ss(XDMA_COMMIT_TASK_PTR) == ret) { + // Wait for xdma to start + } + return csrr_ss(XDMA_COMMIT_TASK_PTR); +} + +// Check if xdma is finished +bool xdma_is_finished(unsigned int task_id) { + return csrr_ss(XDMA_FINISH_TASK_PTR) >= task_id; +} + +void xdma_wait(unsigned int task_id) { + while (!xdma_is_finished(task_id)) { + // Wait for xdma to finish + } +} + +// This is the test function for the SNAX XDMA doing a maxpool operation +extern unsigned int __global_pointer$; + +int kul_cluster_xdma_test() { + // wake up the dma core, not work... + // if (snrt_cluster_core_idx() == 0) { + // volatile uint32_t *interruptTarget = ((uint32_t *)CLINT_CTRL_BASE) + 6 + 1; + // *interruptTarget = 1; + // } + + // !!! set the stack pointer and global pointer !!! + // set it to the end of the KUL cluster TCDM (size = 128KB) address - 4 + unsigned int stack_start = snrt_cluster_base_addrl() + 128 * 1024 - 4; + asm("mv sp, %0" ::"r"((unsigned int)stack_start)); + unsigned int gp_value = (unsigned int)(&__global_pointer$); + asm("mv gp, %0" ::"r"(gp_value)); + + + // Set err value for checking + int err = 0; + // Obtain the start address of the TCDM memory + unsigned int dma_load_input_start; + unsigned int dma_load_input_end; + unsigned int tcdm_baseaddress = snrt_cluster_base_addrl(); + // Put the input at the starting of tcdm + unsigned char *tcdm_in = (unsigned char *)tcdm_baseaddress; + // Put the output at the middle of tcdm + unsigned char *tcdm_out = (unsigned char *)(tcdm_baseaddress + delta_local_out); + + if (snrt_is_dm_core()) { + // --------------------------------// + // -------------source cfg---------// + // --------------------------------// + // source base ptr + write_csr(960, (unsigned int)tcdm_in); + write_csr(961, 0); + + // spatial strides + write_csr(962, 8); + + // temporal bounds + write_csr(963, 1); + write_csr(964, 1); + write_csr(965, 1); + write_csr(966, 1); + write_csr(967, 1); + write_csr(968, 1); + + // temporal strides + write_csr(969, 64); + write_csr(970, 64); + write_csr(971, 64); + write_csr(972, 64); + write_csr(973, 64); + write_csr(974, 64); + + // XDMA_SRC_ENABLED_CHAN_PTR + write_csr(975, 0xFFFFFFFF); + + // --------------------------------// + // -------------dest cfg---------// + // --------------------------------// + // dest base ptr + write_csr(976, (unsigned int)tcdm_out); + write_csr(977, 0); + + // spatial strides + write_csr(978, 8); + + // temporal bounds + write_csr(979, 1); + write_csr(980, 1); + write_csr(981, 1); + write_csr(982, 1); + write_csr(983, 1); + write_csr(984, 1); + + // temporal strides + write_csr(985, 64); + write_csr(986, 64); + write_csr(987, 64); + write_csr(988, 64); + write_csr(989, 64); + write_csr(990, 64); + + // XDMA_DST_ENABLED_CHAN_PTR + write_csr(991, 0xFFFFFFFF); + + // XDMA_DST_ENABLED_BYTE_PTR + write_csr(992, 0xFFFFFFFF); + + // XDMA_DST_BYPASS_PTR + write_csr(993, 0b101); + + // XDMA_DST_EXT_CSR_PTR + // the second extension is enabled and cfg is 1 + // jump 994 + write_csr(995, 0b1); + + // start + // XDMA_START_PTR + write_csr(996, 1); + + // XDMA_COMMIT_TASK_PTR + int task_id = read_csr(997); + // XDMA_FINISH_TASK_PTR + while (read_csr(998) < task_id) { + } + + // // The xdma core is the last compute core in the cluster + // unsigned int sstride_src[1] = {0}; + // unsigned int sstride_dst[1] = {0}; + // unsigned int tstride_src[5] = {0}; + // unsigned int tbound_src[5] = {0}; + // unsigned int tstride_dst[3] = {0}; + // unsigned int tbound_dst[3] = {0}; + + // // Load the CFG from data.h + // sstride_src[0] = spatialStride1_in; + // sstride_dst[0] = spatialStride1_out; + // tstride_src[0] = tempStride0_in; + // tstride_src[1] = tempStride1_in; + // tstride_src[2] = tempStride2_in; + // tstride_src[3] = tempStride3_in; + // tstride_src[4] = tempStride4_in; + // tbound_src[0] = tempLoop0_in; + // tbound_src[1] = tempLoop1_in; + // tbound_src[2] = tempLoop2_in; + // tbound_src[3] = tempLoop3_in; + // tbound_src[4] = tempLoop4_in; + // tstride_dst[0] = tempStride0_out; + // tstride_dst[1] = tempStride1_out; + // tstride_dst[2] = tempStride2_out; + // tbound_dst[0] = tempLoop0_out; + // tbound_dst[1] = tempLoop1_out; + // tbound_dst[2] = tempLoop2_out; + + // // First we need to transfer the input data from L3->TCDM + // snrt_dma_start_1d(tcdm_in, DataIn, input_data_len * sizeof(int8_t)); + // snrt_dma_wait_all(); + + // // --------------------- Configure the Ext --------------------- // + + // if (xdma_disable_dst_ext(0) != 0) { + // err++; + // } else { + // } + + // unsigned int ext_param_maxpool_size[1] = {reduceLen}; + // if (xdma_enable_dst_ext(1, ext_param_maxpool_size) != 0) { + // err++; + // } else { + // } + + // if (xdma_disable_dst_ext(2) != 0) { + // err++; + // } else { + // } + + // // --------------------- Configure the AGU --------------------- // + // xdma_memcpy_nd(tcdm_in, tcdm_out, sstride_src, sstride_dst, 5, + // tstride_src, tbound_src, 3, tstride_dst, tbound_dst, + // 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF); + // int task_id = xdma_start(); + // xdma_wait(task_id); + + // --------------------- Checking the Results --------------------- // + for (int i = 0; i < output_data_len; i++) { + if ((int8_t)tcdm_out[i] != C_golden[i]) { + err++; + } + } + } + + return err; +} diff --git a/sw/lib/offload.c b/sw/lib/offload.c index c7a6040..596d126 100644 --- a/sw/lib/offload.c +++ b/sw/lib/offload.c @@ -54,8 +54,12 @@ void offloadToCluster(void *function, uint8_t clusterId) { hartId += _chimera_numCores[i]; } - volatile uint32_t *interruptTarget = ((uint32_t *)CLINT_CTRL_BASE) + hartId; waitClusterBusy(clusterId); + + volatile uint32_t *interruptTarget = ((uint32_t *)CLINT_CTRL_BASE) + hartId; + + *interruptTarget = 1; + interruptTarget = ((uint32_t *)CLINT_CTRL_BASE) + hartId + 1; *interruptTarget = 1; } diff --git a/sw/tests/testKULClusterOffload.c b/sw/tests/testKULClusterOffload.c new file mode 100644 index 0000000..55a1bf5 --- /dev/null +++ b/sw/tests/testKULClusterOffload.c @@ -0,0 +1,49 @@ +// Copyright 2024 ETH Zurich and University of Bologna. +// Licensed under the Apache License, Version 2.0, see LICENSE for details. +// SPDX-License-Identifier: Apache-2.0 +// +// Moritz Scherer +// Xiaoling Yi + +// Offload `kul_cluster_sw_test` test function to KUL cluster. + +#include "offload.h" +#include "soc_addr_map.h" +#include +#include + +// add snitch runtime for getting TCDM address +#include "kultest/snax-kul-cluster-gemmx-test.h" +#include "kultest/snax-kul-cluster-xdma-test.h" + +static uint32_t *clintPointer = (uint32_t *)CLINT_CTRL_BASE; + +void clusterTrapHandler() { + uint8_t hartId; + asm("csrr %0, mhartid" : "=r"(hartId)::); + + volatile uint32_t *interruptTarget = clintPointer + hartId; + *interruptTarget = 0; + return; +} + +int main() { + uint8_t kul_clusterId = 3; + + setupInterruptHandler(clusterTrapHandler); + + // offload gemm test function to kul cluster + offloadToCluster(kul_cluster_gemmx_test, kul_clusterId); + + // wait for kul cluster to finish + uint32_t retVal_gemmx = waitForCluster(kul_clusterId); + + // offload xdma test function to kul cluster + offloadToCluster(kul_cluster_xdma_test, kul_clusterId); + + // wait for kul cluster to finish + uint32_t retVal_xdma = waitForCluster(kul_clusterId); + + uint32_t retVal = (retVal_gemmx << 16) | retVal_xdma; + return retVal; +}