Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions csrc/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ void silu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) {
auto stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
uint32_t vec_size = get_vec_size_128b<c_type>();
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
Expand Down Expand Up @@ -72,7 +72,7 @@ void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) {
auto stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
uint32_t vec_size = get_vec_size_128b<c_type>();
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
Expand Down Expand Up @@ -103,7 +103,7 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input, bool enable_pdl) {
auto stream = at::cuda::getCurrentCUDAStream();

DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
uint32_t vec_size = get_vec_size_128b<c_type>();
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
Expand Down
2 changes: 1 addition & 1 deletion flashinfer/jit/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
const c10::cuda::OptionalCUDAGuard device_guard(out.device());
auto stream = at::cuda::getCurrentCUDAStream();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input.scalar_type(), c_type, [&] {
uint32_t vec_size = 16 / sizeof(c_type);
uint32_t vec_size = get_vec_size_128b<c_type>();
cudaLaunchConfig_t config;
config.gridDim = num_tokens;
config.blockDim = std::min(d / vec_size, 1024U);
Expand Down
2 changes: 1 addition & 1 deletion include/flashinfer/activation.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace activation {

template <typename T, float (*Activation)(const float&)>
__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) {
constexpr uint32_t vec_size = 16 / sizeof(T);
constexpr uint32_t vec_size = get_vec_size_128b<T>();
const int64_t token_idx = blockIdx.x;
const int64_t thread_idx = threadIdx.x;
const int64_t stride = blockDim.x;
Expand Down
6 changes: 3 additions & 3 deletions include/flashinfer/cp_async.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ __device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) {
load_128b<prefetch_mode>(smem_ptr, gmem_ptr);
} else {
load_128b<prefetch_mode>(smem_ptr, gmem_ptr);
load_128b<prefetch_mode>(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T));
load_128b<prefetch_mode>(smem_ptr + get_vec_size_128b<T>(), gmem_ptr + get_vec_size_128b<T>());
}
}

Expand All @@ -177,8 +177,8 @@ __device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool p
pred_load_128b<prefetch_mode, fill_mode>(smem_ptr, gmem_ptr, predicate);
} else {
pred_load_128b<prefetch_mode, fill_mode>(smem_ptr, gmem_ptr, predicate);
pred_load_128b<prefetch_mode, fill_mode>(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T),
predicate);
pred_load_128b<prefetch_mode, fill_mode>(smem_ptr + get_vec_size_128b<T>(),
gmem_ptr + get_vec_size_128b<T>(), predicate);
}
}

Expand Down
8 changes: 4 additions & 4 deletions include/flashinfer/norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ template <typename T>
cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
uint32_t stride_input, uint32_t stride_output, float eps = 1e-5,
bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t vec_size = std::gcd(get_vec_size_128b<T>(), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
Expand Down Expand Up @@ -236,7 +236,7 @@ template <typename T>
cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5,
bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t vec_size = std::gcd(get_vec_size_128b<T>(), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
Expand Down Expand Up @@ -273,7 +273,7 @@ template <typename T>
cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d,
uint32_t stride_input, uint32_t stride_output, float eps = 1e-5,
bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t vec_size = std::gcd(get_vec_size_128b<T>(), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
Expand Down Expand Up @@ -308,7 +308,7 @@ template <typename T>
cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d,
uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5,
bool enable_pdl = false, cudaStream_t stream = 0) {
const uint32_t vec_size = std::gcd(16 / sizeof(T), d);
const uint32_t vec_size = std::gcd(get_vec_size_128b<T>(), d);

const uint32_t block_size = std::min<uint32_t>(1024, d / vec_size);
const uint32_t num_warps = ceil_div(block_size, 32);
Expand Down
25 changes: 25 additions & 0 deletions include/flashinfer/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
*/
#ifndef FLASHINFER_UTILS_CUH_
#define FLASHINFER_UTILS_CUH_
#include <cuda.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The <cuda.h> header seems redundant here. The CUDA_VERSION macro, which is used for conditional compilation, is also available through <cuda_runtime.h>, which is already included. Removing this unnecessary include helps in keeping file dependencies clean and minimal.

#include <cuda_bf16.h>
#include <cuda_device_runtime_api.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>

#if CUDA_VERSION >= 12080
#include <cuda_fp4.h>
#endif

#include <cstdint>
#include <iostream>
#include <type_traits>
Expand Down Expand Up @@ -289,6 +294,26 @@ inline std::pair<int, int> GetCudaComputeCapability() {
return std::make_pair(major, minor);
}

/*!
* \brief Calculate the vector size for 128-bit alignment given data type T.
* \tparam T The data type
* \return The vector size (number of elements of type T that make up 128 bits)
* \note For most types, this is 16 / sizeof(T). Special case: __nv_fp4_e2m1
* is a subbyte type padded to 1 byte, so we return 32 (128 bits / 4 bits per element).
*/
template <typename T>
__host__ __device__ __forceinline__ constexpr size_t get_vec_size_128b() {
#if CUDA_VERSION >= 12080
if constexpr (std::is_same_v<T, __nv_fp4_e2m1>) {
return 32; // 128 bits / 4 bits per element = 32 elements
} else {
return 16 / sizeof(T);
}
#else
return 16 / sizeof(T);
#endif
}
Comment on lines +304 to +315
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function can be simplified to improve readability and reduce code duplication. Additionally, it currently only handles the __nv_fp4_e2m1 type for FP4. The other FP4 type, __nv_fp4_e2m0, should also be handled to make the utility more robust.

Here's a suggested refactoring that addresses both points:

template <typename T>
__host__ __device__ __forceinline__ constexpr size_t get_vec_size_128b() {
#if CUDA_VERSION >= 12080
  if constexpr (std::is_same_v<T, __nv_fp4_e2m1> || std::is_same_v<T, __nv_fp4_e2m0>) {
    return 32;  // 128 bits / 4 bits per element = 32 elements
  }
#endif
  return 16 / sizeof(T);
}

The if constexpr with a return inside makes the else branch unnecessary. The default case return 16 / sizeof(T); can be shared for both CUDA versions and for types other than FP4.

template <typename T>
__host__ __device__ __forceinline__ constexpr size_t get_vec_size_128b() {
#if CUDA_VERSION >= 12080
  if constexpr (std::is_same_v<T, __nv_fp4_e2m1> || std::is_same_v<T, __nv_fp4_e2m0>) {
    return 32;  // 128 bits / 4 bits per element = 32 elements
  }
#endif
  return 16 / sizeof(T);
}


template <typename T>
inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") {
std::vector<T> host_array(size);
Expand Down