diff --git a/include/infiniop.h b/include/infiniop.h index d51b8d92e..ab4c8c231 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -14,6 +14,7 @@ #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" +#include "infiniop/ops/softmax.h" #include "infiniop/ops/sub.h" #include "infiniop/ops/swiglu.h" #include "infiniop/tensor_descriptor.h" diff --git a/include/infiniop/ops/softmax.h b/include/infiniop/ops/softmax.h new file mode 100644 index 000000000..d06f61345 --- /dev/null +++ b/include/infiniop/ops/softmax.h @@ -0,0 +1,20 @@ +#ifndef __INFINIOP_SOFTMAX_API_H__ +#define __INFINIOP_SOFTMAX_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSoftmaxDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSoftmaxDescriptor(infiniopHandle_t handle, + infiniopSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis); + +__C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, size_t *size); + +__C infiniStatus_t infiniopSoftmax(infiniopSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *y, const void *x, void *stream); + +__C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc new file mode 100644 index 000000000..ab18315a3 --- /dev/null +++ b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc @@ -0,0 +1,94 @@ +#include "softmax_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::softmax::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int axis) { + + auto handle = reinterpret_cast(handle_); + auto dtype = y->dtype(); + + const auto &x_shape = x->shape(); + const auto &y_shape = y->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + auto result = SoftmaxInfo::create(y, x, axis); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + dtype, + result.take(), + 0, + nullptr, + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +void softmax_cpu(const SoftmaxInfo &info, + const void *x, void *y, int axis) { + int dim_size = info.dim_size; + int stride = info.stride; + int other_size = info.other_size; + auto input = reinterpret_cast(x); + auto output = reinterpret_cast(y); + + auto compute_softmax = [&](int i) { + int tid = i % stride + (i - i % stride) * dim_size; + + float max_data = -INFINITY; + for (int j = 0; j < dim_size; j++) { + int index = tid + j * stride; + max_data = fmax(max_data, utils::cast(input[index])); + } + + float sum_data = 0.0f; + for (int j = 0; j < dim_size; j++) { + int index = tid + j * stride; + sum_data += std::exp(utils::cast(input[index]) - max_data); + } + + for (int j = 0; j < dim_size; j++) { + int index = tid + j * stride; + float result = std::exp(utils::cast(input[index]) - max_data) / sum_data; + output[index] = utils::cast(result); + } + }; +#pragma omp parallel for + for (int i = 0; i < other_size; i++) { + compute_softmax(i); + } +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream_) const { + switch (_dtype) { + case INFINI_DTYPE_F16: + softmax_cpu(_info, x, y, _info.axis); + return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_F32: + softmax_cpu(_info, x, y, _info.axis); + return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_BF16: + softmax_cpu(_info, x, y, _info.axis); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} +} // namespace op::softmax::cpu diff --git a/src/infiniop/ops/softmax/cpu/softmax_cpu.h b/src/infiniop/ops/softmax/cpu/softmax_cpu.h new file mode 100644 index 000000000..49a9d9bb9 --- /dev/null +++ b/src/infiniop/ops/softmax/cpu/softmax_cpu.h @@ -0,0 +1,8 @@ +#ifndef __SOFTMAX_CPU_H__ +#define __SOFTMAX_CPU_H__ + +#include "../softmax.h" + +DESCRIPTOR(cpu) + +#endif // __SOFTMAX_CPU_H__ diff --git a/src/infiniop/ops/softmax/cuda/softmax_cuda.cu b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu new file mode 100644 index 000000000..23f919595 --- /dev/null +++ b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu @@ -0,0 +1,54 @@ +#include "../../../devices/cuda/cuda_common.cuh" +#include "softmax_cuda.cuh" +#include "softmax_kernel.cuh" + +namespace op::softmax::cuda { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int axis) { + auto dtype = y->dtype(); + auto handle = reinterpret_cast(handle_); + auto result = SoftmaxInfo::create(y, x, axis); + CHECK_RESULT(result); + CHECK_SAME_SHAPE(y->shape(), x->shape()); + CHECK_DTYPE(y->dtype(), x->dtype(), INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + *desc_ptr = new Descriptor( + dtype, + result.take(), + 0, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream_) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return softmax_dispatch(_info, y, x, stream_); + case INFINI_DTYPE_F32: + return softmax_dispatch(_info, y, x, stream_); + case INFINI_DTYPE_BF16: + return softmax_dispatch<__nv_bfloat16>(_info, y, x, stream_); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} +} // namespace op::softmax::cuda diff --git a/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh b/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh new file mode 100644 index 000000000..6031e113b --- /dev/null +++ b/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __SOFTMAX_CUDA_CUH__ +#define __SOFTMAX_CUDA_CUH__ + +#include "../softmax.h" + +DESCRIPTOR(cuda) + +#endif // __SOFTMAX_CUDA_CUH__ diff --git a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh new file mode 100644 index 000000000..54ab99af0 --- /dev/null +++ b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh @@ -0,0 +1,262 @@ +#ifndef __SOFTMAX_CUDA_KERNEL_H__ +#define __SOFTMAX_CUDA_KERNEL_H__ + +#include "../../../devices/cuda/cuda_kernel_common.cuh" +#include "softmax_cuda.cuh" +#include +#include + +struct __align__(8) MD { + float max; + float sum; +}; + +__device__ __forceinline__ MD reduce_for_md(MD a, MD b) { + bool is_a_bigger = a.max > b.max; + MD bigger = is_a_bigger ? a : b; + MD smaller = is_a_bigger ? b : a; + bigger.sum = bigger.sum + __expf(smaller.max - bigger.max) * smaller.sum; + return bigger; +} + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template <> +struct SumOp { + __device__ __forceinline__ half operator()(const half &a, const half &b) const { + return __hadd(a, b); + } +}; + +template <> +struct SumOp<__nv_bfloat16> { + __device__ __forceinline__ __nv_bfloat16 operator()(const __nv_bfloat16 &a, const __nv_bfloat16 &b) const { + return __hadd(a, b); + } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; + +template <> +struct MaxOp { + __device__ __forceinline__ half operator()(const half &a, const half &b) const { + return __hmax(a, b); + } +}; + +template <> +struct MaxOp<__nv_bfloat16> { + __device__ __forceinline__ __nv_bfloat16 operator()(const __nv_bfloat16 &a, const __nv_bfloat16 &b) const { + return __hmax(a, b); + } +}; + +template class ReduceOp, int THREAD_GROUP_WIDTH = 32> +__device__ __forceinline__ T warpReduce(T value) { + for (int mask = THREAD_GROUP_WIDTH / 2; mask > 0; mask /= 2) { + value = ReduceOp()(value, __shfl_xor_sync(0xffffffff, value, mask)); + } + return value; +} + +// 高维度softmax,例如已知axis=1,输入为shape为[a1, a2, a3, a4] +// 在a1 * a3 * a4组tensor,每个tensor的shape为[a2],进行求和和求max +// 所以我在算子desc创建的时候需要计算出规约轴中访问每个元素的步长stride以及总共多少组长度为a2的tensor +// 当规约轴元素较少的时候可以一个warp处理一组tensor,以1024为界限,当规约轴元素少于1024时使用一个warp处理一组tensor +// 1024 / 32 = 32 +// blockDim.x = 32 -> warp_size +// blockDim.y = 32, 也就是一个block 32个warp +// dim3 block(32, 32); +// threadIdx.y表示block内的warp_id +// BLOCK_DIM_Y代表每个block的warp数目 +/* +第0个元素:[i, 0, j]位置 +第1个元素:[i, 1, j]位置 +第19个元素:[i, 19, j]位置 +(tid + idx * BLOCK_DIM_x) * stride得到的是在axis索引的线性offset +也就是我们还需要i 和 j +i 也就是 (blockIdx.x * blockDim.y + threadIdx.y) / stride +j 也就是 (blockIdx.x * blockDim.y + threadIdx.y) % stride +然后i转化为线性也就是 i * stride * dim_size +j直接加上就好 +*/ +template +__global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dim_size, int other_size) { + float dataPerThread[ELEM_PER_THREAD]; + int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; + int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dim_size; + int tid = threadIdx.x; + if (global_warp_id >= other_size) { + return; + } + __shared__ float group_max[BLOCK_DIM_X]; + __shared__ float group_sum[BLOCK_DIM_X]; + float thread_max = -INFINITY; + float thread_sum = 0.0f; + for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) { + dataPerThread[i] = static_cast(x[(tid + i * BLOCK_DIM_X) * stride + group_offset]); + thread_max = max(thread_max, dataPerThread[i]); + } + + thread_max = warpReduce(thread_max); + if (tid == 0) { + group_max[threadIdx.y] = thread_max; + } + + for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) { + dataPerThread[i] = __expf(dataPerThread[i] - group_max[threadIdx.y]); + thread_sum += dataPerThread[i]; + } + + thread_sum = warpReduce(thread_sum); + if (tid == 0) { + group_sum[threadIdx.y] = thread_sum; + } + + for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) { + y[(tid + i * BLOCK_DIM_X) * stride + group_offset] = static_cast(dataPerThread[i] * __fdividef(1.0f, group_sum[threadIdx.y])); + } +} + +template +__launch_bounds__(BLOCK_DIM) + __global__ void Softmax_block_impl(const T *x, T *y, int stride, int dim_size, int other_size) { + // remain = dim_size - BLOCK_DIM * ELEM_PER_THREAD + int tid = threadIdx.x; + int block_offset = (blockIdx.x - blockIdx.x % stride) * dim_size + blockIdx.x % stride; + int remain = dim_size - (BLOCK_DIM - 1) * ELEM_PER_THREAD; + + MD md_partial; + md_partial.max = -INFINITY; + md_partial.sum = 0.0f; + MD input; + // tid = [0, BLOCK_DIM - 1], 所以最后一个线程处理余数部分 + if (tid < BLOCK_DIM - 1) { +#pragma unroll + for (int i = 0; i < ELEM_PER_THREAD; i++) { + int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset; + input.max = static_cast(x[index]); + input.sum = 1.0f; + md_partial = reduce_for_md(md_partial, input); + } + } else { +#pragma unroll + for (int i = 0; i < remain; i++) { + int index = ((BLOCK_DIM - 1) * ELEM_PER_THREAD + i) * stride + block_offset; + input.max = static_cast(x[index]); + input.sum = 1.0f; + md_partial = reduce_for_md(md_partial, input); + } + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ MD md_total; + MD md_block = BlockReduce(temp_storage).Reduce(md_partial, reduce_for_md); + if (threadIdx.x == 0) { + md_total = md_block; + } + __syncthreads(); + if (tid < BLOCK_DIM - 1) { + for (int i = 0; i < ELEM_PER_THREAD; i++) { + int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset; + y[index] = static_cast(__expf(static_cast(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum)); + } + } else { + for (int i = 0; i < remain; i++) { + int index = ((BLOCK_DIM - 1) * ELEM_PER_THREAD + i) * stride + block_offset; + y[index] = static_cast(__expf(static_cast(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum)); + } + } +} + +template +void dispatchSoftmaxKernel( + const void *x, void *y, + int stride, int dim_size, int other_size, + void *stream, bool use_warp_impl) { + + int elemPerThread; + dim3 grid, block; + + if (use_warp_impl) { + block = dim3(32, 32); + grid = dim3((other_size + block.y - 1) / block.y, 1, 1); + elemPerThread = min((dim_size + 31) / 32, 32); + +#define LAUNCH_WARP_KERNEL(ELEM_PER_THREAD) \ + Softmax_warp_impl \ + <<(stream)>>>( \ + reinterpret_cast(x), reinterpret_cast(y), \ + stride, dim_size, other_size) + + if (elemPerThread <= 1) { + LAUNCH_WARP_KERNEL(1); + } else if (elemPerThread <= 2) { + LAUNCH_WARP_KERNEL(2); + } else if (elemPerThread <= 4) { + LAUNCH_WARP_KERNEL(4); + } else if (elemPerThread <= 8) { + LAUNCH_WARP_KERNEL(8); + } else if (elemPerThread <= 16) { + LAUNCH_WARP_KERNEL(16); + } else { + LAUNCH_WARP_KERNEL(32); + } + +#undef LAUNCH_WARP_KERNEL + + } else { + // Block implementation for dim_size > 1024 + constexpr int BLOCK_SIZE = 1024; + block = dim3(BLOCK_SIZE); + grid = dim3(other_size); + elemPerThread = min((dim_size + BLOCK_SIZE - 1) / BLOCK_SIZE, 32); + +#define LAUNCH_BLOCK_KERNEL(ELEM_PER_THREAD) \ + Softmax_block_impl \ + <<(stream)>>>( \ + reinterpret_cast(x), reinterpret_cast(y), \ + stride, dim_size, other_size) + + if (elemPerThread <= 1) { + LAUNCH_BLOCK_KERNEL(1); + } else if (elemPerThread <= 2) { + LAUNCH_BLOCK_KERNEL(2); + } else if (elemPerThread <= 4) { + LAUNCH_BLOCK_KERNEL(4); + } else if (elemPerThread <= 8) { + LAUNCH_BLOCK_KERNEL(8); + } else if (elemPerThread <= 16) { + LAUNCH_BLOCK_KERNEL(16); + } else { + LAUNCH_BLOCK_KERNEL(32); + } + +#undef LAUNCH_BLOCK_KERNEL + } +} + +template +infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) { + int dim_size = info.dim_size; + int stride = info.stride; + int other_size = info.other_size; + if (dim_size <= 1024) { + dispatchSoftmaxKernel(x, y, stride, dim_size, other_size, stream, true); + } else if (dim_size > 1024) { + dispatchSoftmaxKernel(x, y, stride, dim_size, other_size, stream, false); + } + return INFINI_STATUS_SUCCESS; +} + +#endif // __SOFTMAX_CUDA_KERNEL_H__ diff --git a/src/infiniop/ops/softmax/info.h b/src/infiniop/ops/softmax/info.h new file mode 100644 index 000000000..67be5e14d --- /dev/null +++ b/src/infiniop/ops/softmax/info.h @@ -0,0 +1,47 @@ +#ifndef __SOFTMAX_INFO_H__ +#define __SOFTMAX_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include + +namespace op::softmax { +class SoftmaxInfo { +public: + int axis; + int other_size; + int stride; + int size; + int dim_size; + + static utils::Result create( + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis) { + + if (y_desc->ndim() != x_desc->ndim()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + SoftmaxInfo info; + info.axis = axis; + info.size = 1; + info.other_size = 1; + info.stride = 1; + info.dim_size = static_cast(x_desc->dim(axis)); + int ndim = static_cast(y_desc->ndim()); + for (int i = ndim - 1; i >= 0; i--) { + info.size *= static_cast(y_desc->dim(i)); + } + info.stride = 1; + for (int i = axis + 1; i < ndim; i++) { + info.stride *= static_cast(x_desc->dim(i)); + } + info.other_size = info.size / info.dim_size; + return utils::Result(info); + } +}; +} // namespace op::softmax + +#endif // __SOFTMAX_INFO_H__ diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc new file mode 100644 index 000000000..dcdcf1d3f --- /dev/null +++ b/src/infiniop/ops/softmax/operator.cc @@ -0,0 +1,110 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/softmax.h" + +#ifdef ENABLE_CPU_API +#include "cpu/softmax_cpu.h" +#endif +#ifdef ENABLE_CUDA_API +#include "cuda/softmax_cuda.cuh" +#endif + +__C __export infiniStatus_t infiniopCreateSoftmaxDescriptor(infiniopHandle_t handle, + infiniopSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::softmax::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + x_desc, \ + axis) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t +infiniopGetSoftmaxWorkspaceSize( + infiniopSoftmaxDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + GET(INFINI_DEVICE_NVIDIA, cuda); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__C infiniStatus_t infiniopSoftmax( + infiniopSoftmaxDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, \ + y, \ + x, \ + stream) + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CALCULATE(INFINI_DEVICE_NVIDIA, cuda); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) { +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + DELETE(INFINI_DEVICE_NVIDIA, cuda); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} diff --git a/src/infiniop/ops/softmax/softmax.h b/src/infiniop/ops/softmax/softmax.h new file mode 100644 index 000000000..dc0520f77 --- /dev/null +++ b/src/infiniop/ops/softmax/softmax.h @@ -0,0 +1,50 @@ +#ifndef __SOFTMAX_H__ +#define __SOFTMAX_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::softmax::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + infiniDtype_t _dtype; \ + SoftmaxInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + infiniDtype_t dtype, \ + SoftmaxInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _dtype(dtype), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y, \ + infiniopTensorDescriptor_t x, \ + int axis); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream) const; \ + }; \ + } + +#endif // __SOFTMAX_H__ diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ff583d9c0..0ba71d5e0 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -431,3 +431,35 @@ def swiglu_(lib): lib.infiniopDestroySwiGLUDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + +@OpRegister.operator +def softmax_(lib): + lib.infiniopCreateSoftmaxDescriptor.restype = c_int32 + lib.infiniopCreateSoftmaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int32, + ] + + lib.infiniopGetSoftmaxWorkspaceSize.restype = c_int32 + lib.infiniopGetSoftmaxWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSoftmax.restype = c_int32 + lib.infiniopSoftmax.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySoftmaxDescriptor.restype = c_int32 + lib.infiniopDestroySoftmaxDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] diff --git a/test/infiniop/softmax.py b/test/infiniop/softmax.py new file mode 100644 index 000000000..0900cc8bd --- /dev/null +++ b/test/infiniop/softmax.py @@ -0,0 +1,183 @@ +import torch +import ctypes +from ctypes import c_uint64, c_int32 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # shape, reduce_axis, stride + ((32, 20, 512), 0, (20 * 512, 512, 1)), + ((32, 20, 512), 1, (20 * 512, 512, 1)), + ((32, 20, 512), 2, (20 * 512, 512, 1)), + + # 2D 张量测试 + ((128, 256), 0, (256, 1)), + ((128, 256), 1, (256, 1)), + ((1024, 1024), 0, (1024, 1)), + ((1024, 1024), 1, (1024, 1)), + + # 4D 张量测试 + ((8, 32, 64, 64), 0, (32 * 64 * 64, 64 * 64, 64, 1)), + ((8, 32, 64, 64), 1, (32 * 64 * 64, 64 * 64, 64, 1)), + ((8, 32, 64, 64), 2, (32 * 64 * 64, 64 * 64, 64, 1)), + ((8, 32, 64, 64), 3, (32 * 64 * 64, 64 * 64, 64, 1)), + + # 5D 张量测试 + ((4, 16, 8, 32, 32), 0, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 1, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 2, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 3, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 4, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + + # 小尺寸测试 + ((2, 3), 0, (3, 1)), + ((2, 3), 1, (3, 1)), + ((1, 10), 0, (10, 1)), + ((1, 10), 1, (10, 1)), + ((10, 1), 0, (1, 1)), + ((10, 1), 1, (1, 1)), + + ((1000,), 0, (1,)), + + ((7, 333, 777), 0, (333 * 777, 777, 1)), + ((7, 333, 777), 1, (333 * 777, 777, 1)), + ((7, 333, 777), 2, (333 * 777, 777, 1)), + ((13, 509, 251), 0, (509 * 251, 251, 1)), + ((13, 509, 251), 1, (509 * 251, 251, 1)), + ((13, 509, 251), 2, (509 * 251, 251, 1)), + + ((64, 1024, 768), 0, (1024 * 768, 768, 1)), + ((64, 1024, 768), 1, (1024 * 768, 768, 1)), + ((64, 1024, 768), 2, (1024 * 768, 768, 1)), + ((32, 2048, 512), 0, (2048 * 512, 512, 1)), + ((32, 2048, 512), 1, (2048 * 512, 512, 1)), + ((32, 2048, 512), 2, (2048 * 512, 512, 1)), + + ((1024, 1), 0, (1, 1)), + ((1024, 1), 1, (1, 1)), + ((1, 1024), 0, (1024, 1)), + ((1, 1024), 1, (1024, 1)), + + ((32, 8, 512, 64), 1, (8 * 512 * 64, 512 * 64, 64, 1)), + ((32, 8, 512, 64), 2, (8 * 512 * 64, 512 * 64, 64, 1)), + ((32, 8, 512, 64), 3, (8 * 512 * 64, 512 * 64, 64, 1)), +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-6}, + InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-2}, +} + +def softmax(x, axis, y): + torch.softmax(x, axis = axis, out=y) + +def test( + handle, + device, + shape, + axis, + stride, + dtype=torch.float16, + sync=None, +): + x = TestTensor(shape, stride, dtype, device) + y = TestTensor(shape, stride, dtype, device, mode="zeros") + + print( + f"Testing softmax on {InfiniDeviceNames[device]} with shape:{shape} stride:{stride} axis:{axis} " + f"dtype:{dtype}" + ) + # a = torch.randn(shape, dtype=dtype).to(torch_device) * 0.1 + # b = torch.empty_like(a) + softmax(x.torch_tensor(), axis, y.torch_tensor()) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateSoftmaxDescriptor( + handle, + ctypes.byref(descriptor), + y.descriptor, + x.descriptor, + c_int32(axis), + ) + ) + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetSoftmaxWorkspaceSize(descriptor, ctypes.byref(workspace_size)) + ) + workspace = TestWorkspace(workspace_size.value, x.device) + + def lib_softmax(): + check_error( + LIBINFINIOP.infiniopSoftmax( + descriptor, + workspace.data() if workspace is not None else None, + workspace_size.value, + y.data(), + x.data(), + None, + ) + ) + + lib_softmax() + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [x, y]: + tensor.destroy_desc() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: softmax(x.torch_tensor(), axis, y.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_softmax(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroySoftmaxDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") +