From ab89fa8b3aa5a1d733e01ee58897262b02c09f61 Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Mon, 16 Jun 2025 15:28:53 +0800 Subject: [PATCH 1/8] Issue/259 add softmax operator --- include/infiniop/ops/softmax.h | 19 ++ src/infiniop/ops/softmax/cpu/softmax_cpu.cc | 105 ++++++++ src/infiniop/ops/softmax/cpu/softmax_cpu.h | 8 + src/infiniop/ops/softmax/cuda/softmax_cuda.cu | 52 ++++ .../ops/softmax/cuda/softmax_cuda.cuh | 8 + .../ops/softmax/cuda/softmax_kernel.cuh | 239 ++++++++++++++++++ src/infiniop/ops/softmax/info.h | 47 ++++ src/infiniop/ops/softmax/operator.cc | 109 ++++++++ src/infiniop/ops/softmax/softmax.h | 49 ++++ test/infiniop/softmax.py | 227 +++++++++++++++++ 10 files changed, 863 insertions(+) create mode 100644 include/infiniop/ops/softmax.h create mode 100644 src/infiniop/ops/softmax/cpu/softmax_cpu.cc create mode 100644 src/infiniop/ops/softmax/cpu/softmax_cpu.h create mode 100644 src/infiniop/ops/softmax/cuda/softmax_cuda.cu create mode 100644 src/infiniop/ops/softmax/cuda/softmax_cuda.cuh create mode 100644 src/infiniop/ops/softmax/cuda/softmax_kernel.cuh create mode 100644 src/infiniop/ops/softmax/info.h create mode 100644 src/infiniop/ops/softmax/operator.cc create mode 100644 src/infiniop/ops/softmax/softmax.h create mode 100644 test/infiniop/softmax.py diff --git a/include/infiniop/ops/softmax.h b/include/infiniop/ops/softmax.h new file mode 100644 index 000000000..e1f3fc4d4 --- /dev/null +++ b/include/infiniop/ops/softmax.h @@ -0,0 +1,19 @@ +#ifndef __INFINIOP_MLP_API_H__ +#define __INFINIOP_MLP_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopSoftmaxDescriptor_t; + +__C __export infiniStatus_t infiniopCreateSoftmaxDescriptor(infiniopHandle_t handle, + infiniopSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis); + +__C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t desc, size_t *size); + +__C infiniStatus_t infiniopSoftmax(infiniopSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *y, const void *x, void *stream); + +__C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc); +#endif diff --git a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc new file mode 100644 index 000000000..47113cfe8 --- /dev/null +++ b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc @@ -0,0 +1,105 @@ +#include "softmax_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::softmax::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int axis) { + + auto handle = reinterpret_cast(handle_); + auto dtype = y->dtype(); + + const auto &x_shape = x->shape(); + const auto &y_shape = y->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + auto result = SoftmaxInfo::create(y, x, axis); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + dtype, + result.take(), + 0, + nullptr, + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +void softmax_cpu(const SoftmaxInfo &info, + const void *x, void *y, int axis) { + int dimsize = info.dimsize; + int stride = info.stride; + int othersize = info.otherdim_size; + if constexpr (std::is_same_v) { + auto input = reinterpret_cast(x); + auto output = reinterpret_cast(y); + for (int i = 0; i < othersize; i++) { + int tid = i % stride + (i - i % stride) * dimsize; + float max_data = -INFINITY; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + max_data = fmax(max_data, utils::cast(input[index])); + } + float sum_data = 0.0f; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + sum_data += std::exp(utils::cast(input[index]) - max_data); + } + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + output[index] = utils::cast(std::exp(utils::cast(input[index]) - max_data) / sum_data); + } + } + } else if constexpr (std::is_same_v) { + auto input = reinterpret_cast(x); + auto output = reinterpret_cast(y); +#pragma omp parallel for + for (int i = 0; i < othersize; i++) { + int tid = i % stride + (i - i % stride) * dimsize; + float max_data = -INFINITY; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + max_data = fmax(max_data, input[index]); + } + float sum_data = 0.0f; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + sum_data += std::exp(input[index] - max_data); + } + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + output[index] = std::exp(input[index] - max_data) / sum_data; + } + } + } +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream_) const { + switch (_dtype) { + case INFINI_DTYPE_F16: + softmax_cpu(_info, x, y, _info.axis); + return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_F32: + softmax_cpu(_info, x, y, _info.axis); + return INFINI_STATUS_SUCCESS; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} +} // namespace op::softmax::cpu diff --git a/src/infiniop/ops/softmax/cpu/softmax_cpu.h b/src/infiniop/ops/softmax/cpu/softmax_cpu.h new file mode 100644 index 000000000..49a9d9bb9 --- /dev/null +++ b/src/infiniop/ops/softmax/cpu/softmax_cpu.h @@ -0,0 +1,8 @@ +#ifndef __SOFTMAX_CPU_H__ +#define __SOFTMAX_CPU_H__ + +#include "../softmax.h" + +DESCRIPTOR(cpu) + +#endif // __SOFTMAX_CPU_H__ diff --git a/src/infiniop/ops/softmax/cuda/softmax_cuda.cu b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu new file mode 100644 index 000000000..b5283a427 --- /dev/null +++ b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu @@ -0,0 +1,52 @@ +#include "../../../devices/cuda/cuda_common.cuh" +#include "softmax_cuda.cuh" +#include "softmax_kernel.cuh" + +namespace op::softmax::cuda { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int axis) { + auto dtype = y->dtype(); + auto handle = reinterpret_cast(handle_); + auto result = SoftmaxInfo::create(y, x, axis); + CHECK_RESULT(result); + CHECK_SAME_SHAPE(y->shape(), x->shape()); + CHECK_DTYPE(y->dtype(), x->dtype(), INFINI_DTYPE_F16, INFINI_DTYPE_F32); + *desc_ptr = new Descriptor( + dtype, + result.take(), + 0, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream_) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return softmax_dispatch(_info, y, x, stream_); + case INFINI_DTYPE_F32: + return softmax_dispatch(_info, y, x, stream_); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} +} // namespace op::softmax::cuda \ No newline at end of file diff --git a/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh b/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh new file mode 100644 index 000000000..0a9c63187 --- /dev/null +++ b/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh @@ -0,0 +1,8 @@ +#ifndef __SOFTMAX_CUDA_CUH__ +#define __SOFTMAX_CUDA_CUH__ + +#include "../softmax.h" + +DESCRIPTOR(cuda) + +#endif // __SOFTMAX_CUDA_CUH__ \ No newline at end of file diff --git a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh new file mode 100644 index 000000000..f8b81780c --- /dev/null +++ b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh @@ -0,0 +1,239 @@ +#ifndef __SOFTMAX_CUDA_KERNEL_H__ +#define __SOFTMAX_CUDA_KERNEL_H__ +#include "../../../devices/cuda/cuda_kernel_common.cuh" +#include "softmax_cuda.cuh" +#include +#include + +struct __align__(8) MD { + float max; + float sum; +}; + +__device__ __forceinline__ MD reduce_for_md(MD a, MD b) { + bool is_a_bigger = a.max > b.max; + MD bigger = is_a_bigger ? a : b; + MD smaller = is_a_bigger ? b : a; + bigger.sum = bigger.sum + __expf(smaller.max - bigger.max) * smaller.sum; + return bigger; +} + +template +struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template <> +struct SumOp { + __device__ __forceinline__ half operator()(const half &a, const half &b) const { + return __hadd(a, b); + } +}; + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; + +template <> +struct MaxOp { + __device__ __forceinline__ half operator()(const half &a, const half &b) const { + return __hmax(a, b); + } +}; + +template class ReduceOp, int thread_group_width = 32> +__device__ __forceinline__ T warpReduce(T value) { + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + value = ReduceOp()(value, __shfl_xor_sync(0xffffffff, value, mask)); + } + return value; +} + +// 高维度softmax,例如已知axis=1,输入为shape为[a1, a2, a3, a4] +// 在a1 * a3 * a4组tensor,每个tensor的shape为[a2],进行求和和求max +// 所以我在算子desc创建的时候需要计算出规约轴中访问每个元素的步长stride以及总共多少组长度为a2的tensor +// 当规约轴元素较少的时候可以一个warp处理一组tensor,以1024为界限,当规约轴元素少于1024时使用一个warp处理一组tensor +// 1024 / 32 = 32 +// blockDim.x = 32 -> warp_size +// blockDim.y = 32, 也就是一个block 32个warp +// dim3 block(32, 32); +// threadIdx.y表示block内的warp_id +// BLOCK_DIM_Y代表每个block的warp数目 +/* +第0个元素:[i, 0, j]位置 +第1个元素:[i, 1, j]位置 +第19个元素:[i, 19, j]位置 +(tid + idx * BLOCK_DIM_x) * stride得到的是在axis索引的线性offset +也就是我们还需要i 和 j +i 也就是 (blockIdx.x * blockDim.y + threadIdx.y) / stride +j 也就是 (blockIdx.x * blockDim.y + threadIdx.y) % stride +然后i转化为线性也就是 i * stride * dimsize +j直接加上就好 + +*/ +template +__global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) { + float dataPerThread[elemPerThread]; + int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; + int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dimsize; + int tid = threadIdx.x; + if (global_warp_id >= otherdim_size) { + return; + } + __shared__ float group_max[BLOCK_DIM_X]; + __shared__ float group_sum[BLOCK_DIM_X]; + float thread_max = -INFINITY; + float thread_sum = 0.0f; + for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) { + dataPerThread[i] = static_cast(x[(tid + i * BLOCK_DIM_X) * stride + group_offset]); + thread_max = max(thread_max, dataPerThread[i]); + } + + thread_max = warpReduce(thread_max); + if (tid == 0) { + group_max[threadIdx.y] = thread_max; + } + + for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) { + dataPerThread[i] = __expf(dataPerThread[i] - group_max[threadIdx.y]); + thread_sum += dataPerThread[i]; + } + + thread_sum = warpReduce(thread_sum); + if (tid == 0) { + group_sum[threadIdx.y] = thread_sum; + } + + for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) { + y[(tid + i * BLOCK_DIM_X) * stride + group_offset] = static_cast(dataPerThread[i] * __fdividef(1.0f, group_sum[threadIdx.y])); + } +} + +template +__launch_bounds__(BLOCK_DIM) + __global__ void Softmax_block_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) { + // remain = dimsize - BLOCK_DIM * elemPerThread + int tid = threadIdx.x; + int block_offset = (blockIdx.x - blockIdx.x % stride) * dimsize + blockIdx.x % stride; + int remain = dimsize - (BLOCK_DIM - 1) * elemPerThread; // 🔧 修正:最后线程处理的元素数 + + MD md_partial; + md_partial.max = -INFINITY; + md_partial.sum = 0.0f; + MD input; + // tid = [0, BLOCK_DIM - 1], 所以最后一个线程处理余数部分 + if (tid < BLOCK_DIM - 1) { +#pragma unroll + for (int i = 0; i < elemPerThread; i++) { + int index = (tid * elemPerThread + i) * stride + block_offset; + input.max = static_cast(x[index]); + input.sum = 1.0f; + md_partial = reduce_for_md(md_partial, input); + } + } else { +#pragma unroll + for (int i = 0; i < remain; i++) { + int index = ((BLOCK_DIM - 1) * elemPerThread + i) * stride + block_offset; + input.max = static_cast(x[index]); + input.sum = 1.0f; + md_partial = reduce_for_md(md_partial, input); + } + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ MD md_total; + MD md_block = BlockReduce(temp_storage).Reduce(md_partial, reduce_for_md); + if (threadIdx.x == 0) { + md_total = md_block; + } + __syncthreads(); + if (tid < BLOCK_DIM - 1) { + for (int i = 0; i < elemPerThread; i++) { + int index = (tid * elemPerThread + i) * stride + block_offset; + y[index] = static_cast(__expf(static_cast(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum)); + } + } else { + for (int i = 0; i < remain; i++) { + int index = ((BLOCK_DIM - 1) * elemPerThread + i) * stride + block_offset; + y[index] = static_cast(__expf(static_cast(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum)); + } + } +} + +template +infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) { + int dimsize = info.dimsize; + int stride = info.stride; + int otherdim_size = info.otherdim_size; + if (dimsize <= 1024) { + dim3 block(32, 32); // BLOCK_DIM_X=32, BLOCK_DIM_Y=4 + int num_blocks = (otherdim_size + block.y - 1) / block.y; + dim3 grid(num_blocks, 1, 1); + int elemPerThread = (dimsize + 31) / 32; // 计算每个线程需要处理的元素数 + elemPerThread = min(elemPerThread, 32); // 限制最大值 + if (elemPerThread <= 1) { + Softmax_warp_impl<1, 32, 32, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 2) { + Softmax_warp_impl<2, 32, 32, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 4) { + Softmax_warp_impl<4, 32, 32, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 8) { + Softmax_warp_impl<8, 32, 32, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 16) { + Softmax_warp_impl<16, 32, 32, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else { + Softmax_warp_impl<32, 32, 32, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } + } else if (dimsize > 1024) { + int block_size = 1024; + int elemPerThread = (dimsize + block_size - 1) / block_size; // 每个线程需要处理的元素数 + elemPerThread = min(elemPerThread, 32); // 限制最大值为32 + dim3 block(block_size); + dim3 grid(otherdim_size); + if (elemPerThread <= 1) { + Softmax_block_impl<1, 1024, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 2) { + Softmax_block_impl<2, 1024, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 4) { + Softmax_block_impl<4, 1024, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 8) { + Softmax_block_impl<8, 1024, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else if (elemPerThread <= 16) { + Softmax_block_impl<16, 1024, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } else { + Softmax_block_impl<32, 1024, T> + <<(stream)>>>( + reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + } + } + return INFINI_STATUS_SUCCESS; +} + +#endif // __SOFTMAX_CUDA_KERNEL_H__ \ No newline at end of file diff --git a/src/infiniop/ops/softmax/info.h b/src/infiniop/ops/softmax/info.h new file mode 100644 index 000000000..a40dc066a --- /dev/null +++ b/src/infiniop/ops/softmax/info.h @@ -0,0 +1,47 @@ +#ifndef __CONV_INFO_H__ +#define __CONV_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include + +namespace op::softmax { +class SoftmaxInfo { +public: + int axis; + int otherdim_size; + int stride; + int size; + int dimsize; + + static utils::Result create( + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis) { + + if (y_desc->ndim() != x_desc->ndim()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + SoftmaxInfo info; + info.axis = axis; + info.size = 1; + info.otherdim_size = 1; + info.stride = 1; + info.dimsize = static_cast(x_desc->dim(axis)); + int ndim = y_desc->ndim(); + for (int i = ndim - 1; i >= 0; i--) { + info.size *= static_cast(y_desc->dim(i)); + } + info.stride = 1; + for (int i = axis + 1; i < ndim; i++) { + info.stride *= static_cast(x_desc->dim(i)); + } + info.otherdim_size = info.size / info.dimsize; + return utils::Result(info); + } +}; +} // namespace op::softmax + +#endif // __CONV_INFO_H__ diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc new file mode 100644 index 000000000..d3df9a609 --- /dev/null +++ b/src/infiniop/ops/softmax/operator.cc @@ -0,0 +1,109 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/softmax.h" + +#ifdef ENABLE_CPU_API +#include "cpu/softmax_cpu.h" +#endif +#ifdef ENABLE_CUDA_API +#include "cuda/softmax_cuda.cuh" +#endif + +__C __export infiniStatus_t infiniopCreateSoftmaxDescriptor(infiniopHandle_t handle, + infiniopSoftmaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + int axis) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::softmax::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + x_desc, \ + axis) + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CREATE(INFINI_DEVICE_NVIDIA, cuda); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t +infiniopGetSoftmaxWorkspaceSize( + infiniopSoftmaxDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + GET(INFINI_DEVICE_NVIDIA, cuda); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef GET +} + +__C infiniStatus_t infiniopSoftmax( + infiniopSoftmaxDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, \ + y, \ + x, \ + stream) + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + CALCULATE(INFINI_DEVICE_NVIDIA, cuda); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) { +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_CUDA_API + DELETE(INFINI_DEVICE_NVIDIA, cuda); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} diff --git a/src/infiniop/ops/softmax/softmax.h b/src/infiniop/ops/softmax/softmax.h new file mode 100644 index 000000000..664e6b6e4 --- /dev/null +++ b/src/infiniop/ops/softmax/softmax.h @@ -0,0 +1,49 @@ +#ifndef __SOFTMAX_H__ +#define __SOFTMAX_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::softmax::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + infiniDtype_t _dtype; \ + SoftmaxInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + infiniDtype_t dtype, \ + SoftmaxInfo info, \ + size_t workspace_size_, \ + Opaque *opaque, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _dtype(dtype), \ + _info(info), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y, \ + infiniopTensorDescriptor_t x, \ + int axis); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream) const; \ + }; \ + } +#endif // __CONV_H__ diff --git a/test/infiniop/softmax.py b/test/infiniop/softmax.py new file mode 100644 index 000000000..bf286de68 --- /dev/null +++ b/test/infiniop/softmax.py @@ -0,0 +1,227 @@ +import torch +import ctypes +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 +from libinfiniop import ( + infiniopHandle_t, + infiniopTensorDescriptor_t, + open_lib, + to_tensor, + get_test_devices, + check_error, + rearrange_if_needed, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + create_workspace, +) +from enum import Enum, auto + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + ((32, 20, 512), 0), + ((32, 20, 512), 1), + ((32, 20, 512), 2), + + # 2D 张量测试 + ((128, 256), 0), + ((128, 256), 1), + ((1024, 1024), 0), + ((1024, 1024), 1), + + # 4D 张量测试 + ((8, 32, 64, 64), 0), + ((8, 32, 64, 64), 1), + ((8, 32, 64, 64), 2), + ((8, 32, 64, 64), 3), + + # 5D 张量测试 + ((4, 16, 8, 32, 32), 0), + ((4, 16, 8, 32, 32), 1), + ((4, 16, 8, 32, 32), 2), + ((4, 16, 8, 32, 32), 3), + ((4, 16, 8, 32, 32), 4), + + # 小尺寸测试 + ((2, 3), 0), + ((2, 3), 1), + ((1, 10), 0), + ((1, 10), 1), + ((10, 1), 0), + ((10, 1), 1), + + ((1000,), 0), + + ((7, 333, 777), 0), + ((7, 333, 777), 1), + ((7, 333, 777), 2), + ((13, 509, 251), 0), + ((13, 509, 251), 1), + ((13, 509, 251), 2), + + ((64, 1024, 768), 0), + ((64, 1024, 768), 1), + ((64, 1024, 768), 2), + ((32, 2048, 512), 0), + ((32, 2048, 512), 1), + ((32, 2048, 512), 2), + + ((1024, 1), 0), + ((1024, 1), 1), + ((1, 1024), 0), + ((1, 1024), 1), + + ((32, 8, 512, 64), 1), + ((32, 8, 512, 64), 2), + ((32, 8, 512, 64), 3), +] + + +# Data types used for testing +_TENSOR_DTYPES = [torch.float16, torch.float32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + torch.float16: {"atol": 1e-3, "rtol": 1e-3}, + torch.float32: {"atol": 1e-7, "rtol": 1e-7}, +} + + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class SoftmaxDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopSoftmaxDescriptor_t = POINTER(SoftmaxDescriptor) + + +def softmax(x, axis): + return torch.softmax(x, axis = axis).to(x.dtype) + +def test( + lib, + handle, + torch_device, + shape, + axis, + dtype=torch.float16, + sync=None, +): + print( + f"Testing softmax on {torch_device} with shape:{shape}" + f"dtype:{dtype}" + ) + + a = torch.randn(shape, dtype=dtype).to(torch_device) * 0.1 + b = torch.empty_like(a) + ans = softmax(a, axis) + + a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] + + if sync is not None: + sync() + + descriptor = infiniopSoftmaxDescriptor_t() + check_error( + lib.infiniopCreateSoftmaxDescriptor( + handle, + ctypes.byref(descriptor), + b_tensor.descriptor, + a_tensor.descriptor, + c_int32(axis), + ) + ) + + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a_tensor, b_tensor]: + tensor.destroyDesc(lib) + + workspace_size = c_uint64(0) + check_error( + lib.infiniopGetSoftmaxWorkspaceSize(descriptor, ctypes.byref(workspace_size)) + ) + workspace = create_workspace(workspace_size.value, a.device) + + def lib_softmax(): + check_error( + lib.infiniopSoftmax( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + b_tensor.data, + a_tensor.data, + None, + ) + ) + + lib_softmax() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(b, ans, atol=atol, rtol=rtol) + assert torch.allclose(b, ans, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: softmax(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(lib.infiniopDestroySoftmaxDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + lib = open_lib() + + lib.infiniopCreateSoftmaxDescriptor.restype = c_int32 + lib.infiniopCreateSoftmaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopSoftmaxDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int32, + ] + + lib.infiniopGetSoftmaxWorkspaceSize.restype = c_int32 + lib.infiniopGetSoftmaxWorkspaceSize.argtypes = [ + infiniopSoftmaxDescriptor_t, + POINTER(c_uint64), + ] + + lib.infiniopSoftmax.restype = c_int32 + lib.infiniopSoftmax.argtypes = [ + infiniopSoftmaxDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySoftmaxDescriptor.restype = c_int32 + lib.infiniopDestroySoftmaxDescriptor.argtypes = [ + infiniopSoftmaxDescriptor_t, + ] + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(lib, device, test, _TEST_CASES_, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") + From 5fc6f559b06886e9a5d8d3d64828ad6a1a5b29bc Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Mon, 16 Jun 2025 15:51:16 +0800 Subject: [PATCH 2/8] =?UTF-8?q?Issue/259=20=E4=BF=AE=E5=A4=8DCREATE?= =?UTF-8?q?=E5=AE=8F=E5=AE=9A=E4=B9=89=E4=B8=AD=E5=87=BA=E7=8E=B0=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/softmax/operator.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc index d3df9a609..79ce5b17b 100644 --- a/src/infiniop/ops/softmax/operator.cc +++ b/src/infiniop/ops/softmax/operator.cc @@ -20,8 +20,9 @@ __C __export infiniStatus_t infiniopCreateSoftmaxDescriptor(infiniopHandle_t han handle, \ reinterpret_cast(desc_ptr), \ y_desc, \ - x_desc, \ + x_desc, \ axis) + switch (handle->device) { #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); @@ -93,7 +94,7 @@ infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) { #define DELETE(CASE, NAMESPACE) \ case CASE: \ delete reinterpret_cast(desc); \ - return INFINI_STATUS_SUCCESS; + return INFINI_STATUS_SUCCESS switch (desc->device_type) { #ifdef ENABLE_CPU_API @@ -106,4 +107,4 @@ infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) { return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } #undef DELETE -} +} \ No newline at end of file From 7583cea15f4e52305e2c3289928a15e4717e8609 Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Mon, 16 Jun 2025 16:01:11 +0800 Subject: [PATCH 3/8] =?UTF-8?q?Issue/259=20=E4=BF=AE=E5=A4=8D=E5=AD=98?= =?UTF-8?q?=E5=9C=A8=E7=9A=84=E9=9A=90=E5=BC=8F=E7=B1=BB=E5=9E=8B=E8=BD=AC?= =?UTF-8?q?=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/softmax/info.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infiniop/ops/softmax/info.h b/src/infiniop/ops/softmax/info.h index a40dc066a..da7929262 100644 --- a/src/infiniop/ops/softmax/info.h +++ b/src/infiniop/ops/softmax/info.h @@ -30,7 +30,7 @@ class SoftmaxInfo { info.otherdim_size = 1; info.stride = 1; info.dimsize = static_cast(x_desc->dim(axis)); - int ndim = y_desc->ndim(); + int ndim = static_cast(y_desc->ndim()); for (int i = ndim - 1; i >= 0; i--) { info.size *= static_cast(y_desc->dim(i)); } From e843a513b774687ce568caf8f61065f153760e94 Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Tue, 17 Jun 2025 01:07:38 +0800 Subject: [PATCH 4/8] =?UTF-8?q?Issue/259=20softmax=5Fcpu=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E6=8A=BD=E8=B1=A1=E5=87=8F=E5=B0=91=E5=86=97=E4=BD=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infiniop/ops/softmax.h | 4 +- src/infiniop/ops/softmax/cpu/softmax_cpu.cc | 77 +++++++++---------- .../ops/softmax/cuda/softmax_kernel.cuh | 4 +- src/infiniop/ops/softmax/info.h | 6 +- src/infiniop/ops/softmax/operator.cc | 2 +- src/infiniop/ops/softmax/softmax.h | 2 +- 6 files changed, 47 insertions(+), 48 deletions(-) diff --git a/include/infiniop/ops/softmax.h b/include/infiniop/ops/softmax.h index e1f3fc4d4..e826cab19 100644 --- a/include/infiniop/ops/softmax.h +++ b/include/infiniop/ops/softmax.h @@ -1,5 +1,5 @@ -#ifndef __INFINIOP_MLP_API_H__ -#define __INFINIOP_MLP_API_H__ +#ifndef __INFINIOP_SOFTMAX_API_H__ +#define __INFINIOP_SOFTMAX_API_H__ #include "../operator_descriptor.h" diff --git a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc index 47113cfe8..11ed94144 100644 --- a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc +++ b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc @@ -41,47 +41,46 @@ void softmax_cpu(const SoftmaxInfo &info, int dimsize = info.dimsize; int stride = info.stride; int othersize = info.otherdim_size; - if constexpr (std::is_same_v) { - auto input = reinterpret_cast(x); - auto output = reinterpret_cast(y); - for (int i = 0; i < othersize; i++) { - int tid = i % stride + (i - i % stride) * dimsize; - float max_data = -INFINITY; - for (int j = 0; j < dimsize; j++) { - int index = tid + j * stride; - max_data = fmax(max_data, utils::cast(input[index])); - } - float sum_data = 0.0f; - for (int j = 0; j < dimsize; j++) { - int index = tid + j * stride; - sum_data += std::exp(utils::cast(input[index]) - max_data); - } - for (int j = 0; j < dimsize; j++) { - int index = tid + j * stride; - output[index] = utils::cast(std::exp(utils::cast(input[index]) - max_data) / sum_data); - } + auto to_float = [](const T &val) -> float { + if constexpr (std::is_same_v) { + return utils::cast(val); + } else { + return val; } - } else if constexpr (std::is_same_v) { - auto input = reinterpret_cast(x); - auto output = reinterpret_cast(y); -#pragma omp parallel for - for (int i = 0; i < othersize; i++) { - int tid = i % stride + (i - i % stride) * dimsize; - float max_data = -INFINITY; - for (int j = 0; j < dimsize; j++) { - int index = tid + j * stride; - max_data = fmax(max_data, input[index]); - } - float sum_data = 0.0f; - for (int j = 0; j < dimsize; j++) { - int index = tid + j * stride; - sum_data += std::exp(input[index] - max_data); - } - for (int j = 0; j < dimsize; j++) { - int index = tid + j * stride; - output[index] = std::exp(input[index] - max_data) / sum_data; - } + }; + + auto from_float = [](float val) -> T { + if constexpr (std::is_same_v) { + return utils::cast(val); + } else { + return val; + } + }; + + auto input = reinterpret_cast(x); + auto output = reinterpret_cast(y); + + auto compute_softmax = [&](int i) { + int tid = i % stride + (i - i % stride) * dimsize; + float max_data = -INFINITY; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + max_data = fmax(max_data, to_float(input[index])); + } + float sum_data = 0.0f; + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + sum_data += std::exp(to_float(input[index]) - max_data); } + for (int j = 0; j < dimsize; j++) { + int index = tid + j * stride; + float result = std::exp(to_float(input[index]) - max_data) / sum_data; + output[index] = from_float(result); + } + }; +#pragma omp parallel for + for (int i = 0; i < othersize; i++) { + compute_softmax(i); } } diff --git a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh index f8b81780c..8e2ff9d88 100644 --- a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh +++ b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh @@ -1,5 +1,6 @@ #ifndef __SOFTMAX_CUDA_KERNEL_H__ #define __SOFTMAX_CUDA_KERNEL_H__ + #include "../../../devices/cuda/cuda_kernel_common.cuh" #include "softmax_cuda.cuh" #include @@ -74,7 +75,6 @@ i 也就是 (blockIdx.x * blockDim.y + threadIdx.y) / stride j 也就是 (blockIdx.x * blockDim.y + threadIdx.y) % stride 然后i转化为线性也就是 i * stride * dimsize j直接加上就好 - */ template __global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) { @@ -236,4 +236,4 @@ infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, c return INFINI_STATUS_SUCCESS; } -#endif // __SOFTMAX_CUDA_KERNEL_H__ \ No newline at end of file +#endif // __SOFTMAX_CUDA_KERNEL_H__ diff --git a/src/infiniop/ops/softmax/info.h b/src/infiniop/ops/softmax/info.h index da7929262..c59253abf 100644 --- a/src/infiniop/ops/softmax/info.h +++ b/src/infiniop/ops/softmax/info.h @@ -1,5 +1,5 @@ -#ifndef __CONV_INFO_H__ -#define __CONV_INFO_H__ +#ifndef __SOFTMAX_INFO_H__ +#define __SOFTMAX_INFO_H__ #include "../../../utils.h" #include "../../operator.h" @@ -44,4 +44,4 @@ class SoftmaxInfo { }; } // namespace op::softmax -#endif // __CONV_INFO_H__ +#endif // __SOFTMAX_INFO_H__ diff --git a/src/infiniop/ops/softmax/operator.cc b/src/infiniop/ops/softmax/operator.cc index 79ce5b17b..dcdcf1d3f 100644 --- a/src/infiniop/ops/softmax/operator.cc +++ b/src/infiniop/ops/softmax/operator.cc @@ -107,4 +107,4 @@ infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc) { return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } #undef DELETE -} \ No newline at end of file +} diff --git a/src/infiniop/ops/softmax/softmax.h b/src/infiniop/ops/softmax/softmax.h index 664e6b6e4..a40e0cad2 100644 --- a/src/infiniop/ops/softmax/softmax.h +++ b/src/infiniop/ops/softmax/softmax.h @@ -46,4 +46,4 @@ void *stream) const; \ }; \ } -#endif // __CONV_H__ +#endif // __SOFTMAX_H__ From 0ecfd1e6cba62b5b0655efca9eee067a74fe07e1 Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Mon, 23 Jun 2025 22:22:47 +0800 Subject: [PATCH 5/8] =?UTF-8?q?Issue/259=20softmax=5Fcuda=20=E7=AE=97?= =?UTF-8?q?=E5=AD=90dispatch=E6=8A=BD=E8=B1=A1=E4=BB=A5=E5=8F=8A=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F=E8=A7=84=E8=8C=83=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infiniop/ops/softmax.h | 1 + src/infiniop/ops/softmax/cpu/softmax_cpu.cc | 41 ++--- src/infiniop/ops/softmax/cuda/softmax_cuda.cu | 2 +- .../ops/softmax/cuda/softmax_kernel.cuh | 157 +++++++++--------- src/infiniop/ops/softmax/info.h | 10 +- src/infiniop/ops/softmax/softmax.h | 1 + test/infiniop/rope.py | 2 +- test/infiniop/softmax.py | 1 + 8 files changed, 107 insertions(+), 108 deletions(-) diff --git a/include/infiniop/ops/softmax.h b/include/infiniop/ops/softmax.h index e826cab19..d06f61345 100644 --- a/include/infiniop/ops/softmax.h +++ b/include/infiniop/ops/softmax.h @@ -16,4 +16,5 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d __C infiniStatus_t infiniopSoftmax(infiniopSoftmaxDescriptor_t desc, void *workspace, size_t workspace_size, void *y, const void *x, void *stream); __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t desc); + #endif diff --git a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc index 11ed94144..b1e6327ad 100644 --- a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc +++ b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc @@ -38,48 +38,35 @@ infiniStatus_t Descriptor::create( template void softmax_cpu(const SoftmaxInfo &info, const void *x, void *y, int axis) { - int dimsize = info.dimsize; + int dim_size = info.dim_size; int stride = info.stride; - int othersize = info.otherdim_size; - auto to_float = [](const T &val) -> float { - if constexpr (std::is_same_v) { - return utils::cast(val); - } else { - return val; - } - }; - - auto from_float = [](float val) -> T { - if constexpr (std::is_same_v) { - return utils::cast(val); - } else { - return val; - } - }; - + int other_size = info.other_size; auto input = reinterpret_cast(x); auto output = reinterpret_cast(y); auto compute_softmax = [&](int i) { - int tid = i % stride + (i - i % stride) * dimsize; + int tid = i % stride + (i - i % stride) * dim_size; + float max_data = -INFINITY; - for (int j = 0; j < dimsize; j++) { + for (int j = 0; j < dim_size; j++) { int index = tid + j * stride; - max_data = fmax(max_data, to_float(input[index])); + max_data = fmax(max_data, utils::cast(input[index])); } + float sum_data = 0.0f; - for (int j = 0; j < dimsize; j++) { + for (int j = 0; j < dim_size; j++) { int index = tid + j * stride; - sum_data += std::exp(to_float(input[index]) - max_data); + sum_data += std::exp(utils::cast(input[index]) - max_data); } - for (int j = 0; j < dimsize; j++) { + + for (int j = 0; j < dim_size; j++) { int index = tid + j * stride; - float result = std::exp(to_float(input[index]) - max_data) / sum_data; - output[index] = from_float(result); + float result = std::exp(utils::cast(input[index]) - max_data) / sum_data; + output[index] = utils::cast(result); } }; #pragma omp parallel for - for (int i = 0; i < othersize; i++) { + for (int i = 0; i < other_size; i++) { compute_softmax(i); } } diff --git a/src/infiniop/ops/softmax/cuda/softmax_cuda.cu b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu index b5283a427..f94e32ab0 100644 --- a/src/infiniop/ops/softmax/cuda/softmax_cuda.cu +++ b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu @@ -49,4 +49,4 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_BAD_TENSOR_DTYPE; } } -} // namespace op::softmax::cuda \ No newline at end of file +} // namespace op::softmax::cuda diff --git a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh index 8e2ff9d88..72847c029 100644 --- a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh +++ b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh @@ -47,9 +47,9 @@ struct MaxOp { } }; -template class ReduceOp, int thread_group_width = 32> +template class ReduceOp, int THREAD_GROUP_WIDTH = 32> __device__ __forceinline__ T warpReduce(T value) { - for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + for (int mask = THREAD_GROUP_WIDTH / 2; mask > 0; mask /= 2) { value = ReduceOp()(value, __shfl_xor_sync(0xffffffff, value, mask)); } return value; @@ -73,23 +73,23 @@ __device__ __forceinline__ T warpReduce(T value) { 也就是我们还需要i 和 j i 也就是 (blockIdx.x * blockDim.y + threadIdx.y) / stride j 也就是 (blockIdx.x * blockDim.y + threadIdx.y) % stride -然后i转化为线性也就是 i * stride * dimsize +然后i转化为线性也就是 i * stride * dim_size j直接加上就好 */ -template -__global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) { - float dataPerThread[elemPerThread]; +template +__global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dim_size, int other_size) { + float dataPerThread[ELEM_PER_THREAD]; int global_warp_id = blockIdx.x * blockDim.y + threadIdx.y; - int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dimsize; + int group_offset = global_warp_id % stride + (global_warp_id - global_warp_id % stride) * dim_size; int tid = threadIdx.x; - if (global_warp_id >= otherdim_size) { + if (global_warp_id >= other_size) { return; } __shared__ float group_max[BLOCK_DIM_X]; __shared__ float group_sum[BLOCK_DIM_X]; float thread_max = -INFINITY; float thread_sum = 0.0f; - for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) { + for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) { dataPerThread[i] = static_cast(x[(tid + i * BLOCK_DIM_X) * stride + group_offset]); thread_max = max(thread_max, dataPerThread[i]); } @@ -99,7 +99,7 @@ __global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int group_max[threadIdx.y] = thread_max; } - for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) { + for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) { dataPerThread[i] = __expf(dataPerThread[i] - group_max[threadIdx.y]); thread_sum += dataPerThread[i]; } @@ -109,18 +109,18 @@ __global__ void Softmax_warp_impl(const T *x, T *y, int stride, int dimsize, int group_sum[threadIdx.y] = thread_sum; } - for (int i = 0; tid + i * BLOCK_DIM_X < dimsize; i++) { + for (int i = 0; tid + i * BLOCK_DIM_X < dim_size; i++) { y[(tid + i * BLOCK_DIM_X) * stride + group_offset] = static_cast(dataPerThread[i] * __fdividef(1.0f, group_sum[threadIdx.y])); } } -template +template __launch_bounds__(BLOCK_DIM) - __global__ void Softmax_block_impl(const T *x, T *y, int stride, int dimsize, int otherdim_size) { - // remain = dimsize - BLOCK_DIM * elemPerThread + __global__ void Softmax_block_impl(const T *x, T *y, int stride, int dim_size, int other_size) { + // remain = dim_size - BLOCK_DIM * ELEM_PER_THREAD int tid = threadIdx.x; - int block_offset = (blockIdx.x - blockIdx.x % stride) * dimsize + blockIdx.x % stride; - int remain = dimsize - (BLOCK_DIM - 1) * elemPerThread; // 🔧 修正:最后线程处理的元素数 + int block_offset = (blockIdx.x - blockIdx.x % stride) * dim_size + blockIdx.x % stride; + int remain = dim_size - (BLOCK_DIM - 1) * ELEM_PER_THREAD; MD md_partial; md_partial.max = -INFINITY; @@ -129,8 +129,8 @@ __launch_bounds__(BLOCK_DIM) // tid = [0, BLOCK_DIM - 1], 所以最后一个线程处理余数部分 if (tid < BLOCK_DIM - 1) { #pragma unroll - for (int i = 0; i < elemPerThread; i++) { - int index = (tid * elemPerThread + i) * stride + block_offset; + for (int i = 0; i < ELEM_PER_THREAD; i++) { + int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset; input.max = static_cast(x[index]); input.sum = 1.0f; md_partial = reduce_for_md(md_partial, input); @@ -138,7 +138,7 @@ __launch_bounds__(BLOCK_DIM) } else { #pragma unroll for (int i = 0; i < remain; i++) { - int index = ((BLOCK_DIM - 1) * elemPerThread + i) * stride + block_offset; + int index = ((BLOCK_DIM - 1) * ELEM_PER_THREAD + i) * stride + block_offset; input.max = static_cast(x[index]); input.sum = 1.0f; md_partial = reduce_for_md(md_partial, input); @@ -153,85 +153,94 @@ __launch_bounds__(BLOCK_DIM) } __syncthreads(); if (tid < BLOCK_DIM - 1) { - for (int i = 0; i < elemPerThread; i++) { - int index = (tid * elemPerThread + i) * stride + block_offset; + for (int i = 0; i < ELEM_PER_THREAD; i++) { + int index = (tid * ELEM_PER_THREAD + i) * stride + block_offset; y[index] = static_cast(__expf(static_cast(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum)); } } else { for (int i = 0; i < remain; i++) { - int index = ((BLOCK_DIM - 1) * elemPerThread + i) * stride + block_offset; + int index = ((BLOCK_DIM - 1) * ELEM_PER_THREAD + i) * stride + block_offset; y[index] = static_cast(__expf(static_cast(x[index]) - md_total.max) * __fdividef(1.0f, md_total.sum)); } } } template -infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) { - int dimsize = info.dimsize; - int stride = info.stride; - int otherdim_size = info.otherdim_size; - if (dimsize <= 1024) { - dim3 block(32, 32); // BLOCK_DIM_X=32, BLOCK_DIM_Y=4 - int num_blocks = (otherdim_size + block.y - 1) / block.y; - dim3 grid(num_blocks, 1, 1); - int elemPerThread = (dimsize + 31) / 32; // 计算每个线程需要处理的元素数 - elemPerThread = min(elemPerThread, 32); // 限制最大值 +void dispatchSoftmaxKernel( + const void *x, void *y, + int stride, int dim_size, int other_size, + void *stream, bool use_warp_impl) { + + int elemPerThread; + dim3 grid, block; + + if (use_warp_impl) { + block = dim3(32, 32); + grid = dim3((other_size + block.y - 1) / block.y, 1, 1); + elemPerThread = min((dim_size + 31) / 32, 32); + +#define LAUNCH_WARP_KERNEL(ELEM_PER_THREAD) \ + Softmax_warp_impl \ + <<(stream)>>>( \ + reinterpret_cast(x), reinterpret_cast(y), \ + stride, dim_size, other_size) + if (elemPerThread <= 1) { - Softmax_warp_impl<1, 32, 32, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_WARP_KERNEL(1); } else if (elemPerThread <= 2) { - Softmax_warp_impl<2, 32, 32, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_WARP_KERNEL(2); } else if (elemPerThread <= 4) { - Softmax_warp_impl<4, 32, 32, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_WARP_KERNEL(4); } else if (elemPerThread <= 8) { - Softmax_warp_impl<8, 32, 32, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_WARP_KERNEL(8); } else if (elemPerThread <= 16) { - Softmax_warp_impl<16, 32, 32, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_WARP_KERNEL(16); } else { - Softmax_warp_impl<32, 32, 32, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_WARP_KERNEL(32); } - } else if (dimsize > 1024) { - int block_size = 1024; - int elemPerThread = (dimsize + block_size - 1) / block_size; // 每个线程需要处理的元素数 - elemPerThread = min(elemPerThread, 32); // 限制最大值为32 - dim3 block(block_size); - dim3 grid(otherdim_size); + +#undef LAUNCH_WARP_KERNEL + + } else { + // Block implementation for dim_size > 1024 + constexpr int BLOCK_SIZE = 1024; + block = dim3(BLOCK_SIZE); + grid = dim3(other_size); + elemPerThread = min((dim_size + BLOCK_SIZE - 1) / BLOCK_SIZE, 32); + +#define LAUNCH_BLOCK_KERNEL(ELEM_PER_THREAD) \ + Softmax_block_impl \ + <<(stream)>>>( \ + reinterpret_cast(x), reinterpret_cast(y), \ + stride, dim_size, other_size) + if (elemPerThread <= 1) { - Softmax_block_impl<1, 1024, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_BLOCK_KERNEL(1); } else if (elemPerThread <= 2) { - Softmax_block_impl<2, 1024, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_BLOCK_KERNEL(2); } else if (elemPerThread <= 4) { - Softmax_block_impl<4, 1024, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_BLOCK_KERNEL(4); } else if (elemPerThread <= 8) { - Softmax_block_impl<8, 1024, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_BLOCK_KERNEL(8); } else if (elemPerThread <= 16) { - Softmax_block_impl<16, 1024, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_BLOCK_KERNEL(16); } else { - Softmax_block_impl<32, 1024, T> - <<(stream)>>>( - reinterpret_cast(x), reinterpret_cast(y), stride, dimsize, otherdim_size); + LAUNCH_BLOCK_KERNEL(32); } + +#undef LAUNCH_BLOCK_KERNEL + } +} + +template +infiniStatus_t softmax_dispatch(const op::softmax::SoftmaxInfo &info, void *y, const void *x, void *stream) { + int dim_size = info.dim_size; + int stride = info.stride; + int other_size = info.other_size; + if (dim_size <= 1024) { + dispatchSoftmaxKernel(x, y, stride, dim_size, other_size, stream, true); + } else if (dim_size > 1024) { + dispatchSoftmaxKernel(x, y, stride, dim_size, other_size, stream, false); } return INFINI_STATUS_SUCCESS; } diff --git a/src/infiniop/ops/softmax/info.h b/src/infiniop/ops/softmax/info.h index c59253abf..67be5e14d 100644 --- a/src/infiniop/ops/softmax/info.h +++ b/src/infiniop/ops/softmax/info.h @@ -10,10 +10,10 @@ namespace op::softmax { class SoftmaxInfo { public: int axis; - int otherdim_size; + int other_size; int stride; int size; - int dimsize; + int dim_size; static utils::Result create( infiniopTensorDescriptor_t y_desc, @@ -27,9 +27,9 @@ class SoftmaxInfo { SoftmaxInfo info; info.axis = axis; info.size = 1; - info.otherdim_size = 1; + info.other_size = 1; info.stride = 1; - info.dimsize = static_cast(x_desc->dim(axis)); + info.dim_size = static_cast(x_desc->dim(axis)); int ndim = static_cast(y_desc->ndim()); for (int i = ndim - 1; i >= 0; i--) { info.size *= static_cast(y_desc->dim(i)); @@ -38,7 +38,7 @@ class SoftmaxInfo { for (int i = axis + 1; i < ndim; i++) { info.stride *= static_cast(x_desc->dim(i)); } - info.otherdim_size = info.size / info.dimsize; + info.other_size = info.size / info.dim_size; return utils::Result(info); } }; diff --git a/src/infiniop/ops/softmax/softmax.h b/src/infiniop/ops/softmax/softmax.h index a40e0cad2..dc0520f77 100644 --- a/src/infiniop/ops/softmax/softmax.h +++ b/src/infiniop/ops/softmax/softmax.h @@ -46,4 +46,5 @@ void *stream) const; \ }; \ } + #endif // __SOFTMAX_H__ diff --git a/test/infiniop/rope.py b/test/infiniop/rope.py index 99d90dd2c..c3383ef59 100644 --- a/test/infiniop/rope.py +++ b/test/infiniop/rope.py @@ -224,4 +224,4 @@ def lib_rope(): for device in get_test_devices(args): test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) - print("\033[92mTest passed!\033[0m") + print("\033[92mTest passed!\033[0m") \ No newline at end of file diff --git a/test/infiniop/softmax.py b/test/infiniop/softmax.py index bf286de68..5103c2eac 100644 --- a/test/infiniop/softmax.py +++ b/test/infiniop/softmax.py @@ -23,6 +23,7 @@ # ============================================================================== # These are not meant to be imported from other modules _TEST_CASES_ = [ + # shape reduce_axis ((32, 20, 512), 0), ((32, 20, 512), 1), ((32, 20, 512), 2), From 56b490067922b5e31cfec2b001acf98a36e33236 Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Mon, 23 Jun 2025 22:43:49 +0800 Subject: [PATCH 6/8] Issue/259 format code --- src/infiniop/ops/softmax/cuda/softmax_cuda.cuh | 2 +- test/infiniop/rope.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh b/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh index 0a9c63187..6031e113b 100644 --- a/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh +++ b/src/infiniop/ops/softmax/cuda/softmax_cuda.cuh @@ -5,4 +5,4 @@ DESCRIPTOR(cuda) -#endif // __SOFTMAX_CUDA_CUH__ \ No newline at end of file +#endif // __SOFTMAX_CUDA_CUH__ diff --git a/test/infiniop/rope.py b/test/infiniop/rope.py index c3383ef59..7b24fdb7d 100644 --- a/test/infiniop/rope.py +++ b/test/infiniop/rope.py @@ -224,4 +224,5 @@ def lib_rope(): for device in get_test_devices(args): test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) - print("\033[92mTest passed!\033[0m") \ No newline at end of file + print("\033[92mTest passed!\033[0m") + \ No newline at end of file From 6d643179fefb78f121616b36237f277f11d88584 Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Thu, 10 Jul 2025 00:36:33 +0800 Subject: [PATCH 7/8] =?UTF-8?q?Issue/259=20softmax=E7=AE=97=E5=AD=90?= =?UTF-8?q?=E6=B5=8B=E4=BE=8B=E4=BF=AE=E6=94=B9=E4=BB=A5=E5=8F=8A=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0bf16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/softmax/cpu/softmax_cpu.cc | 5 +- src/infiniop/ops/softmax/cuda/softmax_cuda.cu | 4 +- .../ops/softmax/cuda/softmax_kernel.cuh | 14 ++ test/infiniop/libinfiniop/op_register.py | 32 +++ test/infiniop/rope.py | 1 - test/infiniop/softmax.py | 223 +++++++----------- 6 files changed, 142 insertions(+), 137 deletions(-) diff --git a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc index b1e6327ad..ab18315a3 100644 --- a/src/infiniop/ops/softmax/cpu/softmax_cpu.cc +++ b/src/infiniop/ops/softmax/cpu/softmax_cpu.cc @@ -18,7 +18,7 @@ infiniStatus_t Descriptor::create( const auto &x_shape = x->shape(); const auto &y_shape = y->shape(); - CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); CHECK_SAME_SHAPE(y_shape, x_shape); @@ -84,6 +84,9 @@ infiniStatus_t Descriptor::calculate( case INFINI_DTYPE_F32: softmax_cpu(_info, x, y, _info.axis); return INFINI_STATUS_SUCCESS; + case INFINI_DTYPE_BF16: + softmax_cpu(_info, x, y, _info.axis); + return INFINI_STATUS_SUCCESS; default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/softmax/cuda/softmax_cuda.cu b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu index f94e32ab0..23f919595 100644 --- a/src/infiniop/ops/softmax/cuda/softmax_cuda.cu +++ b/src/infiniop/ops/softmax/cuda/softmax_cuda.cu @@ -23,7 +23,7 @@ infiniStatus_t Descriptor::create( auto result = SoftmaxInfo::create(y, x, axis); CHECK_RESULT(result); CHECK_SAME_SHAPE(y->shape(), x->shape()); - CHECK_DTYPE(y->dtype(), x->dtype(), INFINI_DTYPE_F16, INFINI_DTYPE_F32); + CHECK_DTYPE(y->dtype(), x->dtype(), INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); *desc_ptr = new Descriptor( dtype, result.take(), @@ -45,6 +45,8 @@ infiniStatus_t Descriptor::calculate( return softmax_dispatch(_info, y, x, stream_); case INFINI_DTYPE_F32: return softmax_dispatch(_info, y, x, stream_); + case INFINI_DTYPE_BF16: + return softmax_dispatch<__nv_bfloat16>(_info, y, x, stream_); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh index 72847c029..54ab99af0 100644 --- a/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh +++ b/src/infiniop/ops/softmax/cuda/softmax_kernel.cuh @@ -33,6 +33,13 @@ struct SumOp { } }; +template <> +struct SumOp<__nv_bfloat16> { + __device__ __forceinline__ __nv_bfloat16 operator()(const __nv_bfloat16 &a, const __nv_bfloat16 &b) const { + return __hadd(a, b); + } +}; + template struct MaxOp { __device__ __forceinline__ T operator()(const T &a, const T &b) const { @@ -47,6 +54,13 @@ struct MaxOp { } }; +template <> +struct MaxOp<__nv_bfloat16> { + __device__ __forceinline__ __nv_bfloat16 operator()(const __nv_bfloat16 &a, const __nv_bfloat16 &b) const { + return __hmax(a, b); + } +}; + template class ReduceOp, int THREAD_GROUP_WIDTH = 32> __device__ __forceinline__ T warpReduce(T value) { for (int mask = THREAD_GROUP_WIDTH / 2; mask > 0; mask /= 2) { diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ff583d9c0..aa1380c9b 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -431,3 +431,35 @@ def swiglu_(lib): lib.infiniopDestroySwiGLUDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + +@OpRegister.operator +def softmax_(lib): + lib.infiniopCreateSoftmaxDescriptor.restype = c_int32 + lib.infiniopCreateSoftmaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_int32, + ] + + lib.infiniopGetSoftmaxWorkspaceSize.restype = c_int32 + lib.infiniopGetSoftmaxWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopSoftmax.restype = c_int32 + lib.infiniopSoftmax.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroySoftmaxDescriptor.restype = c_int32 + lib.infiniopDestroySoftmaxDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] \ No newline at end of file diff --git a/test/infiniop/rope.py b/test/infiniop/rope.py index 7b24fdb7d..99d90dd2c 100644 --- a/test/infiniop/rope.py +++ b/test/infiniop/rope.py @@ -225,4 +225,3 @@ def lib_rope(): test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m") - \ No newline at end of file diff --git a/test/infiniop/softmax.py b/test/infiniop/softmax.py index 5103c2eac..0900cc8bd 100644 --- a/test/infiniop/softmax.py +++ b/test/infiniop/softmax.py @@ -1,20 +1,21 @@ import torch import ctypes -from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64 +from ctypes import c_uint64, c_int32 from libinfiniop import ( - infiniopHandle_t, - infiniopTensorDescriptor_t, - open_lib, - to_tensor, + LIBINFINIOP, + TestTensor, get_test_devices, check_error, - rearrange_if_needed, test_operator, get_args, debug, get_tolerance, profile_operation, - create_workspace, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, ) from enum import Enum, auto @@ -23,198 +24,152 @@ # ============================================================================== # These are not meant to be imported from other modules _TEST_CASES_ = [ - # shape reduce_axis - ((32, 20, 512), 0), - ((32, 20, 512), 1), - ((32, 20, 512), 2), + # shape, reduce_axis, stride + ((32, 20, 512), 0, (20 * 512, 512, 1)), + ((32, 20, 512), 1, (20 * 512, 512, 1)), + ((32, 20, 512), 2, (20 * 512, 512, 1)), # 2D 张量测试 - ((128, 256), 0), - ((128, 256), 1), - ((1024, 1024), 0), - ((1024, 1024), 1), + ((128, 256), 0, (256, 1)), + ((128, 256), 1, (256, 1)), + ((1024, 1024), 0, (1024, 1)), + ((1024, 1024), 1, (1024, 1)), # 4D 张量测试 - ((8, 32, 64, 64), 0), - ((8, 32, 64, 64), 1), - ((8, 32, 64, 64), 2), - ((8, 32, 64, 64), 3), + ((8, 32, 64, 64), 0, (32 * 64 * 64, 64 * 64, 64, 1)), + ((8, 32, 64, 64), 1, (32 * 64 * 64, 64 * 64, 64, 1)), + ((8, 32, 64, 64), 2, (32 * 64 * 64, 64 * 64, 64, 1)), + ((8, 32, 64, 64), 3, (32 * 64 * 64, 64 * 64, 64, 1)), # 5D 张量测试 - ((4, 16, 8, 32, 32), 0), - ((4, 16, 8, 32, 32), 1), - ((4, 16, 8, 32, 32), 2), - ((4, 16, 8, 32, 32), 3), - ((4, 16, 8, 32, 32), 4), + ((4, 16, 8, 32, 32), 0, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 1, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 2, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 3, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), + ((4, 16, 8, 32, 32), 4, (16 * 8 * 32 * 32, 8 * 32 * 32, 32 * 32, 32, 1)), # 小尺寸测试 - ((2, 3), 0), - ((2, 3), 1), - ((1, 10), 0), - ((1, 10), 1), - ((10, 1), 0), - ((10, 1), 1), + ((2, 3), 0, (3, 1)), + ((2, 3), 1, (3, 1)), + ((1, 10), 0, (10, 1)), + ((1, 10), 1, (10, 1)), + ((10, 1), 0, (1, 1)), + ((10, 1), 1, (1, 1)), - ((1000,), 0), + ((1000,), 0, (1,)), - ((7, 333, 777), 0), - ((7, 333, 777), 1), - ((7, 333, 777), 2), - ((13, 509, 251), 0), - ((13, 509, 251), 1), - ((13, 509, 251), 2), + ((7, 333, 777), 0, (333 * 777, 777, 1)), + ((7, 333, 777), 1, (333 * 777, 777, 1)), + ((7, 333, 777), 2, (333 * 777, 777, 1)), + ((13, 509, 251), 0, (509 * 251, 251, 1)), + ((13, 509, 251), 1, (509 * 251, 251, 1)), + ((13, 509, 251), 2, (509 * 251, 251, 1)), - ((64, 1024, 768), 0), - ((64, 1024, 768), 1), - ((64, 1024, 768), 2), - ((32, 2048, 512), 0), - ((32, 2048, 512), 1), - ((32, 2048, 512), 2), - - ((1024, 1), 0), - ((1024, 1), 1), - ((1, 1024), 0), - ((1, 1024), 1), - - ((32, 8, 512, 64), 1), - ((32, 8, 512, 64), 2), - ((32, 8, 512, 64), 3), + ((64, 1024, 768), 0, (1024 * 768, 768, 1)), + ((64, 1024, 768), 1, (1024 * 768, 768, 1)), + ((64, 1024, 768), 2, (1024 * 768, 768, 1)), + ((32, 2048, 512), 0, (2048 * 512, 512, 1)), + ((32, 2048, 512), 1, (2048 * 512, 512, 1)), + ((32, 2048, 512), 2, (2048 * 512, 512, 1)), + + ((1024, 1), 0, (1, 1)), + ((1024, 1), 1, (1, 1)), + ((1, 1024), 0, (1024, 1)), + ((1, 1024), 1, (1024, 1)), + + ((32, 8, 512, 64), 1, (8 * 512 * 64, 512 * 64, 64, 1)), + ((32, 8, 512, 64), 2, (8 * 512 * 64, 512 * 64, 64, 1)), + ((32, 8, 512, 64), 3, (8 * 512 * 64, 512 * 64, 64, 1)), ] - # Data types used for testing -_TENSOR_DTYPES = [torch.float16, torch.float32] +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.F32, InfiniDtype.BF16] # Tolerance map for different data types _TOLERANCE_MAP = { - torch.float16: {"atol": 1e-3, "rtol": 1e-3}, - torch.float32: {"atol": 1e-7, "rtol": 1e-7}, + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-3}, + InfiniDtype.F32: {"atol": 1e-7, "rtol": 1e-6}, + InfiniDtype.BF16: {"atol": 1e-3, "rtol": 1e-2}, } - -DEBUG = False -PROFILE = False -NUM_PRERUN = 10 -NUM_ITERATIONS = 1000 - - -class SoftmaxDescriptor(Structure): - _fields_ = [("device", c_int32)] - - -infiniopSoftmaxDescriptor_t = POINTER(SoftmaxDescriptor) - - -def softmax(x, axis): - return torch.softmax(x, axis = axis).to(x.dtype) +def softmax(x, axis, y): + torch.softmax(x, axis = axis, out=y) def test( - lib, handle, - torch_device, + device, shape, axis, + stride, dtype=torch.float16, sync=None, ): + x = TestTensor(shape, stride, dtype, device) + y = TestTensor(shape, stride, dtype, device, mode="zeros") + print( - f"Testing softmax on {torch_device} with shape:{shape}" + f"Testing softmax on {InfiniDeviceNames[device]} with shape:{shape} stride:{stride} axis:{axis} " f"dtype:{dtype}" ) - - a = torch.randn(shape, dtype=dtype).to(torch_device) * 0.1 - b = torch.empty_like(a) - ans = softmax(a, axis) - - a_tensor, b_tensor = [to_tensor(tensor, lib) for tensor in [a, b]] + # a = torch.randn(shape, dtype=dtype).to(torch_device) * 0.1 + # b = torch.empty_like(a) + softmax(x.torch_tensor(), axis, y.torch_tensor()) if sync is not None: sync() - descriptor = infiniopSoftmaxDescriptor_t() + descriptor = infiniopOperatorDescriptor_t() check_error( - lib.infiniopCreateSoftmaxDescriptor( + LIBINFINIOP.infiniopCreateSoftmaxDescriptor( handle, ctypes.byref(descriptor), - b_tensor.descriptor, - a_tensor.descriptor, + y.descriptor, + x.descriptor, c_int32(axis), ) ) - - # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel - for tensor in [a_tensor, b_tensor]: - tensor.destroyDesc(lib) - workspace_size = c_uint64(0) check_error( - lib.infiniopGetSoftmaxWorkspaceSize(descriptor, ctypes.byref(workspace_size)) + LIBINFINIOP.infiniopGetSoftmaxWorkspaceSize(descriptor, ctypes.byref(workspace_size)) ) - workspace = create_workspace(workspace_size.value, a.device) + workspace = TestWorkspace(workspace_size.value, x.device) def lib_softmax(): check_error( - lib.infiniopSoftmax( + LIBINFINIOP.infiniopSoftmax( descriptor, - workspace.data_ptr() if workspace is not None else None, + workspace.data() if workspace is not None else None, workspace_size.value, - b_tensor.data, - a_tensor.data, + y.data(), + x.data(), None, ) ) lib_softmax() + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [x, y]: + tensor.destroy_desc() + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: - debug(b, ans, atol=atol, rtol=rtol) - assert torch.allclose(b, ans, atol=atol, rtol=rtol) + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) # Profiling workflow if PROFILE: # fmt: off - profile_operation("PyTorch", lambda: softmax(a, b), torch_device, NUM_PRERUN, NUM_ITERATIONS) - profile_operation(" lib", lambda: lib_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation("PyTorch", lambda: softmax(x.torch_tensor(), axis, y.torch_tensor()), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_softmax(), device, NUM_PRERUN, NUM_ITERATIONS) # fmt: on - check_error(lib.infiniopDestroySoftmaxDescriptor(descriptor)) + check_error(LIBINFINIOP.infiniopDestroySoftmaxDescriptor(descriptor)) if __name__ == "__main__": args = get_args() - lib = open_lib() - - lib.infiniopCreateSoftmaxDescriptor.restype = c_int32 - lib.infiniopCreateSoftmaxDescriptor.argtypes = [ - infiniopHandle_t, - POINTER(infiniopSoftmaxDescriptor_t), - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - c_int32, - ] - - lib.infiniopGetSoftmaxWorkspaceSize.restype = c_int32 - lib.infiniopGetSoftmaxWorkspaceSize.argtypes = [ - infiniopSoftmaxDescriptor_t, - POINTER(c_uint64), - ] - - lib.infiniopSoftmax.restype = c_int32 - lib.infiniopSoftmax.argtypes = [ - infiniopSoftmaxDescriptor_t, - c_void_p, - c_uint64, - c_void_p, - c_void_p, - c_void_p, - ] - - lib.infiniopDestroySoftmaxDescriptor.restype = c_int32 - lib.infiniopDestroySoftmaxDescriptor.argtypes = [ - infiniopSoftmaxDescriptor_t, - ] - + # Configure testing options DEBUG = args.debug PROFILE = args.profile @@ -222,7 +177,7 @@ def lib_softmax(): NUM_ITERATIONS = args.num_iterations for device in get_test_devices(args): - test_operator(lib, device, test, _TEST_CASES_, _TENSOR_DTYPES) + test_operator(device, test, _TEST_CASES_, _TENSOR_DTYPES) print("\033[92mTest passed!\033[0m") From 581fd623c9bbf7f03784be393377c977065790db Mon Sep 17 00:00:00 2001 From: Graylatzhou <1391087899@qq.com> Date: Thu, 10 Jul 2025 15:59:08 +0800 Subject: [PATCH 8/8] Issue/259 format code --- include/infiniop.h | 1 + test/infiniop/libinfiniop/op_register.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/include/infiniop.h b/include/infiniop.h index d51b8d92e..ab4c8c231 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -14,6 +14,7 @@ #include "infiniop/ops/relu.h" #include "infiniop/ops/rms_norm.h" #include "infiniop/ops/rope.h" +#include "infiniop/ops/softmax.h" #include "infiniop/ops/sub.h" #include "infiniop/ops/swiglu.h" #include "infiniop/tensor_descriptor.h" diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index aa1380c9b..0ba71d5e0 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -462,4 +462,4 @@ def softmax_(lib): lib.infiniopDestroySoftmaxDescriptor.restype = c_int32 lib.infiniopDestroySoftmaxDescriptor.argtypes = [ infiniopOperatorDescriptor_t, - ] \ No newline at end of file + ]