|
| 1 | +#include "../../../devices/maca/maca_kernel_common.h" |
| 2 | +#include "infinicore.h" |
| 3 | +#include <hccub/device/device_radix_sort.cuh> |
| 4 | +#include <hccub/device/device_reduce.cuh> |
| 5 | +#include <hccub/device/device_scan.cuh> |
| 6 | + |
| 7 | +namespace op::random_sample::maca { |
| 8 | + |
| 9 | +// ↓↓↓ 重新封装 cub api,减少模板参数,方便调用 |
| 10 | + |
| 11 | +template <class T> |
| 12 | +static hcError_t argMax_( |
| 13 | + cub::KeyValuePair<int, T> *kv_pair, |
| 14 | + const T *logits, |
| 15 | + int n, |
| 16 | + void *workspace_ptr, |
| 17 | + size_t &workspace_len, |
| 18 | + hcStream_t stream) { |
| 19 | + return cub::DeviceReduce::ArgMax( |
| 20 | + workspace_ptr, workspace_len, |
| 21 | + logits, kv_pair, n, |
| 22 | + stream); |
| 23 | +} |
| 24 | + |
| 25 | +template <class Tval, class Tidx> |
| 26 | +static hcError_t radixSort( |
| 27 | + void *workspace_ptr, size_t &workspace_len, |
| 28 | + const Tval *key_in, Tval *key_out, |
| 29 | + const Tidx *val_in, Tidx *val_out, |
| 30 | + int n, |
| 31 | + hcStream_t stream) { |
| 32 | + return cub::DeviceRadixSort::SortPairsDescending( |
| 33 | + workspace_ptr, workspace_len, |
| 34 | + key_in, key_out, |
| 35 | + val_in, val_out, |
| 36 | + n, |
| 37 | + 0, sizeof(Tval) * 8, |
| 38 | + stream); |
| 39 | +} |
| 40 | + |
| 41 | +template <class T> |
| 42 | +static hcError_t inclusiveSum( |
| 43 | + void *workspace_ptr, size_t &workspace_len, |
| 44 | + T *data, int n, |
| 45 | + hcStream_t stream) { |
| 46 | + return cub::DeviceScan::InclusiveSum( |
| 47 | + workspace_ptr, workspace_len, |
| 48 | + data, data, n, |
| 49 | + stream); |
| 50 | +} |
| 51 | + |
| 52 | +// ↑↑↑ 重新封装 cub api,减少模板参数,方便调用 |
| 53 | +// ↓↓↓ 计算 workspace |
| 54 | + |
| 55 | +// 地址对齐到 256 |
| 56 | +static constexpr size_t align256(size_t size) { |
| 57 | + return (size + 255) & (~255); |
| 58 | +} |
| 59 | + |
| 60 | +template <class Tidx, class Tval> |
| 61 | +utils::Result<size_t> calculateWorkspace(size_t n_) { |
| 62 | + const auto n = static_cast<int>(n_); |
| 63 | + |
| 64 | + size_t argmax; |
| 65 | + CHECK_MACA(argMax_<Tval>( |
| 66 | + nullptr, nullptr, n, |
| 67 | + nullptr, argmax, |
| 68 | + nullptr)); |
| 69 | + // 前 256 字节用于 kv pair |
| 70 | + argmax += 256; |
| 71 | + |
| 72 | + // indices |
| 73 | + size_t size_random = align256(sizeof(Tidx) * n); |
| 74 | + // sorted |
| 75 | + size_random += align256(sizeof(Tval) * n); |
| 76 | + // indices_out |
| 77 | + size_random += align256(sizeof(Tidx) * n); |
| 78 | + // cub device api |
| 79 | + size_t size_radix_sort; |
| 80 | + CHECK_MACA((radixSort<Tval, Tidx>( |
| 81 | + nullptr, size_radix_sort, |
| 82 | + nullptr, nullptr, |
| 83 | + nullptr, nullptr, |
| 84 | + n, |
| 85 | + nullptr))); |
| 86 | + |
| 87 | + size_t size_inclusive_sum; |
| 88 | + CHECK_MACA(inclusiveSum<Tval>( |
| 89 | + nullptr, size_inclusive_sum, |
| 90 | + nullptr, n, |
| 91 | + nullptr)); |
| 92 | + size_random += cub::Max()(size_radix_sort, size_inclusive_sum); |
| 93 | + |
| 94 | + return utils::Result<size_t>(cub::Max()(argmax, size_random)); |
| 95 | +} |
| 96 | + |
| 97 | +// ↑↑↑ 计算 workspace |
| 98 | +// ↓↓↓ 通过特化将 fp16_t 转换为 half |
| 99 | + |
| 100 | +template <class Tval> |
| 101 | +struct CudaTval { |
| 102 | + using Type = Tval; |
| 103 | +}; |
| 104 | + |
| 105 | +template <> |
| 106 | +struct CudaTval<fp16_t> { |
| 107 | + using Type = half; |
| 108 | +}; |
| 109 | + |
| 110 | +// ↑↑↑ 通过特化将 fp16_t 转换为 half |
| 111 | +// ↓↓↓ 用于采样过程的小型 kernel |
| 112 | + |
| 113 | +// maca toolkit 11.x 带的 cub::DeviceReduce::ArgMax 只接受 cub::KeyValuePair<int, Tval> 输出。 |
| 114 | +// 这个 kernel 用于取出序号 |
| 115 | +template <class Tidx, class Tval> |
| 116 | +static __global__ void castIdx(Tidx *result, const cub::KeyValuePair<int, Tval> *kv_pair) { |
| 117 | + *result = kv_pair->key; |
| 118 | +} |
| 119 | + |
| 120 | +// 填充排序要求的序号数组 |
| 121 | +template <class Tidx> |
| 122 | +static __global__ void fillIndices(Tidx *indices, int n) { |
| 123 | + int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 124 | + if (i < n) { |
| 125 | + indices[i] = i; |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +// random sample 使用的 softmax 可以简化为一个基本的线性映射 |
| 130 | +// 由于已经排序,最大值就是第一个数字 |
| 131 | +// 第一个数字需要被多个 block 读取,不能写 |
| 132 | +template <class T> |
| 133 | +static __global__ void partialSoftmaxKernel( |
| 134 | + T *__restrict__ data, int n, |
| 135 | + float temperature) { |
| 136 | + int i = blockIdx.x * blockDim.x + threadIdx.x; |
| 137 | + if (0 < i && i < n) { |
| 138 | + float max = __ldg(data); |
| 139 | + data[i] = (T)expf(((float)data[i] - max) / temperature); |
| 140 | + } |
| 141 | +} |
| 142 | + |
| 143 | +// 将第一个数字写成 1,即 exp(0) |
| 144 | +template <class T> |
| 145 | +static __global__ void setSoftmaxMaxKernel( |
| 146 | + T *__restrict__ data) { |
| 147 | + *data = 1; |
| 148 | +} |
| 149 | + |
| 150 | +// 直接 for 循环遍历采样 |
| 151 | +// 这个 kernel 仅用于避免将数据拷贝到 cpu |
| 152 | +template <class Tval, class Tidx> |
| 153 | +static __global__ void randomSampleKernel( |
| 154 | + Tidx *__restrict__ result, |
| 155 | + const Tval *__restrict__ sorted, |
| 156 | + const Tidx *__restrict__ indices_out, |
| 157 | + size_t n, |
| 158 | + float random, float topp, size_t topk) { |
| 159 | + topk = cub::Min()(topk, n); |
| 160 | + auto p = (Tval)(random * cub::Min()(topp * (float)sorted[n - 1], (float)sorted[topk - 1])); |
| 161 | + for (size_t i = 0;; ++i) { |
| 162 | + if ((sorted[i]) >= p) { |
| 163 | + *result = indices_out[i]; |
| 164 | + return; |
| 165 | + } |
| 166 | + } |
| 167 | +} |
| 168 | + |
| 169 | +// ↑↑↑ 用于采样过程的小型 kernel |
| 170 | + |
| 171 | +struct Algo { |
| 172 | + int block_size; |
| 173 | + |
| 174 | + template <class Tidx, class Tval_> |
| 175 | + infiniStatus_t argmax( |
| 176 | + void *workspace, size_t workspace_size, |
| 177 | + void *result, const void *probs, size_t n, |
| 178 | + void *stream_) const { |
| 179 | + |
| 180 | + using Tval = typename CudaTval<Tval_>::Type; |
| 181 | + |
| 182 | + auto stream = (hcStream_t)stream_; |
| 183 | + auto logits = (Tval *)probs; |
| 184 | + auto kv_pair = (cub::KeyValuePair<int, Tval> *)workspace; |
| 185 | + workspace = (void *)((char *)workspace + 256); |
| 186 | + workspace_size -= 256; |
| 187 | + |
| 188 | + argMax_( |
| 189 | + kv_pair, |
| 190 | + logits, |
| 191 | + n, |
| 192 | + workspace, |
| 193 | + workspace_size, stream); |
| 194 | + castIdx<<<1, 1, 0, stream>>>((Tidx *)result, kv_pair); |
| 195 | + |
| 196 | + return INFINI_STATUS_SUCCESS; |
| 197 | + } |
| 198 | + |
| 199 | + template <class Tidx, class Tval_> |
| 200 | + infiniStatus_t random( |
| 201 | + void *workspace_, size_t workspace_size, |
| 202 | + void *result_, const void *probs, size_t n, |
| 203 | + float random_val, float topp, int topk, float temperature, |
| 204 | + void *stream_) const { |
| 205 | + |
| 206 | + using Tval = typename CudaTval<Tval_>::Type; |
| 207 | + |
| 208 | + auto stream = (hcStream_t)stream_; |
| 209 | + auto logits = (Tval *)probs; |
| 210 | + auto result = (Tidx *)result_; |
| 211 | + |
| 212 | + auto workspace = reinterpret_cast<size_t>(workspace_); |
| 213 | + auto workspace_end = workspace + workspace_size; |
| 214 | + |
| 215 | + auto indices = reinterpret_cast<Tidx *>(workspace); |
| 216 | + workspace += align256(sizeof(Tidx) * n); |
| 217 | + |
| 218 | + auto sorted = reinterpret_cast<Tval *>(workspace); |
| 219 | + workspace += align256(sizeof(Tval) * n); |
| 220 | + |
| 221 | + auto indices_out = reinterpret_cast<Tidx *>(workspace); |
| 222 | + workspace += align256(sizeof(Tidx) * n); |
| 223 | + |
| 224 | + workspace_ = reinterpret_cast<void *>(workspace); |
| 225 | + workspace_size = workspace_end - workspace; |
| 226 | + |
| 227 | + auto block = cub::Min()((size_t)block_size, n); |
| 228 | + auto grid = (n + block - 1) / block; |
| 229 | + // sort |
| 230 | + fillIndices<<<grid, block, 0, stream>>>(indices, n); |
| 231 | + CHECK_MACA(radixSort( |
| 232 | + workspace_, workspace_size, |
| 233 | + logits, sorted, |
| 234 | + indices, indices_out, |
| 235 | + n, |
| 236 | + stream)); |
| 237 | + // softmax |
| 238 | + partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature); |
| 239 | + setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted); |
| 240 | + // sum |
| 241 | + CHECK_MACA(inclusiveSum( |
| 242 | + workspace_, workspace, |
| 243 | + sorted, n, |
| 244 | + stream)); |
| 245 | + // sample |
| 246 | + randomSampleKernel<<<1, 1, 0, stream>>>( |
| 247 | + result, |
| 248 | + sorted, indices_out, n, |
| 249 | + random_val, topp, topk); |
| 250 | + return INFINI_STATUS_SUCCESS; |
| 251 | + } |
| 252 | +}; |
| 253 | + |
| 254 | +} // namespace op::random_sample::maca |
0 commit comments