From ee7a40ac06037eda1b51a2b08d1280a8e6ceca0c Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Wed, 9 Jul 2025 00:53:43 +0800 Subject: [PATCH 1/3] issue/306 - reduce compile time by intruducing custom packages. --- README.md | 7 +- src/infinicub/include/cub_algorithms.cuh | 453 +++++++++++ src/infinicub/src/cub_algorithms.cu | 748 ++++++++++++++++++ src/infinicub/xmake.lua | 41 + .../cuda/random_sample_kernel.cuh | 34 +- xmake.lua | 7 + xmake/cub.lua | 17 + xmake/cuda.lua | 9 + 8 files changed, 1311 insertions(+), 5 deletions(-) create mode 100644 src/infinicub/include/cub_algorithms.cuh create mode 100644 src/infinicub/src/cub_algorithms.cu create mode 100644 src/infinicub/xmake.lua create mode 100644 xmake/cub.lua diff --git a/README.md b/README.md index 653657bb7..190884507 100644 --- a/README.md +++ b/README.md @@ -51,7 +51,7 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] | `--sugon-dcu=[y\|n]` | 是否编译曙光 DCU 接口实现 | n | `--kunlun-xpu=[y\|n]` | 是否编译昆仑 XPU 接口实现 | n | `--ccl=[y\|n]` | 是否编译 InfiniCCL 通信库接口实现 | n - +| `--prebuild=[y\|n]` | 是否提前编译cub包,以减少编译耗时 | n ### 手动安装 1. 项目配置 @@ -75,8 +75,9 @@ python scripts/install.py [XMAKE_CONFIG_FLAGS] # 英伟达 # 可以指定 CUDA 路径, 一般环境变量为 `CUDA_HOME` 或者 `CUDA_ROOT` # window系统:--cuda="%CUDA_HOME%" - # linux系统:--cuda=$CUDA_HOME - xmake f --nv-gpu=true --cuda=$CUDA_HOME -cv + # linux系统 :--cuda=$CUDA_HOME + # 如果提前编译cub包,要将 prebuild 设置为 true + xmake f --nv-gpu=true --cuda=$CUDA_HOME --prebuild=false -cv # 寒武纪 xmake f --cambricon-mlu=true -cv diff --git a/src/infinicub/include/cub_algorithms.cuh b/src/infinicub/include/cub_algorithms.cuh new file mode 100644 index 000000000..c2280edbc --- /dev/null +++ b/src/infinicub/include/cub_algorithms.cuh @@ -0,0 +1,453 @@ +#ifndef __CUB_ALGORITHMS_CUH__ +#define __CUB_ALGORITHMS_CUH__ + +#include +#include +#include +#include + +namespace infini_cub { + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const half *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const float *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const double *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream); + +} // namespace infini_cub + +namespace infini_cub { + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + __nv_bfloat16 *data, + int n, + cudaStream_t stream); + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + half *data, + int n, + cudaStream_t stream); + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + float *data, + int n, + cudaStream_t stream); + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + double *data, + int n, + cudaStream_t stream); +} // namespace infini_cub + +namespace infini_cub { +using uchar = uint8_t; +using ushort = uint16_t; +using uint = uint32_t; +using ulong = uint64_t; + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream); + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream); + + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream); + +} // namespace infini_cub + +#endif // __CUB_ALGORITHMS_CUH__ diff --git a/src/infinicub/src/cub_algorithms.cu b/src/infinicub/src/cub_algorithms.cu new file mode 100644 index 000000000..c61ac14c9 --- /dev/null +++ b/src/infinicub/src/cub_algorithms.cu @@ -0,0 +1,748 @@ +#include "cub_algorithms.cuh" +#include +#include +#include +#include + +namespace infini_cub { + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const half *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const float *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} + +cudaError cub_DeviceReduce_ArgMax( + void *workspace_ptr, + size_t &workspace_len, + const double *logits, + cub::KeyValuePair *kv_pair, + int n, + cudaStream_t stream) { + return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); +} +} // namespace infini_cub + +namespace infini_cub { + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + __nv_bfloat16 *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + half *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + float *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} + +cudaError cub_DeviceScan_InclusiveSum( + void *workspace_ptr, + size_t &workspace_len, + double *data, + int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); +} +} // namespace infini_cub + +namespace infini_cub { +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const __nv_bfloat16 *key_in, + __nv_bfloat16 *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const half *key_in, + half *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(half) * 8, + stream); +} + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const float *key_in, + float *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(float) * 8, + stream); +} + +// -------------------------------------------------------------- +// -------------------------------------------------------------- +// -------------------------------------------------------------- +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uchar *val_in, + uchar *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int8_t *val_in, + int8_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ushort *val_in, + ushort *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const short *val_in, + short *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const uint *val_in, + uint *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const int *val_in, + int *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ulong *val_in, + ulong *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const long *val_in, + long *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending(workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} + +cudaError cub_DeviceRadixSort_SortPairsDescending( + void *workspace_ptr, + size_t &workspace_len, + const double *key_in, + double *key_out, + const ptrdiff_t *val_in, + ptrdiff_t *val_out, + int n, + cudaStream_t stream) { + return cub::DeviceRadixSort::SortPairsDescending(workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + 0, sizeof(double) * 8, + stream); +} +} // namespace infini_cub diff --git a/src/infinicub/xmake.lua b/src/infinicub/xmake.lua new file mode 100644 index 000000000..c7d9e5dc5 --- /dev/null +++ b/src/infinicub/xmake.lua @@ -0,0 +1,41 @@ +local CUDA_ROOT = os.getenv("CUDA_ROOT") or os.getenv("CUDA_HOME") or os.getenv("CUDA_PATH") +local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH") +if CUDA_ROOT ~= nil then + add_includedirs(CUDA_ROOT .. "/include") +end +if CUDNN_ROOT ~= nil then + add_includedirs(CUDNN_ROOT .. "/include") +end + +if is_plat("windows") then + set_runtimes("MD") +end + +-- 静态库目标 +target("infinicub") + set_kind("static") + set_policy("build.cuda.devlink", true) + set_toolchains("cuda") + add_links("cublas", "cudnn") + add_cugencodes("native") + + if is_plat("windows") then + add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") + add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX") + add_cxxflags("/FS") + if CUDNN_ROOT ~= nil then + add_linkdirs(CUDNN_ROOT .. "\\lib\\x64") + end + else + add_cuflags("-Xcompiler=-Wall", "-Xcompiler=-Werror") + add_cuflags("-Xcompiler=-fPIC") + add_cuflags("--extended-lambda") + add_culdflags("-Xcompiler=-fPIC") + add_cxxflags("-fPIC") + end + + set_languages("cxx17") + add_includedirs("include") + + add_files("src/cub_algorithms.cu") +target_end() diff --git a/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh b/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh index a0a349c94..1a1a62d27 100644 --- a/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh +++ b/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh @@ -1,9 +1,13 @@ #include "../../../devices/cuda/cuda_kernel_common.cuh" #include "infinicore.h" + +#ifdef ENABLE_INFINI_CUB +#include "../../../../infinicub/include/cub_algorithms.cuh" +#else #include #include #include - +#endif namespace op::random_sample::cuda { // ↓↓↓ 重新封装 cub api,减少模板参数,方便调用 @@ -16,10 +20,18 @@ static cudaError argMax_( void *workspace_ptr, size_t &workspace_len, cudaStream_t stream) { + +#ifdef ENABLE_INFINI_CUB + return infini_cub::cub_DeviceReduce_ArgMax( + workspace_ptr, workspace_len, + logits, kv_pair, n, + stream); +#else return cub::DeviceReduce::ArgMax( workspace_ptr, workspace_len, logits, kv_pair, n, stream); +#endif } template @@ -29,6 +41,15 @@ static cudaError radixSort( const Tidx *val_in, Tidx *val_out, int n, cudaStream_t stream) { + +#ifdef ENABLE_INFINI_CUB + return infini_cub::cub_DeviceRadixSort_SortPairsDescending( + workspace_ptr, workspace_len, + key_in, key_out, + val_in, val_out, + n, + stream); +#else return cub::DeviceRadixSort::SortPairsDescending( workspace_ptr, workspace_len, key_in, key_out, @@ -36,6 +57,7 @@ static cudaError radixSort( n, 0, sizeof(Tval) * 8, stream); +#endif } template @@ -43,10 +65,18 @@ static cudaError inclusiveSum( void *workspace_ptr, size_t &workspace_len, T *data, int n, cudaStream_t stream) { - return cub::DeviceScan::InclusiveSum( + +#ifdef ENABLE_INFINI_CUB + return infini_cub::cub_DeviceScan_InclusiveSum( workspace_ptr, workspace_len, + data, n, + stream); +#else + return cub::DeviceScan::InclusiveSum( + workspace_ptr, workspace_len, data, data, n, stream); +#endif } // ↑↑↑ 重新封装 cub api,减少模板参数,方便调用 diff --git a/xmake.lua b/xmake.lua index 96002818b..5bb46cc94 100644 --- a/xmake.lua +++ b/xmake.lua @@ -18,6 +18,13 @@ if is_plat("windows") then add_cxflags("/utf-8", {force = true}) end +-- infinicub +option("prebuild") + set_default(false) + set_showmenu(true) + set_description("Enable or disable cub package") +option_end() + -- CPU option("cpu") set_default(true) diff --git a/xmake/cub.lua b/xmake/cub.lua new file mode 100644 index 000000000..397d348bd --- /dev/null +++ b/xmake/cub.lua @@ -0,0 +1,17 @@ +package("infinicub") + set_description("Build infinicub library.") + + add_versions("1.0.0", "commit-hash-or-sha256-for-v1.0.0") + add_versions("1.0.1", "commit-hash-or-sha256-for-v1.0.1") + + set_sourcedir(path.join(os.scriptdir(), "../src/infinicub")) + + local dir = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "USERPROFILE") .. "/.infini") + set_installdir(path.join(dir,"packages/infinicub/", get_config("plat"), get_config("arch"), get_config("mode"))) + + add_configs("shared", {default = false, type = "boolean", readonly = true}) + on_install(function (package) + local configs = {} + import("package.tools.xmake").install(package, configs) + end) +package_end() diff --git a/xmake/cuda.lua b/xmake/cuda.lua index b9369fa7b..408ea1a9e 100644 --- a/xmake/cuda.lua +++ b/xmake/cuda.lua @@ -8,6 +8,9 @@ if CUDNN_ROOT ~= nil then add_includedirs(CUDNN_ROOT .. "/include") end + +add_requires("infinicub 1.0.0", {optional = true, configs = {shared = false}}) + target("infiniop-cuda") set_kind("static") add_deps("infini-utils") @@ -18,6 +21,12 @@ target("infiniop-cuda") add_links("cublas", "cudnn") add_cugencodes("native") + if has_config("prebuild") then + add_defines("ENABLE_INFINI_CUB") + includes("cub.lua") + add_packages("infinicub") + end + if is_plat("windows") then add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") add_cuflags("-Xcompiler=/W3", "-Xcompiler=/WX") From dc4e577a2c80b8112256e5fa3fec57baa1b2ec76 Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Wed, 9 Jul 2025 01:14:39 +0800 Subject: [PATCH 2/3] =?UTF-8?q?issue/306=20-=20=E5=A4=84=E7=90=86win?= =?UTF-8?q?=E5=92=8Clinux=E7=9A=84=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E4=B8=8D=E4=B8=80=E8=87=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicub/include/cub_algorithms.cuh | 13 ++++++++++--- src/infinicub/src/cub_algorithms.cu | 9 +++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/infinicub/include/cub_algorithms.cuh b/src/infinicub/include/cub_algorithms.cuh index c2280edbc..0af6845b4 100644 --- a/src/infinicub/include/cub_algorithms.cuh +++ b/src/infinicub/include/cub_algorithms.cuh @@ -159,6 +159,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -168,6 +169,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( ptrdiff_t *val_out, int n, cudaStream_t stream); +#endif // -------------------------------------------------------------- // -------------------------------------------------------------- @@ -252,6 +254,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -261,6 +264,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( ptrdiff_t *val_out, int n, cudaStream_t stream); +#endif // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- @@ -344,6 +348,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -353,7 +358,8 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( ptrdiff_t *val_out, int n, cudaStream_t stream); - +#endif + // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- @@ -437,7 +443,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); - +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -447,7 +453,8 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( ptrdiff_t *val_out, int n, cudaStream_t stream); - +#endif + } // namespace infini_cub #endif // __CUB_ALGORITHMS_CUH__ diff --git a/src/infinicub/src/cub_algorithms.cu b/src/infinicub/src/cub_algorithms.cu index c61ac14c9..28d6bf5f2 100644 --- a/src/infinicub/src/cub_algorithms.cu +++ b/src/infinicub/src/cub_algorithms.cu @@ -235,6 +235,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -252,6 +253,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( 0, sizeof(half) * 8, stream); } +#endif // -------------------------------------------------------------- // -------------------------------------------------------------- @@ -400,6 +402,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -417,6 +420,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( 0, sizeof(half) * 8, stream); } +#endif // -------------------------------------------------------------- // -------------------------------------------------------------- @@ -565,6 +569,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -582,6 +587,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( 0, sizeof(float) * 8, stream); } +#endif // -------------------------------------------------------------- // -------------------------------------------------------------- @@ -729,6 +735,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } +#ifdef _WIN64 cudaError cub_DeviceRadixSort_SortPairsDescending( void *workspace_ptr, size_t &workspace_len, @@ -745,4 +752,6 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( 0, sizeof(double) * 8, stream); } +#endif + } // namespace infini_cub From 1fd22e8715df615ba47d00811ea130c2d1d09652 Mon Sep 17 00:00:00 2001 From: pengcheng888 <1033693766@qq.com> Date: Wed, 9 Jul 2025 16:21:06 +0800 Subject: [PATCH 3/3] =?UTF-8?q?issue/306=20-=20=E4=BF=AE=E6=94=B9=E4=BA=86?= =?UTF-8?q?=E5=91=BD=E5=90=8D=E7=A9=BA=E9=97=B4=E3=80=81=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=AE=8F=E3=80=81=E6=96=87=E4=BB=B6=E9=87=8D=E5=91=BD=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...gorithms.cuh => cub_algorithms_nvidia.cuh} | 112 +++++++++--------- ...algorithms.cu => cub_algorithms_nvidia.cu} | 110 ++++++++--------- src/infinicub/xmake.lua | 4 +- .../cuda/random_sample_kernel.cuh | 12 +- xmake/cub.lua | 3 + xmake/cuda.lua | 4 +- 6 files changed, 127 insertions(+), 118 deletions(-) rename src/infinicub/include/{cub_algorithms.cuh => cub_algorithms_nvidia.cuh} (78%) rename src/infinicub/src/{cub_algorithms.cu => cub_algorithms_nvidia.cu} (87%) diff --git a/src/infinicub/include/cub_algorithms.cuh b/src/infinicub/include/cub_algorithms_nvidia.cuh similarity index 78% rename from src/infinicub/include/cub_algorithms.cuh rename to src/infinicub/include/cub_algorithms_nvidia.cuh index 0af6845b4..7dabe01f3 100644 --- a/src/infinicub/include/cub_algorithms.cuh +++ b/src/infinicub/include/cub_algorithms_nvidia.cuh @@ -1,14 +1,16 @@ -#ifndef __CUB_ALGORITHMS_CUH__ -#define __CUB_ALGORITHMS_CUH__ +#ifndef __CUB_ALGORITHMS_NVIDIA_CUH__ +#define __CUB_ALGORITHMS_NVIDIA_CUH__ + +#ifdef ENABLE_CUDA_API #include #include #include #include -namespace infini_cub { +namespace infini_cub::nvidai::DeviceReduce { -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *logits, @@ -16,7 +18,7 @@ cudaError cub_DeviceReduce_ArgMax( int n, cudaStream_t stream); -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const half *logits, @@ -24,7 +26,7 @@ cudaError cub_DeviceReduce_ArgMax( int n, cudaStream_t stream); -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const float *logits, @@ -32,54 +34,53 @@ cudaError cub_DeviceReduce_ArgMax( int n, cudaStream_t stream); -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const double *logits, cub::KeyValuePair *kv_pair, int n, cudaStream_t stream); +} // namespace infini_cub::nvidai::DeviceReduce -} // namespace infini_cub - -namespace infini_cub { +namespace infini_cub::nvidai::DeviceReduce { -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, __nv_bfloat16 *data, int n, cudaStream_t stream); -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, half *data, int n, cudaStream_t stream); -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, float *data, int n, cudaStream_t stream); -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, double *data, int n, cudaStream_t stream); -} // namespace infini_cub +} // namespace infini_cub::nvidai::DeviceReduce -namespace infini_cub { +namespace infini_cub::nvidai::DeviceRadixSort { using uchar = uint8_t; using ushort = uint16_t; using uint = uint32_t; using ulong = uint64_t; -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -89,7 +90,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -99,7 +100,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -109,7 +110,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -119,7 +120,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -129,7 +130,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -139,7 +140,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -149,7 +150,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -160,7 +161,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( cudaStream_t stream); #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -174,7 +175,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -184,7 +185,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -194,7 +195,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -204,7 +205,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -214,7 +215,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -224,7 +225,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -234,7 +235,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -244,7 +245,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -255,7 +256,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( cudaStream_t stream); #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -265,10 +266,11 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); #endif + // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -278,7 +280,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -288,7 +290,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -298,7 +300,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -308,7 +310,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -318,7 +320,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -328,7 +330,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -338,7 +340,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -349,7 +351,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( cudaStream_t stream); #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -363,7 +365,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -373,7 +375,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -383,7 +385,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -393,7 +395,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -403,7 +405,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -413,7 +415,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -423,7 +425,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -433,7 +435,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( int n, cudaStream_t stream); -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -444,7 +446,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( cudaStream_t stream); #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -455,6 +457,6 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( cudaStream_t stream); #endif -} // namespace infini_cub - -#endif // __CUB_ALGORITHMS_CUH__ +} // namespace infini_cub::nvidai::DeviceRadixSort +#endif // ENABLE_CUDA_API +#endif // __CUB_ALGORITHMS_NVIDIA_CUH__ diff --git a/src/infinicub/src/cub_algorithms.cu b/src/infinicub/src/cub_algorithms_nvidia.cu similarity index 87% rename from src/infinicub/src/cub_algorithms.cu rename to src/infinicub/src/cub_algorithms_nvidia.cu index 28d6bf5f2..da96b6507 100644 --- a/src/infinicub/src/cub_algorithms.cu +++ b/src/infinicub/src/cub_algorithms_nvidia.cu @@ -1,12 +1,14 @@ -#include "cub_algorithms.cuh" +#ifdef ENABLE_CUDA_API + +#include "cub_algorithms_nvidia.cuh" #include #include #include #include -namespace infini_cub { +namespace infini_cub::nvidai::DeviceReduce { -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *logits, @@ -17,7 +19,7 @@ cudaError cub_DeviceReduce_ArgMax( return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); } -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const half *logits, @@ -27,7 +29,7 @@ cudaError cub_DeviceReduce_ArgMax( return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); } -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const float *logits, @@ -37,7 +39,7 @@ cudaError cub_DeviceReduce_ArgMax( return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); } -cudaError cub_DeviceReduce_ArgMax( +cudaError ArgMax( void *workspace_ptr, size_t &workspace_len, const double *logits, @@ -46,11 +48,11 @@ cudaError cub_DeviceReduce_ArgMax( cudaStream_t stream) { return cub::DeviceReduce::ArgMax(workspace_ptr, workspace_len, logits, kv_pair, n, stream); } -} // namespace infini_cub +} // namespace infini_cub::nvidai::DeviceReduce -namespace infini_cub { +namespace infini_cub::nvidai::DeviceReduce { -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, __nv_bfloat16 *data, @@ -59,7 +61,7 @@ cudaError cub_DeviceScan_InclusiveSum( return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); } -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, half *data, @@ -68,7 +70,7 @@ cudaError cub_DeviceScan_InclusiveSum( return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); } -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, float *data, @@ -77,7 +79,7 @@ cudaError cub_DeviceScan_InclusiveSum( return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); } -cudaError cub_DeviceScan_InclusiveSum( +cudaError InclusiveSum( void *workspace_ptr, size_t &workspace_len, double *data, @@ -85,13 +87,11 @@ cudaError cub_DeviceScan_InclusiveSum( cudaStream_t stream) { return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, n, stream); } -} // namespace infini_cub +} // namespace infini_cub::nvidai::DeviceReduce -namespace infini_cub { -// -------------------------------------------------------------- -// -------------------------------------------------------------- -// -------------------------------------------------------------- -cudaError cub_DeviceRadixSort_SortPairsDescending( +namespace infini_cub::nvidai::DeviceRadixSort { + +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -109,7 +109,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -127,7 +127,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -145,7 +145,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -163,7 +163,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -181,7 +181,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -199,7 +199,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -217,7 +217,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -236,7 +236,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( } #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const __nv_bfloat16 *key_in, @@ -258,7 +258,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -276,7 +276,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -294,7 +294,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -312,7 +312,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -330,7 +330,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -348,7 +348,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -366,7 +366,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -384,7 +384,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -403,7 +403,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( } #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const half *key_in, @@ -425,7 +425,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -443,7 +443,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -461,7 +461,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -479,7 +479,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -497,7 +497,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -515,7 +515,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -533,7 +533,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -551,7 +551,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -570,7 +570,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( } #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const float *key_in, @@ -592,7 +592,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( // -------------------------------------------------------------- // -------------------------------------------------------------- // -------------------------------------------------------------- -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -610,7 +610,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -628,7 +628,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -646,7 +646,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -664,7 +664,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -682,7 +682,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -700,7 +700,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -718,7 +718,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( stream); } -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -736,7 +736,7 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( } #ifdef _WIN64 -cudaError cub_DeviceRadixSort_SortPairsDescending( +cudaError SortPairsDescending( void *workspace_ptr, size_t &workspace_len, const double *key_in, @@ -754,4 +754,6 @@ cudaError cub_DeviceRadixSort_SortPairsDescending( } #endif -} // namespace infini_cub +} // namespace infini_cub::nvidai::DeviceRadixSort + +#endif // ENABLE_CUDA_API \ No newline at end of file diff --git a/src/infinicub/xmake.lua b/src/infinicub/xmake.lua index c7d9e5dc5..932afdc7c 100644 --- a/src/infinicub/xmake.lua +++ b/src/infinicub/xmake.lua @@ -18,6 +18,8 @@ target("infinicub") set_toolchains("cuda") add_links("cublas", "cudnn") add_cugencodes("native") + + add_defines("ENABLE_CUDA_API") if is_plat("windows") then add_cuflags("-Xcompiler=/utf-8", "--expt-relaxed-constexpr", "--allow-unsupported-compiler") @@ -37,5 +39,5 @@ target("infinicub") set_languages("cxx17") add_includedirs("include") - add_files("src/cub_algorithms.cu") + add_files("src/*.cu") target_end() diff --git a/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh b/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh index 1a1a62d27..ae013d57f 100644 --- a/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh +++ b/src/infiniop/ops/random_sample/cuda/random_sample_kernel.cuh @@ -2,7 +2,7 @@ #include "infinicore.h" #ifdef ENABLE_INFINI_CUB -#include "../../../../infinicub/include/cub_algorithms.cuh" +#include "../../../../infinicub/include/cub_algorithms_nvidia.cuh" #else #include #include @@ -22,7 +22,7 @@ static cudaError argMax_( cudaStream_t stream) { #ifdef ENABLE_INFINI_CUB - return infini_cub::cub_DeviceReduce_ArgMax( + return infini_cub::nvidai::DeviceReduce::ArgMax( workspace_ptr, workspace_len, logits, kv_pair, n, stream); @@ -43,7 +43,7 @@ static cudaError radixSort( cudaStream_t stream) { #ifdef ENABLE_INFINI_CUB - return infini_cub::cub_DeviceRadixSort_SortPairsDescending( + return infini_cub::nvidai::DeviceRadixSort::SortPairsDescending( workspace_ptr, workspace_len, key_in, key_out, val_in, val_out, @@ -65,15 +65,15 @@ static cudaError inclusiveSum( void *workspace_ptr, size_t &workspace_len, T *data, int n, cudaStream_t stream) { - + #ifdef ENABLE_INFINI_CUB - return infini_cub::cub_DeviceScan_InclusiveSum( + return infini_cub::nvidai::DeviceReduce::InclusiveSum( workspace_ptr, workspace_len, data, n, stream); #else return cub::DeviceScan::InclusiveSum( - workspace_ptr, workspace_len, + workspace_ptr, workspace_len, data, data, n, stream); #endif diff --git a/xmake/cub.lua b/xmake/cub.lua index 397d348bd..914c923c4 100644 --- a/xmake/cub.lua +++ b/xmake/cub.lua @@ -3,6 +3,9 @@ package("infinicub") add_versions("1.0.0", "commit-hash-or-sha256-for-v1.0.0") add_versions("1.0.1", "commit-hash-or-sha256-for-v1.0.1") + add_versions("1.0.2", "commit-hash-or-sha256-for-v1.0.2") + add_versions("1.0.3", "commit-hash-or-sha256-for-v1.0.3") + add_versions("1.0.4", "commit-hash-or-sha256-for-v1.0.4") set_sourcedir(path.join(os.scriptdir(), "../src/infinicub")) diff --git a/xmake/cuda.lua b/xmake/cuda.lua index 408ea1a9e..6d0d41664 100644 --- a/xmake/cuda.lua +++ b/xmake/cuda.lua @@ -9,7 +9,7 @@ if CUDNN_ROOT ~= nil then end -add_requires("infinicub 1.0.0", {optional = true, configs = {shared = false}}) +add_requires("infinicub 1.0.0", {optional = true, configs = {shared = false, defines="ENABLE_CUDA_API" }}) target("infiniop-cuda") set_kind("static") @@ -21,7 +21,7 @@ target("infiniop-cuda") add_links("cublas", "cudnn") add_cugencodes("native") - if has_config("prebuild") then + if has_config("prebuild") and has_config("nv-gpu") then add_defines("ENABLE_INFINI_CUB") includes("cub.lua") add_packages("infinicub")