Skip to content

Commit 2f20af7

Browse files
Merge pull request #258 from InfiniTensor/issue/36
issue/36 - Migrate cuda ramdom sample to metax
2 parents 8e96d62 + 7707049 commit 2f20af7

File tree

4 files changed

+379
-0
lines changed

4 files changed

+379
-0
lines changed
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
#include "../../../devices/maca/maca_kernel_common.h"
2+
#include "infinicore.h"
3+
#include <hccub/device/device_radix_sort.cuh>
4+
#include <hccub/device/device_reduce.cuh>
5+
#include <hccub/device/device_scan.cuh>
6+
7+
namespace op::random_sample::maca {
8+
9+
// ↓↓↓ 重新封装 cub api,减少模板参数,方便调用
10+
11+
template <class T>
12+
static hcError_t argMax_(
13+
cub::KeyValuePair<int, T> *kv_pair,
14+
const T *logits,
15+
int n,
16+
void *workspace_ptr,
17+
size_t &workspace_len,
18+
hcStream_t stream) {
19+
return cub::DeviceReduce::ArgMax(
20+
workspace_ptr, workspace_len,
21+
logits, kv_pair, n,
22+
stream);
23+
}
24+
25+
template <class Tval, class Tidx>
26+
static hcError_t radixSort(
27+
void *workspace_ptr, size_t &workspace_len,
28+
const Tval *key_in, Tval *key_out,
29+
const Tidx *val_in, Tidx *val_out,
30+
int n,
31+
hcStream_t stream) {
32+
return cub::DeviceRadixSort::SortPairsDescending(
33+
workspace_ptr, workspace_len,
34+
key_in, key_out,
35+
val_in, val_out,
36+
n,
37+
0, sizeof(Tval) * 8,
38+
stream);
39+
}
40+
41+
template <class T>
42+
static hcError_t inclusiveSum(
43+
void *workspace_ptr, size_t &workspace_len,
44+
T *data, int n,
45+
hcStream_t stream) {
46+
return cub::DeviceScan::InclusiveSum(
47+
workspace_ptr, workspace_len,
48+
data, data, n,
49+
stream);
50+
}
51+
52+
// ↑↑↑ 重新封装 cub api,减少模板参数,方便调用
53+
// ↓↓↓ 计算 workspace
54+
55+
// 地址对齐到 256
56+
static constexpr size_t align256(size_t size) {
57+
return (size + 255) & (~255);
58+
}
59+
60+
template <class Tidx, class Tval>
61+
utils::Result<size_t> calculateWorkspace(size_t n_) {
62+
const auto n = static_cast<int>(n_);
63+
64+
size_t argmax;
65+
CHECK_MACA(argMax_<Tval>(
66+
nullptr, nullptr, n,
67+
nullptr, argmax,
68+
nullptr));
69+
// 前 256 字节用于 kv pair
70+
argmax += 256;
71+
72+
// indices
73+
size_t size_random = align256(sizeof(Tidx) * n);
74+
// sorted
75+
size_random += align256(sizeof(Tval) * n);
76+
// indices_out
77+
size_random += align256(sizeof(Tidx) * n);
78+
// cub device api
79+
size_t size_radix_sort;
80+
CHECK_MACA((radixSort<Tval, Tidx>(
81+
nullptr, size_radix_sort,
82+
nullptr, nullptr,
83+
nullptr, nullptr,
84+
n,
85+
nullptr)));
86+
87+
size_t size_inclusive_sum;
88+
CHECK_MACA(inclusiveSum<Tval>(
89+
nullptr, size_inclusive_sum,
90+
nullptr, n,
91+
nullptr));
92+
size_random += cub::Max()(size_radix_sort, size_inclusive_sum);
93+
94+
return utils::Result<size_t>(cub::Max()(argmax, size_random));
95+
}
96+
97+
// ↑↑↑ 计算 workspace
98+
// ↓↓↓ 通过特化将 fp16_t 转换为 half
99+
100+
template <class Tval>
101+
struct CudaTval {
102+
using Type = Tval;
103+
};
104+
105+
template <>
106+
struct CudaTval<fp16_t> {
107+
using Type = half;
108+
};
109+
110+
// ↑↑↑ 通过特化将 fp16_t 转换为 half
111+
// ↓↓↓ 用于采样过程的小型 kernel
112+
113+
// maca toolkit 11.x 带的 cub::DeviceReduce::ArgMax 只接受 cub::KeyValuePair<int, Tval> 输出。
114+
// 这个 kernel 用于取出序号
115+
template <class Tidx, class Tval>
116+
static __global__ void castIdx(Tidx *result, const cub::KeyValuePair<int, Tval> *kv_pair) {
117+
*result = kv_pair->key;
118+
}
119+
120+
// 填充排序要求的序号数组
121+
template <class Tidx>
122+
static __global__ void fillIndices(Tidx *indices, int n) {
123+
int i = blockIdx.x * blockDim.x + threadIdx.x;
124+
if (i < n) {
125+
indices[i] = i;
126+
}
127+
}
128+
129+
// random sample 使用的 softmax 可以简化为一个基本的线性映射
130+
// 由于已经排序,最大值就是第一个数字
131+
// 第一个数字需要被多个 block 读取,不能写
132+
template <class T>
133+
static __global__ void partialSoftmaxKernel(
134+
T *__restrict__ data, int n,
135+
float temperature) {
136+
int i = blockIdx.x * blockDim.x + threadIdx.x;
137+
if (0 < i && i < n) {
138+
float max = __ldg(data);
139+
data[i] = (T)expf(((float)data[i] - max) / temperature);
140+
}
141+
}
142+
143+
// 将第一个数字写成 1,即 exp(0)
144+
template <class T>
145+
static __global__ void setSoftmaxMaxKernel(
146+
T *__restrict__ data) {
147+
*data = 1;
148+
}
149+
150+
// 直接 for 循环遍历采样
151+
// 这个 kernel 仅用于避免将数据拷贝到 cpu
152+
template <class Tval, class Tidx>
153+
static __global__ void randomSampleKernel(
154+
Tidx *__restrict__ result,
155+
const Tval *__restrict__ sorted,
156+
const Tidx *__restrict__ indices_out,
157+
size_t n,
158+
float random, float topp, size_t topk) {
159+
topk = cub::Min()(topk, n);
160+
auto p = (Tval)(random * cub::Min()(topp * (float)sorted[n - 1], (float)sorted[topk - 1]));
161+
for (size_t i = 0;; ++i) {
162+
if ((sorted[i]) >= p) {
163+
*result = indices_out[i];
164+
return;
165+
}
166+
}
167+
}
168+
169+
// ↑↑↑ 用于采样过程的小型 kernel
170+
171+
struct Algo {
172+
int block_size;
173+
174+
template <class Tidx, class Tval_>
175+
infiniStatus_t argmax(
176+
void *workspace, size_t workspace_size,
177+
void *result, const void *probs, size_t n,
178+
void *stream_) const {
179+
180+
using Tval = typename CudaTval<Tval_>::Type;
181+
182+
auto stream = (hcStream_t)stream_;
183+
auto logits = (Tval *)probs;
184+
auto kv_pair = (cub::KeyValuePair<int, Tval> *)workspace;
185+
workspace = (void *)((char *)workspace + 256);
186+
workspace_size -= 256;
187+
188+
argMax_(
189+
kv_pair,
190+
logits,
191+
n,
192+
workspace,
193+
workspace_size, stream);
194+
castIdx<<<1, 1, 0, stream>>>((Tidx *)result, kv_pair);
195+
196+
return INFINI_STATUS_SUCCESS;
197+
}
198+
199+
template <class Tidx, class Tval_>
200+
infiniStatus_t random(
201+
void *workspace_, size_t workspace_size,
202+
void *result_, const void *probs, size_t n,
203+
float random_val, float topp, int topk, float temperature,
204+
void *stream_) const {
205+
206+
using Tval = typename CudaTval<Tval_>::Type;
207+
208+
auto stream = (hcStream_t)stream_;
209+
auto logits = (Tval *)probs;
210+
auto result = (Tidx *)result_;
211+
212+
auto workspace = reinterpret_cast<size_t>(workspace_);
213+
auto workspace_end = workspace + workspace_size;
214+
215+
auto indices = reinterpret_cast<Tidx *>(workspace);
216+
workspace += align256(sizeof(Tidx) * n);
217+
218+
auto sorted = reinterpret_cast<Tval *>(workspace);
219+
workspace += align256(sizeof(Tval) * n);
220+
221+
auto indices_out = reinterpret_cast<Tidx *>(workspace);
222+
workspace += align256(sizeof(Tidx) * n);
223+
224+
workspace_ = reinterpret_cast<void *>(workspace);
225+
workspace_size = workspace_end - workspace;
226+
227+
auto block = cub::Min()((size_t)block_size, n);
228+
auto grid = (n + block - 1) / block;
229+
// sort
230+
fillIndices<<<grid, block, 0, stream>>>(indices, n);
231+
CHECK_MACA(radixSort(
232+
workspace_, workspace_size,
233+
logits, sorted,
234+
indices, indices_out,
235+
n,
236+
stream));
237+
// softmax
238+
partialSoftmaxKernel<<<grid, block, 0, stream>>>(sorted, n, temperature);
239+
setSoftmaxMaxKernel<<<1, 1, 0, stream>>>(sorted);
240+
// sum
241+
CHECK_MACA(inclusiveSum(
242+
workspace_, workspace,
243+
sorted, n,
244+
stream));
245+
// sample
246+
randomSampleKernel<<<1, 1, 0, stream>>>(
247+
result,
248+
sorted, indices_out, n,
249+
random_val, topp, topk);
250+
return INFINI_STATUS_SUCCESS;
251+
}
252+
};
253+
254+
} // namespace op::random_sample::maca
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __RANDOM_SAMPLE_MACA_H__
2+
#define __RANDOM_SAMPLE_MACA_H__
3+
4+
#include "../random_sample.h"
5+
6+
DESCRIPTOR(maca)
7+
8+
#endif // __RANDOM_SAMPLE_MACA_H__
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
#include "../../../devices/maca/common_maca.h"
2+
#include "../../../devices/maca/maca_handle.h"
3+
#include "../info.h"
4+
#include "random_sample_kernel.h"
5+
#include "random_sample_maca.h"
6+
7+
namespace op::random_sample::maca {
8+
9+
struct Descriptor::Opaque {
10+
std::shared_ptr<device::maca::Handle::Internal> internal;
11+
};
12+
13+
Descriptor::~Descriptor() {
14+
delete _opaque;
15+
}
16+
17+
infiniStatus_t Descriptor::create(
18+
infiniopHandle_t handle_,
19+
Descriptor **desc_ptr,
20+
infiniopTensorDescriptor_t result_desc,
21+
infiniopTensorDescriptor_t probs_desc) {
22+
auto handle = reinterpret_cast<device::maca::Handle *>(handle_);
23+
24+
auto result = RandomSampleInfo::create(result_desc, probs_desc);
25+
CHECK_RESULT(result);
26+
27+
auto info = result.take();
28+
size_t workspace_size;
29+
30+
#define CASE_P(CASE, Tidx, Tval) \
31+
case CASE: { \
32+
auto workspace_result = calculateWorkspace<Tidx, Tval>(info.n); \
33+
CHECK_RESULT(workspace_result); \
34+
workspace_size = workspace_result.take(); \
35+
} break
36+
37+
#define CASE_I(CASE, Tidx) \
38+
case CASE: \
39+
switch (info.dt_p) { \
40+
CASE_P(INFINI_DTYPE_F16, Tidx, half); \
41+
CASE_P(INFINI_DTYPE_F32, Tidx, float); \
42+
CASE_P(INFINI_DTYPE_F64, Tidx, double); \
43+
default: \
44+
abort(); \
45+
} \
46+
break
47+
48+
switch (info.dt_i) {
49+
CASE_I(INFINI_DTYPE_I8, int8_t);
50+
CASE_I(INFINI_DTYPE_I16, int16_t);
51+
CASE_I(INFINI_DTYPE_I32, int32_t);
52+
CASE_I(INFINI_DTYPE_I64, int64_t);
53+
CASE_I(INFINI_DTYPE_U8, uint8_t);
54+
CASE_I(INFINI_DTYPE_U16, uint16_t);
55+
CASE_I(INFINI_DTYPE_U32, uint32_t);
56+
CASE_I(INFINI_DTYPE_U64, uint64_t);
57+
default:
58+
abort();
59+
}
60+
61+
#undef CASE_I
62+
#undef CASE_P
63+
64+
*desc_ptr = new Descriptor(
65+
info,
66+
workspace_size,
67+
new Opaque{handle->internal()},
68+
handle->device, handle->device_id);
69+
return INFINI_STATUS_SUCCESS;
70+
}
71+
72+
size_t Descriptor::minWorkspaceSize() const {
73+
return _min_workspace_size;
74+
}
75+
76+
infiniStatus_t Descriptor::calculate(
77+
void *workspace,
78+
size_t workspace_size,
79+
void *result,
80+
const void *probs,
81+
float random_val,
82+
float topp,
83+
int topk,
84+
float temperature,
85+
void *stream) const {
86+
87+
if (workspace_size < _min_workspace_size) {
88+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
89+
}
90+
91+
auto block_size = _opaque->internal->blockSizeX();
92+
93+
Calculate::calculate<Algo>(
94+
Algo{block_size}, _info, workspace, workspace_size,
95+
result, probs,
96+
random_val, topp, topk, temperature,
97+
stream);
98+
99+
return INFINI_STATUS_SUCCESS;
100+
}
101+
102+
} // namespace op::random_sample::maca

0 commit comments

Comments
 (0)