diff --git a/README.md b/README.md index 653657bb7..190884507 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] | `--sugon-dcu=[y\|n]` | 是否编译曙光 DCU 接口实现 | n | `--kunlun-xpu=[y\|n]` | 是否编译昆仑 XPU 接口实现 | n | `--ccl=[y\|n]` | 是否编译 InfiniCCL 通信库接口实现 | n - +| `--prebuild=[y\|n]` | 是否提前编译cub包,以减少编译耗时 | n ### 手动安装 1. 项目配置 @@ -75,8 +75,9 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] # 英伟达 # 可以指定 CUDA 路径, 一般环境变量为 `CUDA_HOME` 或者 `CUDA_ROOT` # window系统:--cuda="%CUDA_HOME%" - # linux系统:--cuda=$CUDA_HOME - xmake f --nv-gpu=true --cuda=$CUDA_HOME -cv + # linux系统 :--cuda=$CUDA_HOME + # 如果提前编译cub包,要将 prebuild 设置为 true + xmake f --nv-gpu=true --cuda=$CUDA_HOME --prebuild=false -cv # 寒武纪 xmake f --cambricon-mlu=true -cv diff --git a/src/infinicub/include/cub_algorithms_nvidia.cuh b/src/infinicub/include/cub_algorithms_nvidia.cuh new file mode 100644 index 000000000..7dabe01f3 --- /dev/null +++ b/src/infinicub/include/cub_algorithms_nvidia.cuh @@ -0,0 +1,462 @@ +#ifndef __CUB_ALGORITHMS_NVIDIA_CUH__ +#define __CUB_ALGORITHMS_NVIDIA_CUH__ + +#ifdef ENABLE_CUDA_API + +#include +#include +#include +#include + +namespace infini_cub::nvidai::DeviceReduce { + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const half *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const float *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const double *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); +} // namespace infini_cub::nvidai::DeviceReduce + +namespace infini_cub::nvidai::DeviceReduce { + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + __nv_bfloat16 *data, + int n, + cudaStream_t stream); + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + half *data, + int n, + cudaStream_t stream); + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + float *data, + int n, + cudaStream_t stream); + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + double *data, + int n, + cudaStream_t stream); +} // namespace infini_cub::nvidai::DeviceReduce + +namespace infini_cub::nvidai::DeviceRadixSort { +using uchar = uint8_t; +using ushort = uint16_t; +using uint = uint32_t; +using ulong = uint64_t; + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); +#endif + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); +#endif + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); +#endif + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); +#endif + +} // namespace infini_cub::nvidai::DeviceRadixSort +#endif // ENABLE_CUDA_API +#endif // __CUB_ALGORITHMS_NVIDIA_CUH__ diff --git a/src/infinicub/src/cub_algorithms_nvidia.cu b/src/infinicub/src/cub_algorithms_nvidia.cu new file mode 100644 index 000000000..da96b6507 --- /dev/null +++ b/src/infinicub/src/cub_algorithms_nvidia.cu @@ -0,0 +1,759 @@ +#ifdef ENABLE_CUDA_API + +#include "cub_algorithms_nvidia.cuh" +#include +#include +#include +#include + +namespace infini_cub::nvidai::DeviceReduce { + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const half *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const float *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} + +cudaError ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const double *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} +} // namespace infini_cub::nvidai::DeviceReduce + +namespace infini_cub::nvidai::DeviceReduce { + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + __nv_bfloat16 *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + half *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + float *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} + +cudaError InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + double *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} +} // namespace infini_cub::nvidai::DeviceReduce + +namespace infini_cub::nvidai::DeviceRadixSort { + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} +#endif + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} +#endif + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} +#endif + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending(workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +#ifdef _WIN64 +cudaError SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending(workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} +#endif + +} // namespace infini_cub::nvidai::DeviceRadixSort + +#endif // ENABLE_CUDA_API \ No newline at end of file diff --git a/src/infinicub/xmake.lua b/src/infinicub/xmake.lua new file mode 100644 index 000000000..932afdc7c --- /dev/null +++ b/src/infinicub/xmake.lua @@ -0,0 +1,43 @@ +local CUDA_ROOT = os.getenv("CUDA_ROOT") or os.getenv("CUDA_HOME") or os.getenv("CUDA_PATH") +local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH") +if CUDA_ROOT ~= nil then + add_includedirs(CUDA_ROOT .. "/include") +end +if CUDNN_ROOT ~= nil then + add_includedirs(CUDNN_ROOT .. "/include") +end + +if is_plat("windows") then + set_runtimes("MD") +end + +-- 静态库目标 +target("infinicub") + set_kind("static") + set_policy("build.cuda.devlink", true) + set_toolchains("cuda") + add_links("cublas", "cudnn") + add_cugencodes("native") + + add_defines("ENABLE_CUDA_API") + + if is_plat("windows") then + add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") + add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX") + add_cxxflags("/FS") + if CUDNN_ROOT ~= nil then + add_linkdirs(CUDNN_ROOT .. "\\lib\\x64") + end + else + add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror") + add_cuflags("-Xcompiler=-fPIC") + add_cuflags("--extended-lambda") + add_culdflags("-Xcompiler=-fPIC") + add_cxxflags("-fPIC") + end + + set_languages("cxx17") + add_includedirs("include") + + add_files("src/*.cu") +target_end() diff --git a/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh b/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh index a0a349c94..ae013d57f 100644 --- a/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh +++ b/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh @@ -1,9 +1,13 @@ #include "../../../devices/cuda/cuda_kernel_common.cuh" #include "infinicore.h" + +#ifdef ENABLE_INFINI_CUB +#include "../../../../infinicub/include/cub_algorithms_nvidia.cuh" +#else #include #include #include - +#endif namespace op::random_sample::cuda { // ↓↓↓ 重新封装 cub api,减少模板参数,方便调用 @@ -16,10 +20,18 @@ static cudaError argMax_( void *workspace_ptr, size_t &workspace_len, cudaStream_t stream) { + +#ifdef ENABLE_INFINI_CUB + return infini_cub::nvidai::DeviceReduce::ArgMax( + workspace_ptr, workspace_len, + logits, kv_pair, n, + stream); +#else return cub::DeviceReduce::ArgMax( workspace_ptr, workspace_len, logits, kv_pair, n, stream); +#endif } template @@ -29,6 +41,15 @@ static cudaError radixSort( const Tidx *val_in, Tidx *val_out, int n, cudaStream_t stream) { + +#ifdef ENABLE_INFINI_CUB + return infini_cub::nvidai::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + stream); +#else return cub::DeviceRadixSort::SortPairsDescending( workspace_ptr, workspace_len, key_in, key_out, @@ -36,6 +57,7 @@ static cudaError radixSort( n, 0, sizeof(Tval) * 8, stream); +#endif } template @@ -43,10 +65,18 @@ static cudaError inclusiveSum( void *workspace_ptr, size_t &workspace_len, T *data, int n, cudaStream_t stream) { + +#ifdef ENABLE_INFINI_CUB + return infini_cub::nvidai::DeviceReduce::InclusiveSum( + workspace_ptr, workspace_len, + data, n, + stream); +#else return cub::DeviceScan::InclusiveSum( workspace_ptr, workspace_len, data, data, n, stream); +#endif } // ↑↑↑ 重新封装 cub api,减少模板参数,方便调用 diff --git a/xmake.lua b/xmake.lua index 96002818b..5bb46cc94 100644 --- a/xmake.lua +++ b/xmake.lua @@ -18,6 +18,13 @@ if is_plat("windows") then add_cxflags("/utf-8", {force = true}) end +-- infinicub +option("prebuild") + set_default(false) + set_showmenu(true) + set_description("Enable or disable cub package") +option_end() + -- CPU option("cpu") set_default(true) diff --git a/xmake/cub.lua b/xmake/cub.lua new file mode 100644 index 000000000..914c923c4 --- /dev/null +++ b/xmake/cub.lua @@ -0,0 +1,20 @@ +package("infinicub") + set_description("Build infinicub library.") + + add_versions("1.0.0", "commit-hash-or-sha256-for-v1.0.0") + add_versions("1.0.1", "commit-hash-or-sha256-for-v1.0.1") + add_versions("1.0.2", "commit-hash-or-sha256-for-v1.0.2") + add_versions("1.0.3", "commit-hash-or-sha256-for-v1.0.3") + add_versions("1.0.4", "commit-hash-or-sha256-for-v1.0.4") + + set_sourcedir(path.join(os.scriptdir(), "../src/infinicub")) + + local dir = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "USERPROFILE") .. "/.infini") + set_installdir(path.join(dir,"packages/infinicub/", get_config("plat"), get_config("arch"), get_config("mode"))) + + add_configs("shared", {default = false, type = "boolean", readonly = true}) + on_install(function (package) + local configs = {} + import("package.tools.xmake").install(package, configs) + end) +package_end() diff --git a/xmake/cuda.lua b/xmake/cuda.lua index b9369fa7b..6d0d41664 100644 --- a/xmake/cuda.lua +++ b/xmake/cuda.lua @@ -8,6 +8,9 @@ if CUDNN_ROOT ~= nil then add_includedirs(CUDNN_ROOT .. "/include") end + +add_requires("infinicub 1.0.0", {optional = true, configs = {shared = false, defines="ENABLE_CUDA_API" }}) + target("infiniop-cuda") set_kind("static") add_deps("infini-utils") @@ -18,6 +21,12 @@ target("infiniop-cuda") add_links("cublas", "cudnn") add_cugencodes("native") + if has_config("prebuild") and has_config("nv-gpu") then + add_defines("ENABLE_INFINI_CUB") + includes("cub.lua") + add_packages("infinicub") + end + if is_plat("windows") then add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX")