-
Notifications
You must be signed in to change notification settings - Fork 531
misc: fix vector size calculation for fp4 #1702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,12 +15,17 @@ | |
*/ | ||
#ifndef FLASHINFER_UTILS_CUH_ | ||
#define FLASHINFER_UTILS_CUH_ | ||
#include <cuda.h> | ||
#include <cuda_bf16.h> | ||
#include <cuda_device_runtime_api.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_fp8.h> | ||
#include <cuda_runtime.h> | ||
|
||
#if CUDA_VERSION >= 12080 | ||
#include <cuda_fp4.h> | ||
#endif | ||
|
||
#include <cstdint> | ||
#include <iostream> | ||
#include <type_traits> | ||
|
@@ -289,6 +294,26 @@ inline std::pair<int, int> GetCudaComputeCapability() { | |
return std::make_pair(major, minor); | ||
} | ||
|
||
/*! | ||
* \brief Calculate the vector size for 128-bit alignment given data type T. | ||
* \tparam T The data type | ||
* \return The vector size (number of elements of type T that make up 128 bits) | ||
* \note For most types, this is 16 / sizeof(T). Special case: __nv_fp4_e2m1 | ||
* is a subbyte type padded to 1 byte, so we return 32 (128 bits / 4 bits per element). | ||
*/ | ||
template <typename T> | ||
__host__ __device__ __forceinline__ constexpr size_t get_vec_size_128b() { | ||
#if CUDA_VERSION >= 12080 | ||
if constexpr (std::is_same_v<T, __nv_fp4_e2m1>) { | ||
return 32; // 128 bits / 4 bits per element = 32 elements | ||
} else { | ||
return 16 / sizeof(T); | ||
} | ||
#else | ||
return 16 / sizeof(T); | ||
#endif | ||
} | ||
Comment on lines
+304
to
+315
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function can be simplified to improve readability and reduce code duplication. Additionally, it currently only handles the Here's a suggested refactoring that addresses both points: template <typename T>
__host__ __device__ __forceinline__ constexpr size_t get_vec_size_128b() {
#if CUDA_VERSION >= 12080
if constexpr (std::is_same_v<T, __nv_fp4_e2m1> || std::is_same_v<T, __nv_fp4_e2m0>) {
return 32; // 128 bits / 4 bits per element = 32 elements
}
#endif
return 16 / sizeof(T);
} The
|
||
|
||
template <typename T> | ||
inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { | ||
std::vector<T> host_array(size); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
<cuda.h>
header seems redundant here. TheCUDA_VERSION
macro, which is used for conditional compilation, is also available through<cuda_runtime.h>
, which is already included. Removing this unnecessary include helps in keeping file dependencies clean and minimal.