|
| 1 | +#include "../../utilities/cuda/cublaslt_utils.cuh" |
| 2 | +#include "cuda_kernel.hh" |
| 3 | +#include "hardware/functions.h" |
| 4 | + |
| 5 | +namespace refactor::kernel { |
| 6 | + using K = AttentionCuda; |
| 7 | + using namespace cublas; |
| 8 | + |
| 9 | + RoutineWorkspace K::lower(Resources &res) const { |
| 10 | + auto handle = res.fetchOrStore<CublasLtContext>()->handle; |
| 11 | + |
| 12 | + constexpr auto ROW_MAJOR = CUBLASLT_ORDER_ROW; |
| 13 | + constexpr auto COL_MAJOR = CUBLASLT_ORDER_COL; |
| 14 | + |
| 15 | + if (!info.cacheLen) { |
| 16 | + if (info.nHead == info.nKVHead) { |
| 17 | + // RAII for closure |
| 18 | + struct Descriptors { |
| 19 | + MatMulDescriptor mul; |
| 20 | + MatrixDescriptor q, k, v, att; |
| 21 | + cublasLtMatmulAlgo_t algoQK, algoAV; |
| 22 | + size_t attSize, workspaceSizeQK, workspaceSizeAV; |
| 23 | + |
| 24 | + Descriptors(CublasLtContext const &context, |
| 25 | + cublasComputeType_t compute, |
| 26 | + AttentionInfo info) |
| 27 | + : mul(compute, CUDA_R_32F), |
| 28 | + q(MatrixLayout{ |
| 29 | + .dataType = dataTypeConvert(info.dataType), |
| 30 | + .rows = static_cast<uint64_t>(info.seqLen), |
| 31 | + .cols = static_cast<uint64_t>(info.headDim), |
| 32 | + .majorStride = static_cast<int64_t>(info.headDim), |
| 33 | + .order = ROW_MAJOR, |
| 34 | + .batchCount = static_cast<int32_t>(info.batch * info.nHead), |
| 35 | + .batchStride = static_cast<int64_t>(info.seqLen * info.headDim), |
| 36 | + }), |
| 37 | + k(MatrixLayout{ |
| 38 | + .dataType = dataTypeConvert(info.dataType), |
| 39 | + .rows = static_cast<uint64_t>(info.headDim), |
| 40 | + .cols = static_cast<uint64_t>(info.seqLen), |
| 41 | + .majorStride = static_cast<int64_t>(info.headDim), |
| 42 | + .order = COL_MAJOR, |
| 43 | + .batchCount = static_cast<int32_t>(info.batch * info.nHead), |
| 44 | + .batchStride = static_cast<int64_t>(info.seqLen * info.headDim), |
| 45 | + }), |
| 46 | + v(MatrixLayout{ |
| 47 | + .dataType = dataTypeConvert(info.dataType), |
| 48 | + .rows = static_cast<uint64_t>(info.seqLen), |
| 49 | + .cols = static_cast<uint64_t>(info.headDim), |
| 50 | + .majorStride = static_cast<int64_t>(info.headDim), |
| 51 | + .order = ROW_MAJOR, |
| 52 | + .batchCount = static_cast<int32_t>(info.batch * info.nHead), |
| 53 | + .batchStride = static_cast<int64_t>(info.seqLen * info.headDim), |
| 54 | + }), |
| 55 | + att(MatrixLayout{ |
| 56 | + .dataType = dataTypeConvert(info.dataType), |
| 57 | + .rows = static_cast<uint64_t>(info.seqLen), |
| 58 | + .cols = static_cast<uint64_t>(info.seqLen), |
| 59 | + .majorStride = static_cast<int64_t>(info.seqLen), |
| 60 | + .order = ROW_MAJOR, |
| 61 | + .batchCount = static_cast<int32_t>(info.batch * info.nHead), |
| 62 | + .batchStride = static_cast<int64_t>(info.seqLen * info.seqLen), |
| 63 | + }), |
| 64 | + attSize(info.batch * info.nHead * info.seqLen * info.seqLen * info.dataType.size()) { |
| 65 | + auto [algoQK_, workspaceSizeQK_] = tune(context.handle, mul, q, k, att); |
| 66 | + auto [algoAV_, workspaceSizeAV_] = tune(context.handle, mul, att, v, q); |
| 67 | + algoQK = algoQK_; |
| 68 | + algoAV = algoAV_; |
| 69 | + workspaceSizeQK = workspaceSizeQK_; |
| 70 | + workspaceSizeAV = workspaceSizeAV_; |
| 71 | + } |
| 72 | + }; |
| 73 | + |
| 74 | + auto const &context = *res.fetchOrStore<CublasLtContext>(); |
| 75 | + auto d = std::make_shared<Descriptors>(context, CUBLAS_COMPUTE_32F, info); |
| 76 | + auto workspaceSize = d->attSize; |
| 77 | + workspaceSize = hardware::alignBytes(workspaceSize, 256); |
| 78 | + workspaceSize += d->workspaceSizeQK; |
| 79 | + workspaceSize = hardware::alignBytes(workspaceSize, 256); |
| 80 | + workspaceSize += d->workspaceSizeAV; |
| 81 | + workspaceSize = hardware::alignBytes(workspaceSize, 256); |
| 82 | + |
| 83 | + auto routine = [d = std::move(d), info = this->info]// |
| 84 | + (Resources & res, void *workspace, void const *const *inputs, void *const *outputs) { |
| 85 | + auto handle = res.fetchOrStore<CublasLtContext>()->handle; |
| 86 | + auto q = inputs[0]; |
| 87 | + auto k = inputs[1]; |
| 88 | + auto v = inputs[2]; |
| 89 | + auto o = outputs[0]; |
| 90 | + auto att = workspace; |
| 91 | + auto workspaceQK = reinterpret_cast<uint8_t *>(workspace) + hardware::alignBytes(d->attSize, 256); |
| 92 | + auto workspaceAV = workspaceQK + hardware::alignBytes(d->workspaceSizeQK, 256); |
| 93 | + |
| 94 | + float alpha = 1, beta = 0; |
| 95 | + cublasLtMatmul( |
| 96 | + handle, d->mul.get(), |
| 97 | + &alpha, |
| 98 | + q, d->q.get(), |
| 99 | + k, d->k.get(), |
| 100 | + &beta, |
| 101 | + att, d->att.get(), |
| 102 | + att, d->att.get(), |
| 103 | + &d->algoQK, |
| 104 | + workspaceQK, d->workspaceSizeQK, |
| 105 | + cudaStreamLegacy); |
| 106 | + |
| 107 | + // TODO inline mask && softmax |
| 108 | + |
| 109 | + cublasLtMatmul( |
| 110 | + handle, d->mul.get(), |
| 111 | + &alpha, |
| 112 | + att, d->att.get(), |
| 113 | + v, d->v.get(), |
| 114 | + &beta, |
| 115 | + o, d->q.get(), |
| 116 | + o, d->q.get(), |
| 117 | + &d->algoAV, |
| 118 | + workspaceAV, d->workspaceSizeAV, |
| 119 | + cudaStreamLegacy); |
| 120 | + }; |
| 121 | + return {std::move(routine), workspaceSize}; |
| 122 | + } |
| 123 | + } |
| 124 | + TODO(""); |
| 125 | + } |
| 126 | + |
| 127 | +}// namespace refactor::kernel |
0 commit comments