From b96af842619c9966f1c158ffb2918a6290b944df Mon Sep 17 00:00:00 2001 From: xgqdut2016 Date: Fri, 5 Dec 2025 14:56:29 +0800 Subject: [PATCH 1/7] issue/664: success linear --- include/infiniop/ops/linear.h | 40 +++ include/infiniop/ops/quant.h | 28 ++ src/infiniop/ops/linear/cuda/kernel.cuh | 40 +++ src/infiniop/ops/linear/info.h | 79 +++++ src/infiniop/ops/linear/linear.h | 54 ++++ .../ops/linear/nvidia/linear_nvidia.cu | 193 ++++++++++++ .../ops/linear/nvidia/linear_nvidia.cuh | 7 + src/infiniop/ops/linear/operator.cc | 117 ++++++++ src/infiniop/ops/quant/cuda/kernel.cuh | 277 ++++++++++++++++++ src/infiniop/ops/quant/info.h | 60 ++++ src/infiniop/ops/quant/nvidia/quant_nvidia.cu | 118 ++++++++ .../ops/quant/nvidia/quant_nvidia.cuh | 7 + src/infiniop/ops/quant/operator.cc | 98 +++++++ src/infiniop/ops/quant/quant.h | 40 +++ test/infiniop/libinfiniop/op_register.py | 83 ++++++ test/infiniop/linear.py | 266 +++++++++++++++++ test/infiniop/quant.py | 211 +++++++++++++ 17 files changed, 1718 insertions(+) create mode 100644 include/infiniop/ops/linear.h create mode 100644 include/infiniop/ops/quant.h create mode 100644 src/infiniop/ops/linear/cuda/kernel.cuh create mode 100644 src/infiniop/ops/linear/info.h create mode 100644 src/infiniop/ops/linear/linear.h create mode 100644 src/infiniop/ops/linear/nvidia/linear_nvidia.cu create mode 100644 src/infiniop/ops/linear/nvidia/linear_nvidia.cuh create mode 100644 src/infiniop/ops/linear/operator.cc create mode 100644 src/infiniop/ops/quant/cuda/kernel.cuh create mode 100644 src/infiniop/ops/quant/info.h create mode 100644 src/infiniop/ops/quant/nvidia/quant_nvidia.cu create mode 100644 src/infiniop/ops/quant/nvidia/quant_nvidia.cuh create mode 100644 src/infiniop/ops/quant/operator.cc create mode 100644 src/infiniop/ops/quant/quant.h create mode 100644 test/infiniop/linear.py create mode 100644 test/infiniop/quant.py diff --git a/include/infiniop/ops/linear.h b/include/infiniop/ops/linear.h new file mode 100644 index 000000000..06f599aea --- /dev/null +++ b/include/infiniop/ops/linear.h @@ -0,0 +1,40 @@ +#ifndef __INFINIOP_LINEAR_API_H__ +#define __INFINIOP_LINEAR_API_H__ + +#include "../operator_descriptor.h" + +typedef InfiniopDescriptor *infiniopLinearDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLinearDescriptor(infiniopHandle_t handle, + infiniopLinearDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t d_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t bias_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t weights_desc, + infiniopTensorDescriptor_t weights_scale_desc, + infiniopTensorDescriptor_t weights_zero_desc, + float alpha, + float beta); + +__C __export infiniStatus_t infiniopGetLinearWorkspaceSize(infiniopLinearDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLinear(infiniopLinearDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *d, + const void *c, + const void *bias, + const void *x, + const void *x_scale, + const void *x_zero, + const void *weights, + const void *weights_scale, + const void *weights_zero, + void *stream); + +__C __export infiniStatus_t infiniopDestroyLinearDescriptor(infiniopLinearDescriptor_t desc); + +#endif diff --git a/include/infiniop/ops/quant.h b/include/infiniop/ops/quant.h new file mode 100644 index 000000000..90027c04a --- /dev/null +++ b/include/infiniop/ops/quant.h @@ -0,0 +1,28 @@ +#ifndef __INFINIOP_QUANT_API_H__ +#define __INFINIOP_QUANT_API_H__ + +#include "../operator_descriptor.h" + +typedef InfiniopDescriptor *infiniopQuantDescriptor_t; + +__C __export infiniStatus_t infiniopCreateQuantDescriptor(infiniopHandle_t handle, + infiniopQuantDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_packed_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t x_desc); + +__C __export infiniStatus_t infiniopGetQuantWorkspaceSize(infiniopQuantDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopQuant(infiniopQuantDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *x_packed, + void *x_scale, + void *x_zero, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroyQuantDescriptor(infiniopQuantDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/linear/cuda/kernel.cuh b/src/infiniop/ops/linear/cuda/kernel.cuh new file mode 100644 index 000000000..da9a7c41d --- /dev/null +++ b/src/infiniop/ops/linear/cuda/kernel.cuh @@ -0,0 +1,40 @@ +#ifndef __LINEAR_KERNEL_CUH__ +#define __LINEAR_KERNEL_CUH__ + +template +__device__ void postKernel(Tdata *y, int32_t *y_packed, const Tdata *c, const Tdata *bias, const int8_t *x_packed, const Tdata *x_scale, const Tdata *x_zero, const int8_t *w_packed, const Tdata *w_scale, const Tdata *w_zero, int M, int K, int N, float alpha, float beta) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) { + return; + } + int idx = row * N + col; + float output1 = ((float)x_scale[row] * (float)w_scale[col] * ((float)y_packed[idx] + K * (float)x_zero[row] * (float)w_zero[col])); + float output2 = 0.0f; + float output3 = 0.0f; + float tmp2 = (float)x_scale[row] * (float)w_scale[col] * (float)w_zero[col]; + float tmp3 = (float)x_scale[row] * (float)x_zero[row] * (float)w_scale[col]; + for (int ind = 0; ind < K; ind++) { + output2 += tmp2 * (float)x_packed[row * K + ind]; + output3 += tmp3 * (float)w_packed[ind * N + col]; + } + float output = alpha * (output1 - output2 - output3) + beta * (float)c[idx] + (float)bias[col]; + + y[idx] = static_cast(output); +} + +template +__device__ void postSymKernel(Tdata *y, int32_t *y_packed, const Tdata *c, const Tdata *bias, const int8_t *x_packed, const Tdata *x_scale, const int8_t *w_packed, const Tdata *w_scale, int M, int K, int N, float alpha, float beta) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= M || col >= N) { + return; + } + int idx = row * N + col; + float output1 = (float)x_scale[row] * (float)w_scale[col] * ((float)y_packed[idx]); + + float output = alpha * output1 + beta * (float)c[idx] + (float)bias[col]; + + y[idx] = static_cast(output); +} +#endif // __LINEAR_KERNEL_CUH__ diff --git a/src/infiniop/ops/linear/info.h b/src/infiniop/ops/linear/info.h new file mode 100644 index 000000000..866125d86 --- /dev/null +++ b/src/infiniop/ops/linear/info.h @@ -0,0 +1,79 @@ +#ifndef __LINEAR_INFO_H__ +#define __LINEAR_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" + +namespace op::linear { + +class LinearInfo { +private: + LinearInfo() = default; + +public: + infiniDtype_t dtype, packed_type; + size_t M, K, N; + float alpha, beta; + + static utils::Result createLinearInfo( + infiniopTensorDescriptor_t d_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t bias_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t weights_desc, + infiniopTensorDescriptor_t weights_scale_desc, + infiniopTensorDescriptor_t weights_zero_desc, + float alpha, + float beta) { + + CHECK_OR_RETURN( + d_desc != nullptr && c_desc != nullptr && bias_desc != nullptr && x_desc != nullptr && x_scale_desc != nullptr && weights_desc != nullptr && weights_scale_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + + const infiniDtype_t dtype = d_desc->dtype(); + const infiniDtype_t packed_type = x_desc->dtype(); + CHECK_OR_RETURN(dtype == c_desc->dtype() && dtype == bias_desc->dtype() && dtype == x_scale_desc->dtype() && dtype == weights_scale_desc->dtype(), + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(packed_type == weights_desc->dtype(), + INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + CHECK_DTYPE(packed_type, INFINI_DTYPE_I8); + CHECK_OR_RETURN(bias_desc->ndim() == 1, + INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(d_desc->ndim() == 2 + && c_desc->ndim() == 2 + && x_desc->ndim() == 2 + && x_scale_desc->ndim() == 2 + && weights_desc->ndim() == 2 + && weights_scale_desc->ndim() == 2, + INFINI_STATUS_BAD_TENSOR_SHAPE); + + size_t M = d_desc->dim(0); + size_t N = d_desc->dim(1); + size_t K = x_desc->dim(1); + CHECK_OR_RETURN(N == bias_desc->dim(0), + INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_OR_RETURN(M == x_desc->dim(0) + || M == x_scale_desc->dim(0) + || 1 == x_scale_desc->dim(1) + || 1 == weights_scale_desc->dim(0) + || N == weights_scale_desc->dim(1), + INFINI_STATUS_BAD_TENSOR_SHAPE); + + return utils::Result(LinearInfo{ + dtype, + packed_type, + M, + K, + N, + alpha, + beta}); + } +}; + +} // namespace op::linear + +#endif // __LINEAR_INFO_H__ diff --git a/src/infiniop/ops/linear/linear.h b/src/infiniop/ops/linear/linear.h new file mode 100644 index 000000000..1c0ac51a4 --- /dev/null +++ b/src/infiniop/ops/linear/linear.h @@ -0,0 +1,54 @@ +#ifndef __LINEAR_H__ +#define __LINEAR_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::linear::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + LinearInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(Opaque *opaque, LinearInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), _info(info), _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t minWorkspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t d_desc, \ + infiniopTensorDescriptor_t c_desc, \ + infiniopTensorDescriptor_t bias_desc, \ + infiniopTensorDescriptor_t x_desc, \ + infiniopTensorDescriptor_t x_scale_desc, \ + infiniopTensorDescriptor_t x_zero_desc, \ + infiniopTensorDescriptor_t weights_desc, \ + infiniopTensorDescriptor_t weights_scale_desc, \ + infiniopTensorDescriptor_t weights_zero_desc, \ + float alpha, \ + float beta); \ + template \ + infiniStatus_t launchKernel(const LinearInfo &info, Tdata *y, \ + const Tdata *c, const Tdata *bias, const int8_t *x_packed, \ + const Tdata *x_scale, const Tdata *x_zero, const int8_t *w_packed, \ + const Tdata *w_scale, const Tdata *w_zero, void *stream, void *workspace) const; \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *d, const void *c, const void *bias, const void *x, \ + const void *x_scale, const void *x_zero, const void *weights, \ + const void *weights_scale, const void *weights_zero, void *stream) const; \ + }; \ + } + +#endif // __LINEAR_H__ \ No newline at end of file diff --git a/src/infiniop/ops/linear/nvidia/linear_nvidia.cu b/src/infiniop/ops/linear/nvidia/linear_nvidia.cu new file mode 100644 index 000000000..d46da86be --- /dev/null +++ b/src/infiniop/ops/linear/nvidia/linear_nvidia.cu @@ -0,0 +1,193 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "linear_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../../../reduce/cuda/reduce.cuh" +#include +#include + +#include "../cuda/kernel.cuh" + +#if defined ENABLE_NVIDIA_API +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +void int8Gemm( + const int8_t *x_packed, const int8_t *w_packed, int32_t *y_packed, + int M, int N, int K, cudaStream_t stream) { + using ElementA = int8_t; + using ElementB = int8_t; + using ElementC = int32_t; + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + // Use SIMT opclass to avoid tensor-op interleaved layout requirements + using Gemm = cutlass::gemm::device::Gemm< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementC, // accumulator type + cutlass::arch::OpClassSimt, + cutlass::arch::Sm75>; + + Gemm gemm_op; + + cutlass::gemm::GemmCoord problem_size(M, N, K); + + typename Gemm::Arguments args{ + problem_size, + {x_packed, K}, + {w_packed, N}, + {y_packed, N}, + {y_packed, N}, + {1, 0}}; + + cutlass::Status status = gemm_op.initialize(args, nullptr, stream); + if (status != cutlass::Status::kSuccess) { + printf("[CUTLASS SIMT] initialize failed: %d\n", int(status)); + return; + } + status = gemm_op(); + if (status != cutlass::Status::kSuccess) { + printf("[CUTLASS SIMT] run failed: %d\n", int(status)); + return; + } +} +#endif + +template +INFINIOP_CUDA_KERNEL post( + Tdata *y, int32_t *y_packed, const Tdata *c, const Tdata *bias, const int8_t *x_packed, const Tdata *x_scale, const Tdata *x_zero, const int8_t *w_packed, const Tdata *w_scale, const Tdata *w_zero, int M, int K, int N, float alpha, float beta) { + postKernel(y, y_packed, c, bias, x_packed, x_scale, x_zero, w_packed, w_scale, w_zero, M, K, N, alpha, beta); +} + +template +INFINIOP_CUDA_KERNEL postSym( + Tdata *y, int32_t *y_packed, const Tdata *c, const Tdata *bias, const int8_t *x_packed, const Tdata *x_scale, const int8_t *w_packed, const Tdata *w_scale, int M, int K, int N, float alpha, float beta) { + postSymKernel(y, y_packed, c, bias, x_packed, x_scale, w_packed, w_scale, M, K, N, alpha, beta); +} + +namespace op::linear::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, Descriptor **desc_ptr, + infiniopTensorDescriptor_t d_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t bias_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t weights_desc, + infiniopTensorDescriptor_t weights_scale_desc, + infiniopTensorDescriptor_t weights_zero_desc, + float alpha, + float beta) { + auto handle = reinterpret_cast(handle_); + auto info = LinearInfo::createLinearInfo(d_desc, c_desc, bias_desc, x_desc, x_scale_desc, x_zero_desc, weights_desc, weights_scale_desc, weights_zero_desc, alpha, beta); + CHECK_RESULT(info); + size_t workspace_size = c_desc->dim(0) * c_desc->dim(1) * sizeof(int32_t); + *desc_ptr = new Descriptor( + new Opaque{handle->internal()}, + info.take(), workspace_size, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t Descriptor::launchKernel(const LinearInfo &info, Tdata *y, const Tdata *c, const Tdata *bias, const int8_t *x_packed, const Tdata *x_scale, const Tdata *x_zero, const int8_t *w_packed, const Tdata *w_scale, const Tdata *w_zero, void *stream_, void *workspace) const { + cudaStream_t stream = (cudaStream_t)stream_; + int M = (int)info.M; + int K = (int)info.K; + int N = (int)info.N; + float alpha = info.alpha; + float beta = info.beta; + char *workspace_ptr = reinterpret_cast(workspace); + int32_t *y_packed = reinterpret_cast(workspace_ptr); +#if defined ENABLE_NVIDIA_API + int8Gemm(x_packed, w_packed, y_packed, M, N, K, stream); +#elif defined ENABLE_QY_API + const int32_t alpha_I = 1; + const int32_t beta_I = 0; + CHECK_STATUS(this->_opaque->internal->useCublas( + stream, + [&](cublasHandle_t handle) { + CHECK_CUBLAS(cublasGemmEx( + handle, + CUBLAS_OP_N, // A = w_packed, column-major view + CUBLAS_OP_N, // B = x_packed, column-major view + N, // m = N + M, // n = M + K, // k = K + &alpha_I, + w_packed, CUDA_R_8I, N, // lda = m = N + x_packed, CUDA_R_8I, K, // ldb = k = K + &beta_I, + y_packed, CUDA_R_32I, N, // ldc = m = N + CUBLAS_COMPUTE_32I, + CUBLAS_GEMM_DEFAULT)); + return INFINI_STATUS_SUCCESS; + })); +#endif + constexpr unsigned int BLOCK_SIZE_x = 32; + constexpr unsigned int BLOCK_SIZE_y = 32; + + int num_block_x = (N + BLOCK_SIZE_x - 1) / BLOCK_SIZE_x; + int num_block_y = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y; + dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1); + dim3 grid_dim(num_block_x, num_block_y, 1); + if (x_zero == nullptr && w_zero == nullptr) { + postSym<<>>(y, y_packed, c, bias, x_packed, x_scale, w_packed, w_scale, M, K, N, alpha, beta); + } else { + post<<>>(y, y_packed, c, bias, x_packed, x_scale, x_zero, w_packed, w_scale, w_zero, M, K, N, alpha, beta); + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *d, + const void *c, + const void *bias, + const void *x, + const void *x_scale, + const void *x_zero, + const void *weights, + const void *weights_scale, + const void *weights_zero, + void *stream) const { +#define CALCULATE_LINEAR(BLOCK_SIZE, TDATA) \ + launchKernel(_info, (TDATA *)d, (const TDATA *)c, (const TDATA *)bias, (const int8_t *)x, (const TDATA *)x_scale, (const TDATA *)x_zero, (const int8_t *)weights, (const TDATA *)weights_scale, (const TDATA *)weights_zero, stream, workspace) +#define CALCULATE_LINEAR_WITH_BLOCK_SIZE(BLOCK_SIZE) \ + { \ + if (_info.dtype == INFINI_DTYPE_F16) \ + return CALCULATE_LINEAR(BLOCK_SIZE, half); \ + else if (_info.dtype == INFINI_DTYPE_F32) \ + return CALCULATE_LINEAR(BLOCK_SIZE, float); \ + else if (_info.dtype == INFINI_DTYPE_BF16) \ + return CALCULATE_LINEAR(BLOCK_SIZE, __nv_bfloat16); \ + else \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; \ + } + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + CALCULATE_LINEAR_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_1024) + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CALCULATE_LINEAR_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_512) + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CALCULATE_LINEAR_WITH_BLOCK_SIZE(CUDA_BLOCK_SIZE_4096) + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::linear::nvidia diff --git a/src/infiniop/ops/linear/nvidia/linear_nvidia.cuh b/src/infiniop/ops/linear/nvidia/linear_nvidia.cuh new file mode 100644 index 000000000..fdc3ddf64 --- /dev/null +++ b/src/infiniop/ops/linear/nvidia/linear_nvidia.cuh @@ -0,0 +1,7 @@ +#ifndef __LINEAR_NVIDIA_API_H__ +#define __LINEAR_NVIDIA_API_H__ +#include "../linear.h" + +DESCRIPTOR(nvidia) + +#endif // __LINEAR_NVIDIA_API_H__ diff --git a/src/infiniop/ops/linear/operator.cc b/src/infiniop/ops/linear/operator.cc new file mode 100644 index 000000000..c069c3bd5 --- /dev/null +++ b/src/infiniop/ops/linear/operator.cc @@ -0,0 +1,117 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/linear.h" + +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#include "nvidia/linear_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateLinearDescriptor(infiniopHandle_t handle, + infiniopLinearDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t d_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t bias_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t x_scale_desc, + infiniopTensorDescriptor_t x_zero_desc, + infiniopTensorDescriptor_t weights_desc, + infiniopTensorDescriptor_t weights_scale_desc, + infiniopTensorDescriptor_t weights_zero_desc, + float alpha, + float beta) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::linear::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + d_desc, \ + c_desc, \ + bias_desc, \ + x_desc, \ + x_scale_desc, \ + x_zero_desc, \ + weights_desc, \ + weights_scale_desc, \ + weights_zero_desc, \ + alpha, \ + beta); + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t infiniopGetLinearWorkspaceSize(infiniopLinearDescriptor_t desc, size_t *size) { + switch (desc->device_type) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->minWorkspaceSize(); \ + return INFINI_STATUS_SUCCESS; +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__C infiniStatus_t infiniopLinear(infiniopLinearDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *d, + const void *c, + const void *bias, + const void *x, + const void *x_scale, + const void *x_zero, + const void *weights, + const void *weights_scale, + const void *weights_zero, + void *stream) { +#define CACULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, d, c, bias, x, x_scale, x_zero, weights, weights_scale, weights_zero, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CACULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_QY_API + CACULATE(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CACULATE +} + +__C infiniStatus_t infiniopDestroyLinearDescriptor(infiniopLinearDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_QY_API + DESTROY(INFINI_DEVICE_QY, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/src/infiniop/ops/quant/cuda/kernel.cuh b/src/infiniop/ops/quant/cuda/kernel.cuh new file mode 100644 index 000000000..e59ba4d0c --- /dev/null +++ b/src/infiniop/ops/quant/cuda/kernel.cuh @@ -0,0 +1,277 @@ +#ifndef __QUANT_KERNEL_CUH__ +#define __QUANT_KERNEL_CUH__ + +#include +__device__ inline int round_half_away_from_zero(float x) { + float ax = fabsf(x); + float r = floorf(ax + 0.5f); + return (x >= 0.0f) ? (int)r : -(int)r; +} + +template +__device__ void blockQuantKernel( + int8_t *x_packed, Tdata *x_scale, Tdata *x_zero, const Tdata *x, + int M, int K) { + int row = blockIdx.x; + int tid = row * K; + + // ---- 1. reduce max ---- + float local_max = op::common_cuda::reduce_op::max( + x + tid, K); + + __shared__ float global_max_f; + if (threadIdx.x == 0) { + global_max_f = local_max; + } + __syncthreads(); + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // ---- 2. reduce min ---- + float thread_min = __FLT_MAX__; + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + thread_min = fminf(thread_min, (float)x[tid + ind]); + } + float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min()); + + __shared__ float global_min_f; + if (threadIdx.x == 0) { + global_min_f = local_min; + } + __syncthreads(); + + // ---- 3. 使用 float(匹配 python)计算 scale/zero ---- + float global_max = global_max_f; + float global_min = global_min_f; + + float scale = (global_max - global_min) / 255.0f; + if (scale < 1e-8f) { + scale = 1e-8f; + } + + float inv_scale = 1.0f / scale; + float zero = -global_min * inv_scale - 128.0f; + + // 写回 scale, zero + x_scale[row] = (Tdata)scale; + x_zero[row] = (Tdata)zero; + + // ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)---- + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + + float v = (float)x[tid + ind]; + float qf = v * inv_scale + zero; + + int q = round_half_away_from_zero(qf); + + if (q > 127) { + q = 127; + } + if (q < -128) { + q = -128; + } + + x_packed[tid + ind] = (int8_t)q; + } +} + +template +__device__ void blockQuantSymKernel( + int8_t *x_packed, Tdata *x_scale, const Tdata *x, + int M, int K) { + int row = blockIdx.x; + int tid = row * K; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + // ---- 2. reduce min ---- + float thread_max = -__FLT_MAX__; + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + thread_max = fmaxf(thread_max, fabs((float)x[tid + ind])); + } + float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max()); + + __shared__ float global_max_f; + if (threadIdx.x == 0) { + global_max_f = local_max; + } + __syncthreads(); + + // ---- 3. 使用 float(匹配 python)计算 scale/zero ---- + float global_max = global_max_f; + + float scale = global_max / 127.0f; + if (scale < 1e-8f) { + scale = 1e-8f; + } + + float inv_scale = 1.0f / scale; + + // 写回 scale, zero + x_scale[row] = (Tdata)scale; + + // ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)---- + for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) { + + float v = (float)x[tid + ind]; + float qf = v * inv_scale; + + int q = round_half_away_from_zero(qf); + + if (q > 127) { + q = 127; + } + if (q < -128) { + q = -128; + } + + x_packed[tid + ind] = (int8_t)q; + } +} + +template +struct MaxOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return max(a, b); + } +}; +template +struct MinOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return min(a, b); + } +}; +template