Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
#include "infiniop/ops/swiglu.h"
#include "infiniop/tensor_descriptor.h"

#include "infiniop/ops/spmv.h"
#endif // __INFINIOP_API_H__
28 changes: 28 additions & 0 deletions include/infiniop/ops/spmv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef __INFINIOP_SPMV_API_H__
#define __INFINIOP_SPMV_API_H__

#include "../operator_descriptor.h"
#include <cstddef>

typedef struct InfiniopDescriptor *infiniopSpMVDescriptor_t;

__C __export infiniStatus_t infiniopCreateSpMVDescriptor(
infiniopHandle_t handle,
infiniopSpMVDescriptor_t *desc_ptr,
size_t num_cols, // 矩阵列数
size_t num_rows, // 行偏移数组长度
size_t nnz, // 非零元素数量
infiniDtype_t dtype); // 数据类型(当前仅支持F32)

__C __export infiniStatus_t infiniopSpMV(
infiniopSpMVDescriptor_t desc,
void *y, // 输出向量
const void *x, // 输入向量
const void *values, // 非零元素值数组
const void *row_ptr, // 行偏移数组
const void *col_indices, // 列索引数组
void *stream); // 计算流

__C __export infiniStatus_t infiniopDestroySpMVDescriptor(infiniopSpMVDescriptor_t desc);

#endif
2 changes: 2 additions & 0 deletions src/infiniop/devices/bang/common_bang.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "cnrt.h"
#include <functional>

struct InfiniopTensorDescriptor;

#define CHECK_BANG(API) CHECK_INTERNAL(API, CNNL_STATUS_SUCCESS)

namespace device::bang {
Expand Down
21 changes: 14 additions & 7 deletions src/infiniop/devices/cuda/cuda_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@
namespace device::cuda {

Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {}
: InfiniopHandle{device, device_id}, _internal(std::make_shared<Handle::Internal>(device_id)) {}

auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
auto Handle::internal() const -> const std::shared_ptr<Internal> & { return _internal; }

Handle::Internal::Internal(int device_id) {
cudaDeviceProp prop;
Expand Down Expand Up @@ -45,6 +42,17 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHan
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Handle::Internal::useCusparse(cudaStream_t stream, const Fn<cusparseHandle_t> &f) const {
auto handle = sparse_handles.pop();
if (!handle) {
CHECK_CUSPARSE(cusparseCreate(&(*handle)));
}
CHECK_CUSPARSE(cusparseSetStream(*handle, stream));
CHECK_STATUS(f(*handle));
sparse_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}

int Handle::Internal::warpSize() const { return _warp_size; }
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
int Handle::Internal::blockSizeX() const { return _block_size[0]; }
Expand Down Expand Up @@ -79,8 +87,7 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {

namespace nvidia {

Handle::Handle(int device_id)
: cuda::Handle(INFINI_DEVICE_NVIDIA, device_id) {}
Handle::Handle(int device_id) : cuda::Handle(INFINI_DEVICE_NVIDIA, device_id) {}

infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
Expand Down
1 change: 1 addition & 0 deletions src/infiniop/devices/cuda/cuda_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "cuda_handle.cuh"
#include "infinicore.h"
#include <cuda_runtime.h>

namespace device::cuda {

Expand Down
16 changes: 10 additions & 6 deletions src/infiniop/devices/cuda/cuda_handle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,34 @@
#include "cuda_handle.h"
#include <cublas_v2.h>
#include <cudnn.h>
#include <cusparse.h>
#include <functional>

#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS)
#define CHECK_CUSPARSE(API) CHECK_INTERNAL(API, CUSPARSE_STATUS_SUCCESS)

namespace device::cuda {

class Handle::Internal {
Pool<cublasHandle_t> blas_handles;
Pool<cudnnHandle_t> dnn_handles;
Pool<cusparseHandle_t> sparse_handles;

int _warp_size,
_max_threads_per_block,
_block_size[3],
_grid_size[3];
int _warp_size, _max_threads_per_block, _block_size[3], _grid_size[3];

template <typename T>
using Fn = std::function<infiniStatus_t(T)>;

public:
Internal(int);

infiniStatus_t useCublas(cudaStream_t stream, const Fn<cublasHandle_t> &f) const;
infiniStatus_t useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const;
infiniStatus_t useCublas(cudaStream_t stream,
const Fn<cublasHandle_t> &f) const;
infiniStatus_t useCudnn(cudaStream_t stream,
const Fn<cudnnHandle_t> &f) const;
infiniStatus_t useCusparse(cudaStream_t stream,
const Fn<cusparseHandle_t> &f) const;

int warpSize() const;
int maxThreadsPerBlock() const;
Expand Down
8 changes: 8 additions & 0 deletions src/infiniop/ops/spmv/bang/spmv_bang.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SPMV_BANG_H__
#define __SPMV_BANG_H__

#include "../spmv.h"

DESCRIPTOR(bang)

#endif // __SPMV_CPU_H__
129 changes: 129 additions & 0 deletions src/infiniop/ops/spmv/bang/spmv_bang.mlu
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include "../../../devices/bang/bang_handle.h"
#include "../../../devices/bang/common_bang.h"
#include "../info.h"
#include "bang.h"
#include "mlu.h"
#include "spmv_bang.h"
#include <cstddef>

#define NRAMSIZE 1024 * 512 // 512KB NRAM size

__mlu_entry__ void spmv_csr(int num_rows, int num_cols, int nnz,
int *row_ptr, int *col_indices,
float *values, float *x, float *y) {
// 计算每个 task 处理的行范围
// 如果按照nnz划分任务,或许要在descriptor中记录一些信息;
int rows_num_pertask = (num_rows + taskDim - 1) / taskDim;
int start_row = taskId * rows_num_pertask;
int end_row = num_rows < (taskId + 1) * rows_num_pertask ? num_rows : (taskId + 1) * rows_num_pertask;
// int low = 0;
// int high = num_rows;
// int nnz_per_task = nnz / taskDim;
// while (low < high) {
// int mid = (low + high) / 2;
// if (row_ptr[mid] < taskId * nnz_per_task) {
// low = mid + 1;
// } else {
// high = mid;
// }
// }
// int start_row = low;
// // printf("taskId: %d, start_row: %d\n", taskId, start_row);
// int end_row = num_rows;
// if (taskId != taskDim - 1) {
// high = num_rows;
// while (low < high) {
// int mid = (low + high) / 2;
// if (row_ptr[mid] < (taskId + 1) * nnz_per_task) {
// low = mid + 1;
// } else {
// high = mid;
// }
// }
// end_row = low;
// }
// printf("taskId: %d, end_row: %d\n", taskId, end_row);
// 处理分配给当前 task 的行
// move values data from GDRAM to NRAM
const int float_capacity = (NRAMSIZE / sizeof(float)) * 0.75; // 75% of NRAM size
__nram__ float nram_values[(NRAMSIZE / sizeof(float)) * 3 / 4];
__nram__ float sum = 0.0f;
__nram__ int current_num = 0;
__nram__ int current_begin = 0;
for (int row = start_row; row < end_row; row++) {
for (int k = 0; k < row_ptr[row + 1] - row_ptr[row]; k += float_capacity) {
current_num = std::min(float_capacity, row_ptr[row + 1] - row_ptr[row] - k);
// move values data from GDRAM to NRAM
current_begin = row_ptr[row] + k;
__memcpy(nram_values, values + current_begin, current_num * sizeof(float), GDRAM2NRAM);
for (int i = 0; i < current_num; i++) {
sum += nram_values[i] * x[col_indices[current_begin + i]];
}
}
y[row] = sum;
sum = 0.0f; // reset sum for the next row
}
return;
}
namespace op::spmv::bang {
struct Descriptor::Opaque {
std::shared_ptr<device::bang::Handle::Internal> internal;
};

Descriptor::~Descriptor() { delete _opaque; }

infiniStatus_t Descriptor::create(infiniopHandle_t handle_,
Descriptor **desc_ptr, size_t num_cols,
size_t num_rows, size_t nnz,
infiniDtype_t dtype) {
auto handle = reinterpret_cast<device::bang::cambricon::Handle *>(handle_);

// currently only float32 supported
if (dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

auto result = SpMVInfo::create(num_cols, num_rows, nnz);
CHECK_RESULT(result);

*desc_ptr = new Descriptor(dtype, result.take(), new Opaque{handle->internal()},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(void *y, const void *x, const void *values,
const void *row_ptr,
const void *col_indices,
void *stream) const {
// do basic validation
auto validation_result = validateSpMVCSR(y, x, values, row_ptr, col_indices, _dtype);
CHECK_OR_RETURN(validation_result == INFINI_STATUS_SUCCESS, validation_result);

CNRT_CHECK(cnrtSetDevice(device_id));
cnrtQueue_t queue;
if (stream == nullptr || stream == NULL) {
CNRT_CHECK(cnrtQueueCreate(&queue));
} else {
queue = (cnrtQueue_t)stream;
}
cnrtDim3_t dim = {256, 1, 1};
cnrtFunctionType_t ktype = CNRT_FUNC_TYPE_BLOCK;

int num_rows = static_cast<int>(_info.num_rows);
int num_cols = static_cast<int>(_info.num_cols);
int nnz = static_cast<int>(_info.nnz);

int *d_row_ptr = (int *)row_ptr;
int *d_col_indices = (int *)col_indices;
float *d_values = (float *)values;
float *d_x = (float *)x;
float *d_y = (float *)y;

spmv_csr<<<dim, ktype, queue>>>(num_rows, num_cols, nnz, d_row_ptr, d_col_indices, d_values, d_x, d_y);

CNRT_CHECK(cnrtQueueSync(queue));

return INFINI_STATUS_SUCCESS;
}

} // namespace op::spmv::bang
82 changes: 82 additions & 0 deletions src/infiniop/ops/spmv/cpu/spmv_cpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#include "spmv_cpu.h"
#include "../../../devices/cpu/common_cpu.h"
#include "../info.h"
#include <cstring>

namespace op::spmv::cpu {

struct Descriptor::Opaque {
// CPU doesn't need special hardware context
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
size_t num_cols,
size_t num_rows,
size_t nnz,
infiniDtype_t dtype) {

auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);

// 当前仅支持单精度
if (dtype != INFINI_DTYPE_F32) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

auto result = SpMVInfo::create(num_cols, num_rows, nnz);
CHECK_RESULT(result);

*desc_ptr = new Descriptor(
dtype, result.take(),
new Opaque{},
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

// CSR Implementation of SpMV
static void spmv_csr_impl(
float *y, const float *x, const float *values,
const int32_t *row_ptr, const int32_t *col_idx,
size_t num_rows) {

#ifdef ENABLE_OMP
#pragma omp parallel for
#endif
for (int i = 0; i < static_cast<int>(num_rows); ++i) {
float sum = 0;
for (int32_t j = row_ptr[i]; j < row_ptr[i + 1]; ++j) {
sum += values[j] * x[col_idx[j]];
}
y[i] = sum;
}
}

infiniStatus_t Descriptor::calculate(
void *y,
const void *x,
const void *values,
const void *row_ptr,
const void *col_indices,
void *stream) const {

auto validation_result = validateSpMVCSR(
y, x, values, row_ptr, col_indices, _dtype);
CHECK_OR_RETURN(validation_result == INFINI_STATUS_SUCCESS, validation_result);

spmv_csr_impl(
static_cast<float *>(y),
static_cast<const float *>(x),
static_cast<const float *>(values),
static_cast<const int32_t *>(row_ptr),
static_cast<const int32_t *>(col_indices),
_info.num_rows);

return INFINI_STATUS_SUCCESS;
}

} // namespace op::spmv::cpu
8 changes: 8 additions & 0 deletions src/infiniop/ops/spmv/cpu/spmv_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __SPMV_CPU_H__
#define __SPMV_CPU_H__

#include "../spmv.h"

DESCRIPTOR(cpu)

#endif // __SPMV_CPU_H__
Loading