Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/flashinfer_xqa_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

#include "pytorch_extension_utils.h"

void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize,
double qScale, at::Tensor output,
void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads,
int64_t slidingWinSize, double qScale, at::Tensor output,
#if LOW_PREC_OUTPUT
at::Tensor rcpOutScale,
#endif
Expand Down
145 changes: 145 additions & 0 deletions csrc/xqa/gmma.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/

#pragma once
#include "cuda_hint.cuh"
#include "mha_stdheaders.cuh"
#include "utils.cuh"
#ifndef __CUDACC__
#include <cuda_runtime.h>
#endif
#include <cuda_fp16.h>
#include <cuda_fp8.h>

namespace gmma {

enum class SwizzleMode : uint64_t { kNONE = 0, k128 = 1, k64 = 2, k32 = 3 };

struct MatDesc {
uint64_t addr : 16;
uint64_t dimKOffset : 16;
uint64_t dimMNOffset : 16;
uint64_t pad0 : 1;
uint64_t baseOffset : 3;
uint64_t pad1 : 10;
SwizzleMode swizzle : 2;

enum class Raw : uint64_t {};

[[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const {
MatDesc ret = *this;
ret.addr = encode(__cvta_generic_to_shared(data));
return ret;
}

static __device__ inline uint32_t encode(uint32_t val) { return (val & 0x3FFFFU) >> 4; }

__device__ inline bool operator==(MatDesc const& other) const { return raw() == other.raw(); }

__device__ inline Raw const& raw() const {
static_assert(sizeof(MatDesc) == 8);
return reinterpret_cast<Raw const&>(*this);
}

static __device__ inline MatDesc fromRaw(Raw const& raw) {
return reinterpret_cast<MatDesc const&>(raw);
}
};

static_assert(sizeof(MatDesc) == 8);

[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) {
assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0);
MatDesc::Raw ret = base;
auto& u32x2 = reinterpret_cast<uint32_t(&)[2]>(ret);
u32x2[0] += static_cast<uint32_t>(__cvta_generic_to_shared(data)) >> 4;
return ret;
}

__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
uint32_t dimMNByteOffset, void const* patternStartAddr,
SwizzleMode swizzleMode) {
uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr);
uint32_t const baseAlign = [&]() -> uint32_t {
switch (swizzleMode) {
case SwizzleMode::kNONE:
return 1;
case SwizzleMode::k128:
return 1024;
case SwizzleMode::k64:
return 512;
case SwizzleMode::k32:
return 256;
}
asm volatile("trap;\n");
return 0;
}();
uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7));
return MatDesc{
/*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)),
/*dimKOffset=*/MatDesc::encode(dimKByteOffset),
/*dimMNOffset=*/MatDesc::encode(dimMNByteOffset),
/*pad0=*/0,
/*baseOffset=*/baseOffset,
/*pad1=*/0,
/*swizzle=*/swizzleMode,
};
}

__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset,
uint32_t dimMNByteOffset, SwizzleMode swizzleMode) {
return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode);
}

inline constexpr uint32_t instM = 64;

template <typename MathElem>
inline constexpr uint32_t instK = 32 / sizeof(MathElem);

inline constexpr uint32_t instNBase = 8;

// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N
// acc is used as both input and output.
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
__device__ void mma_async_shmA(float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA,
MatDesc::Raw descB, bool accHasVal);
template <typename InputElem, uint32_t n, bool transA = false, bool transB = false>
__device__ void mma_async_regA(float (&acc)[exactDiv(n, instNBase)][2][2],
uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal);

__device__ inline void fence() { asm volatile("wgmma.fence.sync.aligned;\n"); }

__device__ inline void commit_group() { asm volatile("wgmma.commit_group.sync.aligned;\n"); }

template <uint32_t targetNbInFlightGroups>
__device__ inline void wait_group() {
asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups));
}

template <bool swizzle, typename T, uint32_t rows, uint32_t cols, bool alignedForSwizzle>
constexpr SwizzleMode getSwizzleMode(Array2D<T, rows, cols, alignedForSwizzle> const&) {
constexpr auto rowBytes = Array2D<T, rows, cols, alignedForSwizzle>::rowBytes;
if constexpr (!swizzle) {
return SwizzleMode::kNONE;
}
if constexpr (rowBytes % 128 == 0) {
return SwizzleMode::k128;
} else if constexpr (rowBytes == 64) {
return SwizzleMode::k64;
} else {
static_assert(rowBytes == 32);
return SwizzleMode::k32;
}
}
} // namespace gmma

#include "gmma_impl.cuh"
Loading