diff --git a/csrc/activation.cu b/csrc/activation.cu index 0e9cef7c41..d03f0a29b2 100644 --- a/csrc/activation.cu +++ b/csrc/activation.cu @@ -40,7 +40,7 @@ void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) { auto stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); + uint32_t vec_size = get_vec_size_128b(); cudaLaunchConfig_t config; config.gridDim = num_tokens; config.blockDim = std::min(d / vec_size, 1024U); @@ -72,7 +72,7 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) { auto stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); + uint32_t vec_size = get_vec_size_128b(); cudaLaunchConfig_t config; config.gridDim = num_tokens; config.blockDim = std::min(d / vec_size, 1024U); @@ -103,7 +103,7 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) { auto stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); + uint32_t vec_size = get_vec_size_128b(); cudaLaunchConfig_t config; config.gridDim = num_tokens; config.blockDim = std::min(d / vec_size, 1024U); diff --git a/flashinfer/jit/activation.py b/flashinfer/jit/activation.py index 4d78616e5c..3c2a125fbd 100644 --- a/flashinfer/jit/activation.py +++ b/flashinfer/jit/activation.py @@ -41,7 +41,7 @@ const c10::cuda::OptionalCUDAGuard device_guard(out.device()); auto stream = at::cuda::getCurrentCUDAStream(); DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] { - uint32_t vec_size = 16 / sizeof(c_type); + uint32_t vec_size = get_vec_size_128b(); cudaLaunchConfig_t config; config.gridDim = num_tokens; config.blockDim = std::min(d / vec_size, 1024U); diff --git a/include/flashinfer/activation.cuh b/include/flashinfer/activation.cuh index 6e9f029923..3ebe21448b 100644 --- a/include/flashinfer/activation.cuh +++ b/include/flashinfer/activation.cuh @@ -27,7 +27,7 @@ namespace activation { template __global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { - constexpr uint32_t vec_size = 16 / sizeof(T); + constexpr uint32_t vec_size = get_vec_size_128b(); const int64_t token_idx = blockIdx.x; const int64_t thread_idx = threadIdx.x; const int64_t stride = blockDim.x; diff --git a/include/flashinfer/cp_async.cuh b/include/flashinfer/cp_async.cuh index bd59cc58e3..50e0226f19 100644 --- a/include/flashinfer/cp_async.cuh +++ b/include/flashinfer/cp_async.cuh @@ -154,7 +154,7 @@ __device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { load_128b(smem_ptr, gmem_ptr); } else { load_128b(smem_ptr, gmem_ptr); - load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T)); + load_128b(smem_ptr + get_vec_size_128b(), gmem_ptr + get_vec_size_128b()); } } @@ -177,8 +177,8 @@ __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool p pred_load_128b(smem_ptr, gmem_ptr, predicate); } else { pred_load_128b(smem_ptr, gmem_ptr, predicate); - pred_load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T), - predicate); + pred_load_128b(smem_ptr + get_vec_size_128b(), + gmem_ptr + get_vec_size_128b(), predicate); } } diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index f2c91138b3..9ba1c542c9 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -108,7 +108,7 @@ template cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d, uint32_t stride_input, uint32_t stride_output, float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + const uint32_t vec_size = std::gcd(get_vec_size_128b(), d); const uint32_t block_size = std::min(1024, d / vec_size); const uint32_t num_warps = ceil_div(block_size, 32); @@ -236,7 +236,7 @@ template cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + const uint32_t vec_size = std::gcd(get_vec_size_128b(), d); const uint32_t block_size = std::min(1024, d / vec_size); const uint32_t num_warps = ceil_div(block_size, 32); @@ -273,7 +273,7 @@ template cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d, uint32_t stride_input, uint32_t stride_output, float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + const uint32_t vec_size = std::gcd(get_vec_size_128b(), d); const uint32_t block_size = std::min(1024, d / vec_size); const uint32_t num_warps = ceil_div(block_size, 32); @@ -308,7 +308,7 @@ template cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5, bool enable_pdl = false, cudaStream_t stream = 0) { - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + const uint32_t vec_size = std::gcd(get_vec_size_128b(), d); const uint32_t block_size = std::min(1024, d / vec_size); const uint32_t num_warps = ceil_div(block_size, 32); diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 5b26d7beaf..347874c5f3 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -15,12 +15,17 @@ */ #ifndef FLASHINFER_UTILS_CUH_ #define FLASHINFER_UTILS_CUH_ +#include #include #include #include #include #include +#if CUDA_VERSION >= 12080 +#include +#endif + #include #include #include @@ -289,6 +294,26 @@ inline std::pair GetCudaComputeCapability() { return std::make_pair(major, minor); } +/*! + * \brief Calculate the vector size for 128-bit alignment given data type T. + * \tparam T The data type + * \return The vector size (number of elements of type T that make up 128 bits) + * \note For most types, this is 16 / sizeof(T). Special case: __nv_fp4_e2m1 + * is a subbyte type padded to 1 byte, so we return 32 (128 bits / 4 bits per element). + */ +template +__host__ __device__ __forceinline__ constexpr size_t get_vec_size_128b() { +#if CUDA_VERSION >= 12080 + if constexpr (std::is_same_v) { + return 32; // 128 bits / 4 bits per element = 32 elements + } else { + return 16 / sizeof(T); + } +#else + return 16 / sizeof(T); +#endif +} + template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { std::vector host_array(size);