diff --git a/include/infiniop.h b/include/infiniop.h index ffce99ff4..29ee7ba92 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -23,4 +23,5 @@ #include "infiniop/ops/swiglu.h" #include "infiniop/tensor_descriptor.h" +#include "infiniop/ops/spmv.h" #endif // __INFINIOP_API_H__ diff --git a/include/infiniop/ops/spmv.h b/include/infiniop/ops/spmv.h new file mode 100644 index 000000000..9fe8d32cb --- /dev/null +++ b/include/infiniop/ops/spmv.h @@ -0,0 +1,28 @@ +#ifndef __INFINIOP_SPMV_API_H__ +#define __INFINIOP_SPMV_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopSpMVDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSpMVDescriptor( + infiniopHandle_t handle, + infiniopSpMVDescriptor_t *desc_ptr, + size_t num_cols, // 矩阵列数 + size_t num_rows, // 行偏移数组长度 + size_t nnz, // 非零元素数量 + infiniDtype_t dtype); // 数据类型(当前仅支持F32) + +__C __export infiniStatus_t infiniopSpMV( + infiniopSpMVDescriptor_t desc, + void *y, // 输出向量 + const void *x, // 输入向量 + const void *values, // 非零元素值数组 + const void *row_ptr, // 行偏移数组 + const void *col_indices, // 列索引数组 + void *stream); // 计算流 + +__C __export infiniStatus_t infiniopDestroySpMVDescriptor(infiniopSpMVDescriptor_t desc); + +#endif diff --git a/src/infiniop/devices/bang/common_bang.h b/src/infiniop/devices/bang/common_bang.h index 4e95430ff..9e9fb4ec9 100644 --- a/src/infiniop/devices/bang/common_bang.h +++ b/src/infiniop/devices/bang/common_bang.h @@ -8,6 +8,8 @@ #include "cnrt.h" #include +struct InfiniopTensorDescriptor; + #define CHECK_BANG(API) CHECK_INTERNAL(API, CNNL_STATUS_SUCCESS) namespace device::bang { diff --git a/src/infiniop/devices/cuda/cuda_common.cu b/src/infiniop/devices/cuda/cuda_common.cu index 118f172a2..f5aafb81f 100644 --- a/src/infiniop/devices/cuda/cuda_common.cu +++ b/src/infiniop/devices/cuda/cuda_common.cu @@ -3,12 +3,9 @@ namespace device::cuda { Handle::Handle(infiniDevice_t device, int device_id) - : InfiniopHandle{device, device_id}, - _internal(std::make_shared(device_id)) {} + : InfiniopHandle{device, device_id}, _internal(std::make_shared(device_id)) {} -auto Handle::internal() const -> const std::shared_ptr & { - return _internal; -} +auto Handle::internal() const -> const std::shared_ptr & { return _internal; } Handle::Internal::Internal(int device_id) { cudaDeviceProp prop; @@ -45,6 +42,17 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn &f) const { + auto handle = sparse_handles.pop(); + if (!handle) { + CHECK_CUSPARSE(cusparseCreate(&(*handle))); + } + CHECK_CUSPARSE(cusparseSetStream(*handle, stream)); + CHECK_STATUS(f(*handle)); + sparse_handles.push(std::move(*handle)); + return INFINI_STATUS_SUCCESS; +} + int Handle::Internal::warpSize() const { return _warp_size; } int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; } int Handle::Internal::blockSizeX() const { return _block_size[0]; } @@ -79,8 +87,7 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) { namespace nvidia { -Handle::Handle(int device_id) - : cuda::Handle(INFINI_DEVICE_NVIDIA, device_id) {} +Handle::Handle(int device_id) : cuda::Handle(INFINI_DEVICE_NVIDIA, device_id) {} infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { *handle_ptr = new Handle(device_id); diff --git a/src/infiniop/devices/cuda/cuda_common.cuh b/src/infiniop/devices/cuda/cuda_common.cuh index 206410aa5..b69ec22c7 100644 --- a/src/infiniop/devices/cuda/cuda_common.cuh +++ b/src/infiniop/devices/cuda/cuda_common.cuh @@ -3,6 +3,7 @@ #include "cuda_handle.cuh" #include "infinicore.h" +#include namespace device::cuda { diff --git a/src/infiniop/devices/cuda/cuda_handle.cuh b/src/infiniop/devices/cuda/cuda_handle.cuh index 5db14817d..51b50ca5e 100644 --- a/src/infiniop/devices/cuda/cuda_handle.cuh +++ b/src/infiniop/devices/cuda/cuda_handle.cuh @@ -6,21 +6,21 @@ #include "cuda_handle.h" #include #include +#include #include #define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS) #define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS) +#define CHECK_CUSPARSE(API) CHECK_INTERNAL(API, CUSPARSE_STATUS_SUCCESS) namespace device::cuda { class Handle::Internal { Pool blas_handles; Pool dnn_handles; + Pool sparse_handles; - int _warp_size, - _max_threads_per_block, - _block_size[3], - _grid_size[3]; + int _warp_size, _max_threads_per_block, _block_size[3], _grid_size[3]; template using Fn = std::function; @@ -28,8 +28,12 @@ class Handle::Internal { public: Internal(int); - infiniStatus_t useCublas(cudaStream_t stream, const Fn &f) const; - infiniStatus_t useCudnn(cudaStream_t stream, const Fn &f) const; + infiniStatus_t useCublas(cudaStream_t stream, + const Fn &f) const; + infiniStatus_t useCudnn(cudaStream_t stream, + const Fn &f) const; + infiniStatus_t useCusparse(cudaStream_t stream, + const Fn &f) const; int warpSize() const; int maxThreadsPerBlock() const; diff --git a/src/infiniop/ops/spmv/bang/spmv_bang.h b/src/infiniop/ops/spmv/bang/spmv_bang.h new file mode 100644 index 000000000..4684368c2 --- /dev/null +++ b/src/infiniop/ops/spmv/bang/spmv_bang.h @@ -0,0 +1,8 @@ +#ifndef __SPMV_BANG_H__ +#define __SPMV_BANG_H__ + +#include "../spmv.h" + +DESCRIPTOR(bang) + +#endif // __SPMV_CPU_H__ diff --git a/src/infiniop/ops/spmv/bang/spmv_bang.mlu b/src/infiniop/ops/spmv/bang/spmv_bang.mlu new file mode 100644 index 000000000..97cd0bea1 --- /dev/null +++ b/src/infiniop/ops/spmv/bang/spmv_bang.mlu @@ -0,0 +1,129 @@ +#include "../../../devices/bang/bang_handle.h" +#include "../../../devices/bang/common_bang.h" +#include "../info.h" +#include "bang.h" +#include "mlu.h" +#include "spmv_bang.h" +#include + +#define NRAMSIZE 1024 * 512 // 512KB NRAM size + +__mlu_entry__ void spmv_csr(int num_rows, int num_cols, int nnz, + int *row_ptr, int *col_indices, + float *values, float *x, float *y) { + // 计算每个 task 处理的行范围 + // 如果按照nnz划分任务,或许要在descriptor中记录一些信息; + int rows_num_pertask = (num_rows + taskDim - 1) / taskDim; + int start_row = taskId * rows_num_pertask; + int end_row = num_rows < (taskId + 1) * rows_num_pertask ? num_rows : (taskId + 1) * rows_num_pertask; + // int low = 0; + // int high = num_rows; + // int nnz_per_task = nnz / taskDim; + // while (low < high) { + // int mid = (low + high) / 2; + // if (row_ptr[mid] < taskId * nnz_per_task) { + // low = mid + 1; + // } else { + // high = mid; + // } + // } + // int start_row = low; + // // printf("taskId: %d, start_row: %d\n", taskId, start_row); + // int end_row = num_rows; + // if (taskId != taskDim - 1) { + // high = num_rows; + // while (low < high) { + // int mid = (low + high) / 2; + // if (row_ptr[mid] < (taskId + 1) * nnz_per_task) { + // low = mid + 1; + // } else { + // high = mid; + // } + // } + // end_row = low; + // } + // printf("taskId: %d, end_row: %d\n", taskId, end_row); + // 处理分配给当前 task 的行 + // move values data from GDRAM to NRAM + const int float_capacity = (NRAMSIZE / sizeof(float)) * 0.75; // 75% of NRAM size + __nram__ float nram_values[(NRAMSIZE / sizeof(float)) * 3 / 4]; + __nram__ float sum = 0.0f; + __nram__ int current_num = 0; + __nram__ int current_begin = 0; + for (int row = start_row; row < end_row; row++) { + for (int k = 0; k < row_ptr[row + 1] - row_ptr[row]; k += float_capacity) { + current_num = std::min(float_capacity, row_ptr[row + 1] - row_ptr[row] - k); + // move values data from GDRAM to NRAM + current_begin = row_ptr[row] + k; + __memcpy(nram_values, values + current_begin, current_num * sizeof(float), GDRAM2NRAM); + for (int i = 0; i < current_num; i++) { + sum += nram_values[i] * x[col_indices[current_begin + i]]; + } + } + y[row] = sum; + sum = 0.0f; // reset sum for the next row + } + return; +} +namespace op::spmv::bang { +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, + Descriptor **desc_ptr, size_t num_cols, + size_t num_rows, size_t nnz, + infiniDtype_t dtype) { + auto handle = reinterpret_cast(handle_); + + // currently only float32 supported + if (dtype != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto result = SpMVInfo::create(num_cols, num_rows, nnz); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor(dtype, result.take(), new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *y, const void *x, const void *values, + const void *row_ptr, + const void *col_indices, + void *stream) const { + // do basic validation + auto validation_result = validateSpMVCSR(y, x, values, row_ptr, col_indices, _dtype); + CHECK_OR_RETURN(validation_result == INFINI_STATUS_SUCCESS, validation_result); + + CNRT_CHECK(cnrtSetDevice(device_id)); + cnrtQueue_t queue; + if (stream == nullptr || stream == NULL) { + CNRT_CHECK(cnrtQueueCreate(&queue)); + } else { + queue = (cnrtQueue_t)stream; + } + cnrtDim3_t dim = {256, 1, 1}; + cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_BLOCK; + + int num_rows = static_cast(_info.num_rows); + int num_cols = static_cast(_info.num_cols); + int nnz = static_cast(_info.nnz); + + int *d_row_ptr = (int *)row_ptr; + int *d_col_indices = (int *)col_indices; + float *d_values = (float *)values; + float *d_x = (float *)x; + float *d_y = (float *)y; + + spmv_csr<<>>(num_rows, num_cols, nnz, d_row_ptr, d_col_indices, d_values, d_x, d_y); + + CNRT_CHECK(cnrtQueueSync(queue)); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::spmv::bang diff --git a/src/infiniop/ops/spmv/cpu/spmv_cpu.cc b/src/infiniop/ops/spmv/cpu/spmv_cpu.cc new file mode 100644 index 000000000..ff6d69116 --- /dev/null +++ b/src/infiniop/ops/spmv/cpu/spmv_cpu.cc @@ -0,0 +1,82 @@ +#include "spmv_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../info.h" +#include + +namespace op::spmv::cpu { + +struct Descriptor::Opaque { + // CPU doesn't need special hardware context +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + size_t num_cols, + size_t num_rows, + size_t nnz, + infiniDtype_t dtype) { + + auto handle = reinterpret_cast(handle_); + + // 当前仅支持单精度 + if (dtype != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto result = SpMVInfo::create(num_cols, num_rows, nnz); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + dtype, result.take(), + new Opaque{}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// CSR Implementation of SpMV +static void spmv_csr_impl( + float *y, const float *x, const float *values, + const int32_t *row_ptr, const int32_t *col_idx, + size_t num_rows) { + +#ifdef ENABLE_OMP +#pragma omp parallel for +#endif + for (int i = 0; i < static_cast(num_rows); ++i) { + float sum = 0; + for (int32_t j = row_ptr[i]; j < row_ptr[i + 1]; ++j) { + sum += values[j] * x[col_idx[j]]; + } + y[i] = sum; + } +} + +infiniStatus_t Descriptor::calculate( + void *y, + const void *x, + const void *values, + const void *row_ptr, + const void *col_indices, + void *stream) const { + + auto validation_result = validateSpMVCSR( + y, x, values, row_ptr, col_indices, _dtype); + CHECK_OR_RETURN(validation_result == INFINI_STATUS_SUCCESS, validation_result); + + spmv_csr_impl( + static_cast(y), + static_cast(x), + static_cast(values), + static_cast(row_ptr), + static_cast(col_indices), + _info.num_rows); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::spmv::cpu diff --git a/src/infiniop/ops/spmv/cpu/spmv_cpu.h b/src/infiniop/ops/spmv/cpu/spmv_cpu.h new file mode 100644 index 000000000..90ee1e5e8 --- /dev/null +++ b/src/infiniop/ops/spmv/cpu/spmv_cpu.h @@ -0,0 +1,8 @@ +#ifndef __SPMV_CPU_H__ +#define __SPMV_CPU_H__ + +#include "../spmv.h" + +DESCRIPTOR(cpu) + +#endif // __SPMV_CPU_H__ diff --git a/src/infiniop/ops/spmv/cuda/spmv_cuda.cu b/src/infiniop/ops/spmv/cuda/spmv_cuda.cu new file mode 100644 index 000000000..3cc80483e --- /dev/null +++ b/src/infiniop/ops/spmv/cuda/spmv_cuda.cu @@ -0,0 +1,101 @@ +#include "../../../devices/cuda/cuda_common.cuh" +#include "../../../devices/cuda/cuda_handle.cuh" +#include "../../../devices/cuda/cuda_kernel_common.cuh" +#include "../info.h" +#include "spmv_cuda.cuh" +#include +#include + +namespace op::spmv::cuda { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, + Descriptor **desc_ptr, size_t num_cols, + size_t num_rows, size_t nnz, + infiniDtype_t dtype) { + + auto handle = reinterpret_cast(handle_); + + // currently only float32 supported + if (dtype != INFINI_DTYPE_F32) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + auto result = SpMVInfo::create(num_cols, num_rows, nnz); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor(dtype, result.take(), new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *y, const void *x, const void *values, + const void *row_ptr, + const void *col_indices, + void *stream) const { + + // do basic validation + auto validation_result = validateSpMVCSR(y, x, values, row_ptr, col_indices, _dtype); + CHECK_OR_RETURN(validation_result == INFINI_STATUS_SUCCESS, + validation_result); + + // set up data types and constants + cudaDataType cuda_dtype = CUDA_R_32F; + const float alpha = 1.0f, beta = 0.0f; + + CHECK_STATUS(_opaque->internal->useCusparse( + (cudaStream_t)stream, [&](cusparseHandle_t cusparse_handle) { + // create sparse matrix descriptor + cusparseSpMatDescr_t mat_descr; + CHECK_CUSPARSE(cusparseCreateCsr( + &mat_descr, _info.num_rows, _info.num_cols, _info.nnz, + const_cast(row_ptr), const_cast(col_indices), + const_cast(values), CUSPARSE_INDEX_32I, CUSPARSE_INDEX_32I, + CUSPARSE_INDEX_BASE_ZERO, cuda_dtype)); + + // create dense vector descriptors + cusparseDnVecDescr_t vec_x, vec_y; + CHECK_CUSPARSE(cusparseCreateDnVec(&vec_x, _info.num_cols, + const_cast(x), cuda_dtype)); + CHECK_CUSPARSE( + cusparseCreateDnVec(&vec_y, _info.num_rows, y, cuda_dtype)); + + // compute buffer size + size_t buffer_size = 0; + CHECK_CUSPARSE(cusparseSpMV_bufferSize( + cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, + mat_descr, vec_x, &beta, vec_y, cuda_dtype, + CUSPARSE_SPMV_ALG_DEFAULT, &buffer_size)); + + // allocate external buffer if needed + void *external_buffer = nullptr; + if (buffer_size > 0) { + CHECK_CUDA(cudaMalloc(&external_buffer, buffer_size)); + } + + // perform the sparse matrix-vector multiplication + auto result = cusparseSpMV(cusparse_handle, CUSPARSE_OPERATION_NON_TRANSPOSE, + &alpha, mat_descr, vec_x, &beta, vec_y, cuda_dtype, + CUSPARSE_SPMV_ALG_DEFAULT, external_buffer); + + // clean up resources + if (external_buffer) { + cudaFree(external_buffer); + } + cusparseDestroyDnVec(vec_x); + cusparseDestroyDnVec(vec_y); + cusparseDestroySpMat(mat_descr); + + CHECK_CUSPARSE(result); + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::spmv::cuda diff --git a/src/infiniop/ops/spmv/cuda/spmv_cuda.cuh b/src/infiniop/ops/spmv/cuda/spmv_cuda.cuh new file mode 100644 index 000000000..a0f0fafc2 --- /dev/null +++ b/src/infiniop/ops/spmv/cuda/spmv_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __SPMV_CUDA_CUH__ +#define __SPMV_CUDA_CUH__ + +#include "../spmv.h" + +DESCRIPTOR(cuda) + +#endif // __SPMV_CUDA_CUH__ diff --git a/src/infiniop/ops/spmv/info.h b/src/infiniop/ops/spmv/info.h new file mode 100644 index 000000000..5c6bc9d92 --- /dev/null +++ b/src/infiniop/ops/spmv/info.h @@ -0,0 +1,66 @@ +#ifndef __SPMV_INFO_H__ +#define __SPMV_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" + +namespace op::spmv { + +// SpMV operation information +class SpMVInfo { + SpMVInfo() = default; + +public: + size_t num_rows; + size_t num_cols; + size_t nnz; + + static utils::Result create( + size_t num_cols, + size_t num_rows, + size_t nnz) { + + CHECK_OR_RETURN(num_cols > 0 && num_rows > 0 && nnz > 0, + INFINI_STATUS_BAD_PARAM); + + SpMVInfo info; + info.num_rows = num_rows; + info.num_cols = num_cols; + info.nnz = nnz; + + return utils::Result(info); + } + + static utils::Result createLegacy( + size_t num_rows, + size_t num_cols, + size_t nnz) { + + CHECK_OR_RETURN(num_rows > 0 && num_cols > 0 && nnz > 0, + INFINI_STATUS_BAD_PARAM); + + SpMVInfo info; + info.num_rows = num_rows; + info.num_cols = num_cols; + info.nnz = nnz; + + return utils::Result(info); + } +}; + +// validate SpMV CSR operation parameters +inline infiniStatus_t validateSpMVCSR( + const void *y, const void *x, const void *values, + const void *row_indices, const void *col_indices, + infiniDtype_t dtype) { + + CHECK_OR_RETURN(y && x && values && row_indices && col_indices, + INFINI_STATUS_BAD_PARAM); + CHECK_OR_RETURN(dtype == INFINI_DTYPE_F32, INFINI_STATUS_BAD_TENSOR_DTYPE); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::spmv + +#endif // __SPMV_INFO_H__ diff --git a/src/infiniop/ops/spmv/operator.cc b/src/infiniop/ops/spmv/operator.cc new file mode 100644 index 000000000..98c889ecc --- /dev/null +++ b/src/infiniop/ops/spmv/operator.cc @@ -0,0 +1,94 @@ +#include "../../handle.h" +#include "infiniop/ops/spmv.h" +#include "spmv.h" + +#ifdef ENABLE_CPU_API +#include "cpu/spmv_cpu.h" +#endif + +#ifdef ENABLE_CUDA_API +#include "cuda/spmv_cuda.cuh" +#endif + +#ifdef ENABLE_CAMBRICON_API +#include "bang/spmv_bang.h" +#endif + +__C infiniStatus_t infiniopCreateSpMVDescriptor( + infiniopHandle_t handle, infiniopSpMVDescriptor_t *desc_ptr, + size_t num_cols, size_t num_rows, size_t nnz, infiniDtype_t dtype) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::spmv::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + num_cols, num_rows, nnz, dtype) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda); +#endif +#ifdef ENABLE_CAMBRICON_API + CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopSpMV(infiniopSpMVDescriptor_t desc, void *y, + const void *x, const void *values, + const void *row_ptr, const void *col_indices, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(y, x, values, row_ptr, col_indices, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CALCULATE(INFINI_DEVICE_NVIDIA, cuda); +#endif +#ifdef ENABLE_CAMBRICON_API + CALCULATE(INFINI_DEVICE_CAMBRICON, bang); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroySpMVDescriptor(infiniopSpMVDescriptor_t desc) { +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + DELETE(INFINI_DEVICE_NVIDIA, cuda); +#endif +#ifdef ENABLE_CAMBRICON_API + DELETE(INFINI_DEVICE_CAMBRICON, bang); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} diff --git a/src/infiniop/ops/spmv/spmv.h b/src/infiniop/ops/spmv/spmv.h new file mode 100644 index 000000000..24b6534eb --- /dev/null +++ b/src/infiniop/ops/spmv/spmv.h @@ -0,0 +1,48 @@ +#ifndef __SPMV_H__ +#define __SPMV_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::spmv::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + infiniDtype_t _dtype; \ + SpMVInfo _info; \ + \ + Descriptor( \ + infiniDtype_t dtype, \ + SpMVInfo info, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _dtype(dtype), \ + _info(info) {} \ + \ + public: \ + ~Descriptor(); \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + size_t num_cols, \ + size_t num_rows, \ + size_t nnz, \ + infiniDtype_t dtype); \ + \ + infiniStatus_t calculate( \ + void *y, \ + const void *x, \ + const void *values, \ + const void *row_ptr, \ + const void *col_indices, \ + void *stream) const; \ + }; \ + } + +#endif // __SPMV_H__ diff --git a/test/infiniop/spmv.py b/test/infiniop/spmv.py new file mode 100644 index 000000000..55a9aaea7 --- /dev/null +++ b/test/infiniop/spmv.py @@ -0,0 +1,348 @@ +import torch +import ctypes +from ctypes import POINTER, Structure, c_int32, c_size_t, c_void_p, c_float +from libinfiniop import ( + infiniopHandle_t, + infiniopTensorDescriptor_t, + open_lib, + to_tensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + (100, 200, 0.1), # Dense small matrix + (5000, 3600, 0.01), # Medium size + (10000, 100000, 0.0004), # Large sparse matrix + # (1000000, 1000000, 0.00001), # Very large sparse matrix +] + +# Data types used for testing (currently only float32 supported) +_TENSOR_DTYPES = [torch.float32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + torch.float32: {"atol": 1e-5, "rtol": 1e-4}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 100 + + +# ============================================================================== +# Definitions +# ============================================================================== +class SpMVDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopSpMVDescriptor_t = POINTER(SpMVDescriptor) + + +def generate_unique_indices_batch(nnz, total_elements, device, batch_size=100): + """ + Generate unique random linear indices in [0, total_elements-1] with minimal extra space. + Uses a batch approach to avoid excessive memory usage. + """ + generated = set() + result = torch.empty(nnz, dtype=torch.long, device=device) + count = 0 + + while count < nnz: + remaining = nnz - count + batch_size = min(batch_size, remaining) + candidates = torch.randint(0, total_elements, (batch_size,), device=device) + for candidate in candidates.cpu().tolist(): + if candidate not in generated: + generated.add(candidate) + result[count] = candidate + count += 1 + if count == nnz: + break + + return result + + +def create_random_csr_matrix(num_rows, num_cols, density, dtype, device): + """ + Create a random CSR sparse matrix with given density. + Returns: values, row_ptr, col_indices, nnz + """ + # Generate random sparse matrix + total_elements = num_rows * num_cols + nnz = int(total_elements * density) + + # Generate linear indices for non-zero elements + linear_indices = generate_unique_indices_batch(nnz, total_elements, device) + + rows = linear_indices // num_cols + cols = linear_indices % num_cols + + # Sort by row for CSR format + sorted_indices = torch.argsort(rows * num_cols + cols) + rows = rows[sorted_indices] + cols = cols[sorted_indices] + + # Create values + values = torch.ones(nnz, dtype=dtype, device=device) + + # Create row pointers (CSR format) + row_ptr = torch.zeros(num_rows + 1, dtype=torch.int32, device=device) + for i in range(nnz): + row_ptr[rows[i] + 1] += 1 + row_ptr = torch.cumsum(row_ptr, dim=0, dtype=torch.int32) + + # Column indices + col_indices = cols.to(torch.int32).to(device) + + if DEBUG: + print("=== CSR Matrix Memory Layout Debug ===") + print(f"row_ptr: shape={row_ptr.shape}, dtype={row_ptr.dtype}") + print(f"row_ptr.is_contiguous(): {row_ptr.is_contiguous()}") + print(f"row_ptr.stride(): {row_ptr.stride()}") + print(f"row_ptr.storage_offset(): {row_ptr.storage_offset()}") + print(f"row_ptr values: {row_ptr[:10]}") + + print(f"col_indices: shape={col_indices.shape}, dtype={col_indices.dtype}") + print(f"col_indices.is_contiguous(): {col_indices.is_contiguous()}") + print(f"col_indices.stride(): {col_indices.stride()}") + print(f"col_indices.storage_offset(): {col_indices.storage_offset()}") + print(f"col_indices values: {col_indices[:10]}") + + return values, row_ptr, col_indices, nnz + + +def spmv_reference(values, row_ptr, col_indices, x): + """ + Reference SpMV implementation using PyTorch. + """ + num_rows = len(row_ptr) - 1 + y = torch.zeros(num_rows, dtype=values.dtype, device=values.device) + + for i in range(num_rows): + start = row_ptr[i].item() + end = row_ptr[i + 1].item() + for j in range(start, end): + y[i] += values[j] * x[col_indices[j]] + + return y + + +def spmv_pytorch_reference(values, row_ptr, col_indices, x, num_rows, num_cols): + """ + Alternative reference using PyTorch sparse tensors for verification. + """ + # Convert CSR to COO format for PyTorch sparse tensor + row_indices = [] + for i in range(num_rows): + start = row_ptr[i].item() + end = row_ptr[i + 1].item() + row_indices.extend([i] * (end - start)) + + row_indices = torch.tensor(row_indices, dtype=torch.long, device=values.device) + col_indices_long = col_indices.long() + + # Create sparse tensor + indices = torch.stack([row_indices, col_indices_long]) + sparse_matrix = torch.sparse_coo_tensor( + indices, values, (num_rows, num_cols), device=values.device + ).coalesce() + + # Perform SpMV + return torch.sparse.mm(sparse_matrix, x.unsqueeze(1)).squeeze(1) + + +# The argument list should be (lib, handle, torch_device, , dtype) +def test( + lib, + handle, + torch_device, + num_rows, + num_cols, + density, + dtype=torch.float32, + sync=None, +): + print( + f"Testing SpMV on {torch_device} with num_rows:{num_rows}, num_cols:{num_cols}, " + f"density:{density}, dtype:{dtype}" + ) + + # Create random CSR sparse matrix + values, row_ptr, col_indices, nnz = create_random_csr_matrix( + num_rows, num_cols, density, dtype, torch_device + ) + + # Create input vector + x = torch.ones(num_cols, dtype=dtype, device=torch_device) + + # Create output vector + y = torch.zeros(num_rows, dtype=dtype, device=torch_device) + + # Compute reference results + y_torch_ref = spmv_reference(values, row_ptr, col_indices, x) + if torch_device == "cuda": + y_torch_sparse_ref = spmv_pytorch_reference( + values, row_ptr, col_indices, x, num_rows, num_cols + ) + assert torch.allclose( + y_torch_ref, y_torch_sparse_ref, atol=1e-6, rtol=1e-5 + ), "PyTorch sparse reference doesn't match common reference!" + + # Create tensors for infiniop + y_tensor = to_tensor(y, lib) + x_tensor = to_tensor(x, lib) + values_tensor = to_tensor(values, lib) + row_ptr_tensor = to_tensor(row_ptr, lib) + col_indices_tensor = to_tensor(col_indices, lib) + + if sync is not None: + sync() + + # Create descriptor + descriptor = infiniopSpMVDescriptor_t() + check_error( + lib.infiniopCreateSpMVDescriptor( + handle, + ctypes.byref(descriptor), + num_cols, + num_rows, + nnz, + InfiniDtype.F32, # Only support float32 now. + ) + ) + + # Invalidate the descriptors to prevent them from being directly used by the kernel + for tensor in [ + y_tensor, + x_tensor, + values_tensor, + row_ptr_tensor, + col_indices_tensor, + ]: + tensor.destroyDesc(lib) + + # Execute infiniop SpMV operator + def lib_spmv(): + check_error( + lib.infiniopSpMV( + descriptor, + y_tensor.data, + x_tensor.data, + values_tensor.data, + row_ptr_tensor.data, + col_indices_tensor.data, + None, # stream + ) + ) + + # print parameters for debugging + if DEBUG: + print("--------------SpMV parameters: ------------------") + print("x_tensor:", x_tensor.torch_tensor_[:10]) + print("y_tensor:", y_tensor.torch_tensor_[:10]) + print("values_tensor:", values_tensor.torch_tensor_[:10]) + print("row_ptr_tensor:", row_ptr_tensor.torch_tensor_[:10]) + print("col_indices_tensor:", col_indices_tensor.torch_tensor_[:10]) + + lib_spmv() + + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(y, y_torch_ref, atol=atol, rtol=rtol) + + # Check against our reference + assert torch.allclose( + y, y_torch_ref, atol=atol, rtol=rtol + ), f"Results don't match reference! Max diff: {(y - y_torch_ref).abs().max().item()}" + + # Also check against PyTorch sparse reference + if torch_device == "cuda": + assert torch.allclose( + y, y_torch_sparse_ref, atol=atol, rtol=rtol + ), f"Results don't match PyTorch reference! Max diff: {(y - y_torch_sparse_ref).abs().max().item()}" + + # Profiling workflow + if PROFILE: + profile_operation( + "Torch Reference", + lambda: spmv_reference(values, row_ptr, col_indices, x), + torch_device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + if torch_device == "cuda": + profile_operation( + "Torch Sparse Reference", + lambda: spmv_pytorch_reference( + values, row_ptr, col_indices, x, num_rows, num_cols + ), + torch_device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_spmv(), torch_device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(lib.infiniopDestroySpMVDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + lib = open_lib() + + # Register API functions + lib.infiniopCreateSpMVDescriptor.restype = c_int32 + lib.infiniopCreateSpMVDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopSpMVDescriptor_t), + c_size_t, # num_cols + c_size_t, # num_rows + c_size_t, # nnz + c_int32, # dtype + ] + + lib.infiniopSpMV.restype = c_int32 + lib.infiniopSpMV.argtypes = [ + infiniopSpMVDescriptor_t, + c_void_p, # y + c_void_p, # x + c_void_p, # values + c_void_p, # row_ptr + c_void_p, # col_indices + c_void_p, # stream + ] + + lib.infiniopDestroySpMVDescriptor.restype = c_int32 + lib.infiniopDestroySpMVDescriptor.argtypes = [ + infiniopSpMVDescriptor_t, + ] + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/xmake/cuda.lua b/xmake/cuda.lua index fb3dbd400..a5599ea67 100644 --- a/xmake/cuda.lua +++ b/xmake/cuda.lua @@ -15,7 +15,7 @@ target("infiniop-cuda") set_policy("build.cuda.devlink", true) set_toolchains("cuda") - add_links("cublas", "cudnn") + add_links("cublas", "cudnn", "cusparse") add_cugencodes("native") if is_plat("windows") then