From 4403f7faf060af89ae725d2e72fd5ce89a9a52bf Mon Sep 17 00:00:00 2001 From: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Date: Thu, 25 Sep 2025 00:55:34 -0700 Subject: [PATCH 1/3] add xqa fp8 mha Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- csrc/flashinfer_xqa_ops.cu | 4 +- csrc/xqa/gmma.cuh | 178 ++ csrc/xqa/gmma_impl.cuh | 4586 ++++++++++++++++++++++++++++++++++++ csrc/xqa/mha.cu | 8 +- csrc/xqa/mha.h | 14 + csrc/xqa/mha_sm90.cu | 3504 +++++++++++++++++++++++++++ csrc/xqa/tensorMap.cpp | 93 + csrc/xqa/tensorMap.h | 12 + csrc/xqa/tma.h | 319 +++ csrc/xqa/utils.cuh | 8 +- csrc/xqa/xqa_wrapper.cu | 48 +- flashinfer/aot.py | 18 +- flashinfer/xqa.py | 52 +- tests/test_xqa.py | 41 +- 14 files changed, 8835 insertions(+), 50 deletions(-) create mode 100644 csrc/xqa/gmma.cuh create mode 100644 csrc/xqa/gmma_impl.cuh create mode 100644 csrc/xqa/mha_sm90.cu create mode 100644 csrc/xqa/tensorMap.cpp create mode 100644 csrc/xqa/tensorMap.h create mode 100644 csrc/xqa/tma.h diff --git a/csrc/flashinfer_xqa_ops.cu b/csrc/flashinfer_xqa_ops.cu index 87a614d778..8aff9c007d 100644 --- a/csrc/flashinfer_xqa_ops.cu +++ b/csrc/flashinfer_xqa_ops.cu @@ -16,8 +16,8 @@ #include "pytorch_extension_utils.h" -void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, at::Tensor output, +void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, + int64_t slidingWinSize, double qScale, at::Tensor output, #if LOW_PREC_OUTPUT at::Tensor rcpOutScale, #endif diff --git a/csrc/xqa/gmma.cuh b/csrc/xqa/gmma.cuh new file mode 100644 index 0000000000..f5148d8b9b --- /dev/null +++ b/csrc/xqa/gmma.cuh @@ -0,0 +1,178 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#include "utils.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma +{ + +enum class SwizzleMode : uint64_t +{ + kNONE = 0, + k128 = 1, + k64 = 2, + k32 = 3 +}; + +struct MatDesc +{ + uint64_t addr : 16; + uint64_t dimKOffset : 16; + uint64_t dimMNOffset : 16; + uint64_t pad0 : 1; + uint64_t baseOffset : 3; + uint64_t pad1 : 10; + SwizzleMode swizzle : 2; + + enum class Raw : uint64_t + { + }; + + [[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const + { + MatDesc ret = *this; + ret.addr = encode(__cvta_generic_to_shared(data)); + return ret; + } + + static __device__ inline uint32_t encode(uint32_t val) + { + return (val & 0x3FFFFU) >> 4; + } + + __device__ inline bool operator==(MatDesc const& other) const + { + return raw() == other.raw(); + } + + __device__ inline Raw const& raw() const + { + static_assert(sizeof(MatDesc) == 8); + return reinterpret_cast(*this); + } + + static __device__ inline MatDesc fromRaw(Raw const& raw) + { + return reinterpret_cast(raw); + } +}; + +static_assert(sizeof(MatDesc) == 8); + +[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) +{ + assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0); + MatDesc::Raw ret = base; + auto& u32x2 = reinterpret_cast(ret); + u32x2[0] += static_cast(__cvta_generic_to_shared(data)) >> 4; + return ret; +} + +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset, + void const* patternStartAddr, SwizzleMode swizzleMode) +{ + uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr); + uint32_t const baseAlign = [&]() -> uint32_t + { + switch (swizzleMode) + { + case SwizzleMode::kNONE: return 1; + case SwizzleMode::k128: return 1024; + case SwizzleMode::k64: return 512; + case SwizzleMode::k32: return 256; + } + asm volatile("trap;\n"); + return 0; + }(); + uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); + return MatDesc{ + /*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), + /*dimKOffset=*/MatDesc::encode(dimKByteOffset), + /*dimMNOffset=*/MatDesc::encode(dimMNByteOffset), + /*pad0=*/0, + /*baseOffset=*/baseOffset, + /*pad1=*/0, + /*swizzle=*/swizzleMode, + }; +} + +__device__ inline MatDesc makeMatDesc( + void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset, SwizzleMode swizzleMode) +{ + return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode); +} + +inline constexpr uint32_t instM = 64; + +template +inline constexpr uint32_t instK = 32 / sizeof(MathElem); + +inline constexpr uint32_t instNBase = 8; + +// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N +// acc is used as both input and output. +template +__device__ void mma_async_shmA( + float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal); +template +__device__ void mma_async_regA( + float (&acc)[exactDiv(n, instNBase)][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal); + +__device__ inline void fence() +{ + asm volatile("wgmma.fence.sync.aligned;\n"); +} + +__device__ inline void commit_group() +{ + asm volatile("wgmma.commit_group.sync.aligned;\n"); +} + +template +__device__ inline void wait_group() +{ + asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups)); +} + +template +constexpr SwizzleMode getSwizzleMode(Array2D const&) +{ + constexpr auto rowBytes = Array2D::rowBytes; + if constexpr (!swizzle) + { + return SwizzleMode::kNONE; + } + if constexpr (rowBytes % 128 == 0) + { + return SwizzleMode::k128; + } + else if constexpr (rowBytes == 64) + { + return SwizzleMode::k64; + } + else + { + static_assert(rowBytes == 32); + return SwizzleMode::k32; + } +} +} // namespace gmma + +#include "gmma_impl.cuh" diff --git a/csrc/xqa/gmma_impl.cuh b/csrc/xqa/gmma_impl.cuh new file mode 100644 index 0000000000..96bc721ca3 --- /dev/null +++ b/csrc/xqa/gmma_impl.cuh @@ -0,0 +1,4586 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma +{ +// cog template. Do code generation with: pip install cogapp; cog -r $filename + +// clang-format off +/*[[[cog +import cog +reg_list = lambda beg,end: ", ".join([f"%{i}" for i in range(beg, end)]) +acc_placeholder = lambda n: "{%s}" % reg_list(0, n//2) +acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]), "+f"(acc[{i}][1][1])' for i in range(n//8)]) +ptx_eol = "\\n" +n_list = [8, 16, 24, 32, 64, 128, 256] +for n in n_list: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} + +template<> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + +for n in n_list: + for transA in [0, 1]: + for transB in [0, 1]: + for t,s in [('half', 'f16'), ('__nv_bfloat16', 'bf16')]: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + if transA == 0: + cog.outl(f''' +template<> +__device__ inline void mma_async_regA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') +]]]*/ +// clang-format on + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 8, false, false>( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 16, false, false>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 32, false, false>( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 64, false, false>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 64, false, false>( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 128, false, false>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 128, false, false>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 256, false, false>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 256, false, false>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 0>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 0>( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 1>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 1>( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 0>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 1>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 0>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 0>( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 1>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 1>( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 0>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 0>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 0>( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 1>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 1>( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 0>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 1>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 0>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 0>( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 1>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 1>( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 0>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 1>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 0>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 0>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 1>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 1>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 0>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 1>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 0>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 0>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 1>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 1>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 0>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 1>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) +{ + if (accHasVal) + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } + else + { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +//[[[end]]] +} // namespace gmma diff --git a/csrc/xqa/mha.cu b/csrc/xqa/mha.cu index c896017780..a951e27dca 100644 --- a/csrc/xqa/mha.cu +++ b/csrc/xqa/mha.cu @@ -476,7 +476,7 @@ __device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskTy col + actualQSeqLen < nbValidCols ? true : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart)); - acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : -INFINITY; + acc(m, n)(i, j) = maskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax; } } } @@ -2709,11 +2709,7 @@ void launchMHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32 #if SPEC_DEC mask, #endif - attentionSinks, cacheList, -#if BEAM_WIDTH > 1 - beamSearchParams, -#endif - batchSize, kvCacheScale, semaphores, scratch); + attentionSinks, cacheList, batchSize, kvCacheScale, semaphores, scratch); #else KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; #ifndef NDEBUG diff --git a/csrc/xqa/mha.h b/csrc/xqa/mha.h index 77d8a2fd2f..171524f0b1 100644 --- a/csrc/xqa/mha.h +++ b/csrc/xqa/mha.h @@ -186,6 +186,20 @@ void launchHopperF8MHA( #endif uint32_t* semaphores, void* scratch, cudaStream_t stream); +void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, + uint32_t slidingWinSize, float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + InputHead const* q, float const* attentionSinks, + GMemCacheHead* pool, KVCachePageIndex const* kvCachePageList, + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, +#if SPEC_DEC + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream); + void launchMLA( cudaDeviceProp const& prop, uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu new file mode 100644 index 0000000000..7fa5234874 --- /dev/null +++ b/csrc/xqa/mha_sm90.cu @@ -0,0 +1,3504 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#include "cuda_hint.cuh" +#include "defines.h" +#if !(IS_MLA) +#include "barriers.cuh" +#include "utils.cuh" +#include "utils.h" + +#if SPEC_DEC +#define Q_HEADS_PER_CTA 64 +#include "specDec.h" +#endif + +#ifndef GENERATE_CUBIN +#include "hostUtils.h" +#include "tensorMap.h" +#include +#endif +#include "gmma.cuh" +#include "mha.h" +#include "mhaUtils.cuh" +#include "mha_stdheaders.cuh" +#include "tma.h" + +#define DBG_PRINT 0 + +#ifdef SPEC_Q_SEQ_LEN +static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN is only supported for SPEC_DEC"); +constexpr uint32_t specDecQLen = SPEC_Q_SEQ_LEN; +static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is too large"); +#define SWAP_AB 1 +#else +#define SWAP_AB (!SPEC_DEC) +#endif + +#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT) + +inline constexpr bool swapAB = SWAP_AB; + +#pragma region Config + +static_assert( + (inputElemSize == cacheElemSize && mha::is_same_v) || inputElemSize > cacheElemSize); +using MathElem + = mha::conditional_t<(inputElemSize > cacheElemSize && mha::is_same_v), InputElem, CacheElem>; + +constexpr uint32_t gmmaWarpsPerGrp = 4; +constexpr uint32_t gmmaWarpGrpSize = warp_size * gmmaWarpsPerGrp; +constexpr uint32_t gemm0NbGmmaGrps = 1; +constexpr uint32_t gemm0NbThrds = gmmaWarpGrpSize * gemm0NbGmmaGrps; +constexpr uint32_t gemm0NbWarps = gmmaWarpsPerGrp * gemm0NbGmmaGrps; +#if SPEC_DEC && !SWAP_AB +inline constexpr uint32_t ctaNbQHeads = Q_HEADS_PER_CTA; +inline constexpr uint32_t inputTokensPerCta = ctaNbQHeads / headGrpSize; +constexpr uint32_t ctaNbValidQHeads = ctaNbQHeads; +#elif SPEC_DEC && SWAP_AB +inline constexpr uint32_t inputTokensPerCta = specDecQLen; +inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * inputTokensPerCta; +inline constexpr uint32_t ctaNbQHeads = []() +{ + static_assert(ctaNbValidQHeads <= 32, "ctaNbValidQHeads cannot exceed 32"); + if constexpr (ctaNbValidQHeads <= 8) + { + return 8; + } + if constexpr (ctaNbValidQHeads <= 16) + { + return 16; + } + return 32; +}(); +#else +inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * beamWidth; +inline constexpr uint32_t ctaNbQHeads = roundUp(ctaNbValidQHeads, swapAB ? 8U : 64U); +inline constexpr uint32_t inputTokensPerCta = 1; +#endif +constexpr uint32_t gemm0WarpGrpTileNbTokens = 64; +inline constexpr uint32_t gemm0CtaTileNbTokens = gemm0WarpGrpTileNbTokens * gemm0NbGmmaGrps; +constexpr uint32_t gemm1NbGmmaGrps = 1; +constexpr uint32_t gemm1NbThrds = gmmaWarpGrpSize * gemm1NbGmmaGrps; +constexpr uint32_t gemm1NbWarps = gmmaWarpsPerGrp * gemm1NbGmmaGrps; +constexpr uint32_t gemm1CtaTileNbTokens = gemm0CtaTileNbTokens; +constexpr uint32_t mathHeadBytes = sizeof(Vec); +constexpr uint32_t nbIOWarps = 4; +constexpr uint32_t nbIOThrds = warp_size * nbIOWarps; +constexpr uint32_t multiBlockMinNbTilesPerCta = 1; // 3; // @fixme: need tuning +constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; +constexpr uint32_t nbWarps = gemm0NbWarps + gemm1NbWarps + nbIOWarps; + +constexpr uint32_t cacheHeadPartBytes = mha::min(paddedCacheHeadBytes, 128U); +constexpr uint32_t cacheHeadNbParts + = exactDiv(paddedCacheHeadBytes, cacheHeadPartBytes); // @fixme: support divUp in the future +constexpr uint32_t cacheHeadPartElems = exactDiv(headElems, cacheHeadNbParts); +constexpr uint32_t swizzleBytes = cacheHeadPartBytes; +static_assert(swizzleBytes == 128 || swizzleBytes == 64 || swizzleBytes == 32); + +constexpr bool needInputCvt = inputElemSize > cacheElemSize&& mha::is_same_v; +constexpr bool needCacheCvt = inputElemSize > cacheElemSize&& mha::is_same_v; +static_assert(needInputCvt || needCacheCvt || mha::is_same_v); + +using ShmQWiseVec = Vec; + +constexpr uint32_t qPartBytes = mha::min(mathHeadBytes, 128U); +constexpr uint32_t nbQParts = exactDiv(mathHeadBytes, qPartBytes); +constexpr uint32_t grainsPerQPart = exactDiv(qPartBytes, grainBytes); + +constexpr uint32_t xPartBytes = mha::min(cacheElemSize * gemm0CtaTileNbTokens, 128U); +constexpr uint32_t nbXParts = exactDiv(cacheElemSize * gemm0CtaTileNbTokens, xPartBytes); +constexpr uint32_t grainsPerXPart = exactDiv(xPartBytes, grainBytes); +constexpr uint32_t cacheElemsPerGrain = exactDiv(grainBytes, cacheElemSize); + +constexpr uint32_t grainsPerIOHead = exactDiv(ioHeadBytes, grainBytes); +constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); + +#if USE_BEAM_SEARCH +constexpr uint32_t beamSearchGemm0CtaTileNbTokens = exactDiv(gemm0CtaTileNbTokens, beamWidth); +#endif + +using PaddedOutHead = PaddedInputHead; + +#pragma endregion Config + +struct alignas(128) SharedMem +{ + using KBuffer = Array2D; + static constexpr uint32_t nbKBuf = 2; + KBuffer k[nbKBuf]; // as is loaded from global mem. + using XBuffer = Vec, nbXParts>; + static constexpr uint32_t nbXBuf + = 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); + using VBuffer = Vec, + cacheHeadNbParts>; +#if !SWAP_AB + using VTBuffer = Array2D; +#endif + static constexpr uint32_t nbVBuf = 2; +#if CACHE_ELEM_ENUM == 0 + using OutSwizzleBuf = Array2D; +#elif CACHE_ELEM_ENUM == 2 + using OutSwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; +#endif + static_assert(nbXBuf == nbVBuf); + + union ReusedXVOutSwizzleBuf + { + struct XV + { + XBuffer x; + VBuffer v; +#if !SWAP_AB + VTBuffer vt; +#endif + // @fixme: also put xColMax and xColSum here + } xv; + + OutSwizzleBuf outSwizzle; + } reusedXVOutSwizzleBuf[nbXBuf]; + + static_assert(sizeof(OutSwizzleBuf) <= sizeof(SharedMem::ReusedXVOutSwizzleBuf::XV), + "need to use split output to avoid excessive shared memory usage"); + + __device__ inline XBuffer& xBuf(uint32_t i) + { + return reusedXVOutSwizzleBuf[i].xv.x; + } + + __device__ inline VBuffer& vBuf(uint32_t i) + { + return reusedXVOutSwizzleBuf[i].xv.v; + } +#if !SWAP_AB + __device__ inline VTBuffer& vtBuf(uint32_t i) + { + return reusedXVOutSwizzleBuf[i].xv.vt; + } +#endif + __device__ inline OutSwizzleBuf& outSwizzleBuf(uint32_t i) + { + return reusedXVOutSwizzleBuf[i].outSwizzle; + } + + using QBuffer = Vec, nbQParts>; + QBuffer q; // For gmma math. Conversion done if needed. + + // @fixme: move these into reusedXVOutSwizzleBuf +#if SWAP_AB + ShmQWiseVec xColMax[nbXBuf]; + ShmQWiseVec xColSum[nbXBuf][gemm0NbWarps]; +#else + ShmQWiseVec xRowMax[nbXBuf]; + ShmQWiseVec xRowSum[nbXBuf]; +#endif + + ShmQWiseVec gemm0CurrentSeqMax; + // col sum and max for the current gemm1 acc. Use shared memory to save some registers. register storage will be 8x + // duplicate for swapAB and 4x duplicate for non-swapAB. + ShmQWiseVec gemm1AccColMax; + ShmQWiseVec gemm1AccColSum; + +#if USE_PAGED_KV_CACHE + static constexpr uint32_t nbPagesPerTile + = gemm0CtaTileNbTokens >= tokensPerPage ? exactDiv(gemm0CtaTileNbTokens, tokensPerPage) : 1; + Vec pages[2]; // one for K and one for V +#endif + + // mem barriers + + CtaBarrierPair qBar; + CtaBarrierPair kBar[nbKBuf]; + CtaBarrierPair vBar[nbVBuf]; +#if !SWAP_AB + CtaBarrierPair vtBar[nbVBuf]; +#endif + CtaBarrierPair xBar[nbXBuf]; + + // used internally in the gemm0 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm0WarpGrpBar; + + // used internally in the gemm1 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm1WarpGrpBar; + + bool isLastCta; +}; + +CUBIN_EXPORT __device__ constexpr uint32_t smemSize = sizeof(SharedMem); +#ifdef __CUDA_ARCH__ +static_assert(smemSize < kMAX_SMEM_SIZE); +#endif + +constexpr uint32_t nbQLdWarps = needInputCvt ? nbIOWarps - 2 : 1; +constexpr uint32_t nbQLdThrds = warp_size * nbQLdWarps; + +#if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 +template +struct F16QToF8Converter +{ + static_assert(inputElemSize == 2); + using F16Vec = Vec; +#if CACHE_ELEM_ENUM == 0 + using ShmVec = F16Vec; +#elif CACHE_ELEM_ENUM == 2 + using F8Vec = Vec; + using ShmVec = F8Vec; +#endif + + static constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); + static constexpr uint32_t grainsPerPaddedInputQHeadGrp = grainsPerPaddedInputHead * headGrpSize; +#if !(SPEC_DEC) + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * beamWidth; +#else + static_assert(beamWidth == 1); + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * inputTokensPerCta; +#endif + static constexpr uint32_t nbIters = divUp(totalGrains, nbThrds); + + using RegData = Vec; + + static __device__ RegData load(uint32_t tid, TinyPtr const& src, + uint32_t const nbKHeads /*for beam search and spec dec*/, uint32_t nbTokens); + static __device__ void store(uint32_t tid, SharedMem::QBuffer& dst, RegData const& data); +}; +#endif // CACHE_ELEM_ENUM + +struct KVTilePartLoader +{ + static constexpr uint32_t nbParts = cacheHeadNbParts; + static constexpr uint32_t partElems = exactDiv(headElems, nbParts); + +#if USE_PAGED_KV_CACHE + static_assert(gemm0CtaTileNbTokens % tokensPerPage == 0 || tokensPerPage % gemm0CtaTileNbTokens == 0); + static constexpr uint32_t nbPagesPerTile = SharedMem::nbPagesPerTile; +#endif + + uint32_t const nbKHeads; + KVCacheList const& cacheList; + uint32_t const idxReq; + uint32_t const idxHeadGrp; + + CUtensorMap const& tensorMap; +#if USE_PAGED_KV_CACHE + uint32_t const nbPages; // for bound check + Vec& pages; + uint32_t idxTileRef; // idxTile used to load the pages +#endif + uint32_t const baseOffset; + + __device__ KVTilePartLoader(bool isK, uint32_t nbKHeads, KVCacheList const& cacheList, + uint32_t idxReq, uint32_t idxHeadGrp, CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages, Vec& pageBuf +#endif + ); + // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache + template + __device__ void loadData( + Array2D& dst, uint32_t idxTile, + uint32_t idxPart, CtaBarrier& bar); + + __device__ void loadPages(uint32_t idxTile); + __device__ GMemKVCacheHead& getHead(uint32_t pos); +}; + +using GmmaAccCoreMat = Array2D; +template +using GmmaAcc = Array2D; + +inline constexpr uint32_t gemm0M = (swapAB ? gemm0CtaTileNbTokens : ctaNbQHeads); +inline constexpr uint32_t gemm0N = (swapAB ? ctaNbQHeads : gemm0CtaTileNbTokens); + +using Gemm0Acc = GmmaAcc; + +#if SWAP_AB +using RegColWiseVec = Vec, Gemm0Acc::cols>; +using UniformNeedRescaleMask = Vec; +using RegSeqWiseVec = RegColWiseVec; +#else +using RegRowWiseVec = Vec, Gemm0Acc::rows>; +using UniformNeedRescaleMask + = Vec; +using RegSeqWiseVec = RegRowWiseVec; +#endif + +#if SPEC_DEC + +__device__ inline uint32_t getInputSeqLen(SpecDecParams const& params, uint32_t idxReq) +{ + return (params.qCuSeqLens == nullptr) ? params.qSeqLen : params.qCuSeqLens[idxReq + 1] - params.qCuSeqLens[idxReq]; +} + +__device__ inline uint32_t getInputTokOffset(SpecDecParams const& params, uint32_t idxReq) +{ + return (params.qCuSeqLens == nullptr) ? params.qSeqLen * idxReq : params.qCuSeqLens[idxReq]; +} + +struct SpecDec +{ + static inline constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static inline constexpr uint32_t ctaMaxQSeqLen = (ctaNbQHeads / headGrpSize); + using TileMaskRow = Vec; + + __device__ inline SpecDec(SpecDecParams const& params, uint32_t idxReq, uint32_t idxInputSubSeq, uint32_t seqLen) + : params(params) + , idxInputSubSeq(idxInputSubSeq) + , seqLen(seqLen) + { + inputSeqLen = getInputSeqLen(params, idxReq); + baseOffset = divUp(params.qSeqLen, 32U) * (getInputTokOffset(params, idxReq) + ctaMaxQSeqLen * idxInputSubSeq); + } + + __device__ inline uint32_t unmaskedSeqLen() const + { + return seqLen - inputSeqLen; + } + + __device__ inline bool needMask(uint32_t idxTile, uint32_t idxQTokInCta) const + { + return tileSize * (idxTile + 1) > unmaskedSeqLen() + && ctaMaxQSeqLen * idxInputSubSeq + idxQTokInCta < inputSeqLen && params.mask != nullptr; + } + + __device__ inline int32_t maskColBeg(uint32_t idxTile) const + { + int32_t const convergedSeqLen = int32_t(unmaskedSeqLen()); + return static_cast(exactDiv(tileSize, 32) * idxTile) + - static_cast(divUp(convergedSeqLen, 32)); + } + + __device__ inline TileMaskRow loadTileMaskRow(uint32_t idxTile, uint32_t idxQTokInCta) const + { + assert(needMask(idxTile, idxQTokInCta)); + constexpr uint32_t nbOrigElems = TileMaskRow::size + 1; + Vec orig; + + int32_t const cols = divUp(params.qSeqLen, 32); + uint32_t const rowOffset = baseOffset + idxQTokInCta * cols; + int32_t const colBeg = maskColBeg(idxTile); +#pragma unroll + for (int32_t i = 0; i < int32_t(nbOrigElems); i++) + { + int32_t const idx = colBeg + i; + orig[i] = inRange(idx, 0, cols) ? params.mask[rowOffset + idx] : (idx < 0 ? ~0U : 0U); + } + TileMaskRow mask; + uint32_t const shift = (32 - unmaskedSeqLen() % 32) % 32; +#pragma unroll + for (uint32_t i = 0; i < TileMaskRow::size; i++) + { + asm("shf.r.clamp.b32 %0, %1, %2, %3;\n" : "=r"(mask[i]) : "r"(orig[i]), "r"(orig[i + 1]), "r"(shift)); + } + return mask; + } + + SpecDecParams const& params; + uint32_t const idxInputSubSeq; + uint32_t const seqLen; + uint32_t inputSeqLen; + uint32_t baseOffset; +}; + +__device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank); +#endif + +#if SWAP_AB +__device__ RegColWiseVec computeWarpGrpColMax_sync( + CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src); +__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd); +__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax); +__device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); +__device__ void storeGemm0AccToShm( + uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); +__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound); +#else +__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src); +__device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); +__device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& colMax); +__device__ RegRowWiseVec computeWarpRowSum(Gemm0Acc& src); +__device__ void storeGemm0AccToShm( + uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec); +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, RegRowWiseVec const& regVec); +#endif + +using RegMatAFrag = Array2D, 1, 2>; +constexpr uint32_t gemm1NbGmmaInstK = exactDiv(gemm1CtaTileNbTokens, gmma::instK); + +#if SWAP_AB +constexpr uint32_t gemm1NbGmmaInstM = exactDiv(headElems, gmma::instM); +__device__ Vec loadVTileTransposed( + uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK); +using Gemm1Acc = GmmaAcc; +__device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, + ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar); +template +__device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, + SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, + uint32_t nbKHeads = 0 /* only for final result in spec dec. */); +#else +__device__ void transposeVTile( + uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src); +using Gemm1Acc = GmmaAcc; +__device__ void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, + ShmQWiseVec const(&shmXRowSum), ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, ShmQWiseVec& shmAccRowSum); +template +__device__ void finalizeAndWriteOut_sync(uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, ShmQWiseVec const& accColSum, + uint32_t nbKHeads /* only for final result in spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens); +#endif + +inline constexpr uint32_t ropeNbPairsPerThrdImpl(uint32_t nbThrds) +{ + auto const val = divUp(exactDiv(validElemsPerHead, 2), nbThrds); + assert(val <= 32); + return val <= 2 ? val : (val <= 4 ? 4 : (val <= 8 ? 8 : (val <= 16 ? 16 : 32))); +} + +template +inline constexpr uint32_t ropeNbPairsPerThrd = ropeNbPairsPerThrdImpl(nbThrds); + +template +__device__ Vec, ropeNbPairsPerThrd> loadHead( + Vec const& head, uint32_t tid); +template +__device__ mha::conditional_t, 2>, Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, Vec, nbPairsPerThrd> const& ropeCosSin); +template +__device__ void storeRotatedPairsForKV(GMemCacheHead& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t tid); +template +__device__ void storeRotatedPairsForQ(SharedMem::QBuffer& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t row, uint32_t tid); + +class ScratchMem +{ +public: + struct alignas(8) SumMax + { + float sum; + float max; + }; + + using ColWiseVec = Vec; + + HOST_DEVICE_FUNC ScratchMem(void* scratch, uint32_t maxTotalNbSubSeq, uint32_t nbInputSeqSplit) + : mScratch{static_cast(scratch)} + { + uint32_t const nbChunks = maxTotalNbSubSeq * nbInputSeqSplit; + Segmenter segmenter; + constexpr uint32_t alignment = sizeof(Vec); + mRowSumMax = segmenter.template newSeg(nbChunks, alignment); + mTokens = segmenter.template newSeg>(nbChunks, alignment); + } + + HOST_DEVICE_FUNC TinyPtr rowSumMax() const + { + return makePtr(mRowSumMax); + } + + HOST_DEVICE_FUNC TinyPtr> tokens() const + { + return makePtr>(mTokens); + } + +private: + template + HOST_DEVICE_FUNC TinyPtr makePtr(uint32_t offset) const + { + return TinyPtr{mScratch, offset}.template cast(); + } + +private: + mha::byte* mScratch; + // offsets + uint32_t mRowSumMax; + uint32_t mTokens; +}; + +struct MultiBlockSMem +{ + using ColWiseVec = ScratchMem::ColWiseVec; + static constexpr uint32_t nbBuf = useSpecDec ? 2 : 4; + static constexpr uint32_t nbIOWarps = nbBuf; + using Elem = InputElem; + using Head = Vec; + Vec, nbBuf> tokens; + Vec rowSumMax; + Vec barriers; +}; + +#ifndef NDEBUG +namespace dbg +{ +template +__device__ void printAcc( + CtaBarrier& warpGrpBar, uint32_t warpRank, Array2D const& acc) +{ + for (int m = 0; m < nbGmmaInstM; m++) + { + for (int w = 0; w < 4; w++) + { + if (warpRank == w) + { + for (int a = 0; a < 2; a++) + { + for (int b = 0; b < 8; b++) + { + for (int n = 0; n < nbGmmaInstNBase; n++) + { + for (uint32_t i = 0; i < 4; i++) + { + if (laneId() == b * 4 + i) + { + printf("%f, %f, ", acc(m, n)(a, 0), acc(m, n)(a, 1)); + } + __syncwarp(); + } + } + if (laneId() == 0) + { + printf("\n"); + } + __syncwarp(); + } + if (laneId() == 0) + { + printf("\n"); + } + __syncwarp(); + } + } + warpGrpBar.arrive_and_wait(); + } + } +} + +__device__ void printShmColWiseVec(ShmQWiseVec const& vec) +{ + for (uint32_t i = 0; i < vec.size; i++) + { + printf("%f, ", vec[i]); + } + printf("\n"); +} + +template +__device__ void printArray2D(Array2D const& src) +{ + for (uint32_t i = 0; i < rows; i++) + { + for (uint32_t j = 0; j < cols; j++) + { + T const val = src.template at(i, j); + for (uint32_t k = 0; k < exactDiv(sizeof(T), sizeof(Elem)); k++) + { + printf("%f, ", float(reinterpret_cast(&val)[k])); + } + } + printf("\n"); + } +} +} // namespace dbg +#endif + +CUBIN_EXPORT __device__ constexpr XQAKernelType kernelType = XQAKernelType::kHOPPER_WARP_SPECIALIZED; + +CUBIN_EXPORT __global__ +#ifdef NDEBUG +#if !OPTIMIZE_FOR_LATENCY + __launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2) +#else + __launch_bounds__(128 * 3) +#endif +#else + __launch_bounds__(128 * 3, 1) +#endif + void kernel_mha(uint32_t const nbKHeads, +#if SLIDING_WINDOW + uint32_t const slidingWinSize, +#endif + float const qScale, + OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] +#if LOW_PREC_OUTPUT + float const* const rcpOutScale, +#endif +#if USE_INPUT_KV + IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads], +#if ROPE_STYLE != 0 + Vec const* __restrict__ const ropeCosSin, // [maxNbPosEmb] +#endif +#else + IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], +#endif + float const* attentionSinks, // [headGrpSize] + KVCacheList const cacheList, +#if USE_BEAM_SEARCH + BeamSearchParams const beamSearchParams, +#endif + uint32_t const batchSize, + float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used + // only for int8/fp8 KV cache. +#if PAGED_KV_CACHE_LAYOUT == 1 + __grid_constant__ CUtensorMap const tensorMapVLLMK, __grid_constant__ CUtensorMap const tensorMapVLLMV, +#else + __grid_constant__ CUtensorMap const tensorMap, +#endif +#if SPEC_DEC + SpecDecParams const specDecParams, +#endif + uint32_t* __restrict__ const semaphores + = nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] + void* __restrict__ const scratch = nullptr) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) \ + && (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 + uint32_t const idxReq = blockIdx.z / nbKHeads; +#if SPEC_DEC + uint32_t const reqInputTokBeg = getInputTokOffset(specDecParams, idxReq); + uint32_t const reqInputTokEnd = getInputTokOffset(specDecParams, idxReq + 1); + uint32_t const nbInputSeqSplit = gridDim.x; + assert(nbInputSeqSplit == divUp(specDecParams.qSeqLen, inputTokensPerCta)); +#else + uint32_t const reqInputTokBeg = idxReq; + uint32_t const reqInputTokEnd = idxReq + 1; + constexpr uint32_t nbInputSeqSplit = 1; + assert(gridDim.x == nbInputSeqSplit); +#endif + uint32_t const idxHeadGrp = blockIdx.z % nbKHeads; // inside one request + assert(gridDim.z == nbKHeads * batchSize); + uint32_t const cacheSeqLen = getCacheSeqLen(cacheList, idxReq); + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; +#if SPEC_DEC + uint32_t const idxInputSubSeq = blockIdx.x; + uint32_t const inputSeqLen = reqInputTokEnd - reqInputTokBeg; + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + uint32_t const ctaNbValidTokens = mha::min(uint32_t{inputTokensPerCta}, inputSeqLen - ctaTokOffset); + + if (ctaTokOffset >= inputSeqLen) + { + return; + } +#else + uint32_t const idxInputSubSeq = 0; + uint32_t const inputSeqLen = 1; + uint32_t const ctaTokOffset = 0; + uint32_t const ctaNbValidTokens = 1; +#endif +#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE + // get the actual start position depending on ctaTokOffset, which is the draft token position per CTA + uint32_t const tok0SeqLen = cacheSeqLen - inputSeqLen + 1 + ctaTokOffset; + int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize); + uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg); +#elif SLIDING_WINDOW + bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); + // if SPEC_DEC && SLIDING_WINDOW && IS_SPEC_DEC_TREE, it should not do sliding + assert(!SPEC_DEC || !rtIsReallySliding); + uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0; +#else + constexpr bool rtIsReallySliding = false; + constexpr uint32_t nbTotalSkipTokens = 0; +#endif + uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / tileSize; + uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % tileSize; + +#if USE_BEAM_SEARCH + uint32_t const ctxCacheSeqLen = getCtxCacheSeqLen(beamSearchParams, idxReq); + uint32_t const nbCtxKTiles = useKVCache ? ctxCacheSeqLen / gemm0CtaTileNbTokens : 0; + uint32_t const nbDivergentKTiles + = useKVCache ? divUp(cacheSeqLen - gemm0CtaTileNbTokens * nbCtxKTiles, beamSearchGemm0CtaTileNbTokens) : 0; + uint32_t const nbKTiles = nbCtxKTiles + nbDivergentKTiles; + uint32_t const nbVTiles = nbKTiles; +#else + uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tileSize) : 0; + // uint32_t const nbKTiles = nbTiles; + // uint32_t const nbVTiles = nbTiles; + uint32_t const nbTilesInUse = nbTiles - nbSkipLeadingTiles; +#endif + uint32_t const maxNbSubSeq = gridDim.y; + uint32_t const idxSubSeq = blockIdx.y; + bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTilesInUse >= multiBlockMinNbTiles); + uint32_t const idxKTileInit = nbSkipLeadingTiles + idxSubSeq; + uint32_t const idxVTileInit = idxKTileInit; + uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; + static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); + assert(isMultiBlockMode == (nbSubSeq > 1)); + if (idxSubSeq >= nbSubSeq) + { + return; + } + uint32_t const ctaInputTokBeg = reqInputTokBeg + ctaTokOffset; + auto const warpIdx = getWarpIdx(uint3{128, 1, 3}); + auto const wid = warpIdx.z * 4 + warpIdx.x; +#if PAGED_KV_CACHE_LAYOUT == 1 + if (wid == 0 && warpElectSync()) + { + tma::prefetchTensorMap(tensorMapVLLMK); + tma::prefetchTensorMap(tensorMapVLLMV); + } +#else + if (wid == 0 && warpElectSync()) + { + tma::prefetchTensorMap(tensorMap); + } +#endif + extern __shared__ char smemByteBuf[]; + assert(dynamicSmemSize() >= sizeof(SharedMem)); + SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); + + constexpr uint32_t nbBuffers = 2; + static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf); + if (wid < nbBuffers) + { + if (warpElectSync()) + { + smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); + smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); +#if !SWAP_AB + smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); +#endif + smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); + } + } + else if (wid == nbBuffers) + { + if (warpElectSync()) + { + smem.qBar.initialize(gemm0NbThrds + nbQLdThrds, gemm0NbThrds + nbQLdThrds); + init(&smem.gemm0WarpGrpBar, gemm0NbThrds); + init(&smem.gemm1WarpGrpBar, gemm1NbThrds); + } + } + __syncthreads(); + +#if USE_PAGED_KV_CACHE + uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); +#endif + + constexpr bool isKVCacheQuantized = (cacheElemSize < 2); + assert(idxKTileInit < nbTiles); + uint32_t const nbIters = divUp(nbTiles - idxKTileInit, nbSubSeq); + assert(nbIters >= 1); + + constexpr uint32_t gmmaInstK = gmma::instK; + constexpr uint32_t grainsPerInstK = exactDiv(sizeof(MathElem) * gmmaInstK, grainBytes); + + if (warpIdx.z == 0) + { +#if SPEC_DEC + SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen}; +#endif + + // QK gemm + constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM); + using Acc = GmmaAcc; + + unused(smem.qBar.consumed.arrive()); + for (auto& b : smem.kBar) + { + unused(b.consumed.arrive()); + } + + float const qkScale = qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) + * rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. + uint32_t const warpRank = warpIdx.x; + + // init once per sequence. It also works as global colMax across iterations. + if (threadIdx.x < ctaNbQHeads) + { + smem.gemm0CurrentSeqMax[threadIdx.x] = safeInitRowMax; + } + smem.gemm0WarpGrpBar.arrive_and_wait(); + + smem.qBar.produced.arrive_and_wait(); +#if DBG_PRINT + if (threadIdx.x == 0) + { + printf("q:\n"); + dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[0]); + } +#endif + + auto const matDescQBase = gmma::makeMatDesc( + nullptr, 0, SharedMem::QBuffer::Elem::rowBytes * 8, gmma::getSwizzleMode(SharedMem::QBuffer::Elem{})) + .raw(); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) + { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + assert(idxKTile < nbTiles); + Acc acc; // no need to initialize. GMMA allows us to ignore acc initial values. + gmma::fence(); + static_assert(cacheHeadNbParts == nbQParts); +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) + { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBuf = smem.k[idxKBuf]; + auto& kBar = smem.kBar[idxKBuf]; + static_assert(SharedMem::KBuffer::rows % 8 == 0); + auto const matDescKBase = gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, &smem.k[0], + gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw(); + assert(matDescKBase + == gmma::makeMatDesc( + nullptr, 0, SharedMem::KBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw()); + arrive_tx_and_wait(kBar.produced, exactDiv(sizeof(SharedMem::KBuffer), gemm0NbThrds)); + // if (threadIdx.x == 0) { + // printf("************* part %u *******\n", idxPart); + // printf("q:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[idxPart]); + // printf("k:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(kBuf); + // } + constexpr uint32_t nbGmmaInstK = exactDiv(cacheHeadPartElems, gmmaInstK); +#pragma unroll + for (uint32_t k = 0; k < nbGmmaInstK; k++) + { + bool const accHasVal = (idxPart != 0 || k != 0); + auto const matDescQ = addAddr(matDescQBase, &smem.q[idxPart](0, grainsPerInstK * k)); +#pragma unroll + for (uint32_t m = 0; m < nbGmmaInstM; m++) + { + auto const matDescK = addAddr(matDescKBase, &kBuf(64 * m, grainsPerInstK * k)); +#if SWAP_AB + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescK, matDescQ, accHasVal); +#else + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescQ, matDescK, accHasVal); +#endif + } + } + gmma::commit_group(); + //@fixme: use two sets of acc and let gmma_async overlap with softmax. But this will let tile0_softmax + // wait for + // k loading of tile1 and may harm perf for short-seq cases. + gmma::wait_group<0>(); + unused(kBar.consumed.arrive()); + } +#if !defined(NDEBUG) && DBG_PRINT + dbg::printAcc(smem.gemm0WarpGrpBar, warpRank, acc); +#endif + // apply qkScale + acc = acc * qkScale; + // apply mask +#if SPEC_DEC + warpGrpApplyMask(acc, specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + tok0WinBeg, +#endif + cacheSeqLen, idxKTile, warpRank); +#else + bool const isFirstTile = (idxKTile == nbSkipLeadingTiles); + bool const needMaskLeading = (rtIsReallySliding && isFirstTile && tile0NbSkipTokens > 0); + bool const isLastTile = (idxKTile + 1 == nbTiles); + bool const needMaskTrailing = isLastTile && cacheSeqLen % tileSize != 0; + if (needMaskLeading || needMaskTrailing) + { + uint32_t const validTokenBeg = needMaskLeading ? tile0NbSkipTokens : 0; + uint32_t const validTokenEnd = (needMaskTrailing ? cacheSeqLen % tileSize : tileSize); + if (validTokenBeg > 0 || validTokenEnd < tileSize) + { +#if SWAP_AB + warpGrpApplyMask(warpRank, acc, validTokenBeg, validTokenEnd); +#else + warpGrpApplyMask(acc, validTokenBeg, validTokenEnd); +#endif + } + } +#endif + // update colMax in shared mem and get a register copy +#if SWAP_AB + RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, colMax); +#else + RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, rowMax); +#endif + + // @fixme: may need fp32->fp8->fp32 before doing sum. +#if SWAP_AB + RegColWiseVec const warpColSum = computeWarpColSum(acc); +#else + RegRowWiseVec const rowSum = computeWarpRowSum(acc); +#endif + + // map 1 to fp8_max before conversion to fp8 + acc = acc * kE4M3_MAX; + + uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; + auto& xBar = smem.xBar[idxXBuf]; + // @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM. +#if SWAP_AB + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + // store colMax and warpColSum + auto const lane = laneId(); + if (lane < 4) + { + auto& xColMax = smem.xColMax[idxXBuf]; + auto& xColSum = smem.xColSum[idxXBuf][warpRank]; +#pragma unroll + for (uint32_t n = 0; n < colMax.size; n++) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) + { + if (warpRank == 0) + { + xColMax[8 * n + 2 * lane + j] = colMax[n][j]; + } + xColSum[8 * n + 2 * lane + j] = warpColSum[n][j]; + } + } + } +#else + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax); + storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum); +#endif + + __syncwarp(); + // the release semantics of arrive does not work for async consumers like gmma. additional fence is + // needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(xBar.produced.arrive()); + } + unused(smem.qBar.consumed.arrive()); + } + else if (warpIdx.z == 1) + { + // XV GEMM + for (auto& b : smem.vBar) + { + unused(b.consumed.arrive()); + } +#if !SWAP_AB + for (auto& b : smem.vtBar) + { + unused(b.consumed.arrive()); + } +#endif + for (auto& b : smem.xBar) + { + unused(b.consumed.arrive()); + } + + if (threadIdx.x < smem.gemm1AccColMax.size) + { + auto const idx = threadIdx.x; + smem.gemm1AccColMax[idx] = safeInitRowMax; + smem.gemm1AccColSum[idx] = 0; + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + + uint32_t const warpRank = warpIdx.x; + + constexpr float xScale = 1.f / kE4M3_MAX; +#if LOW_PREC_OUTPUT + float const oScale = rcpOutScale[0]; +#else + constexpr float oScale = 1.F; +#endif + float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale; + + Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. + gmma::fence(); + + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens, "not implemented"); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) + { + uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq; + auto const idxVBuf = idxIter % SharedMem::nbVBuf; + auto const idxXBuf = idxVBuf; + auto& vBar = smem.vBar[idxVBuf]; + arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); + auto const& vBuf = smem.vBuf(idxVBuf); +#if !SWAP_AB + CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; + auto& vtBuf = smem.vtBuf(idxVBuf); + vtBar.consumed.arrive_and_wait(); + transposeVTile(warpRank, laneId(), vtBuf, vBuf); + vBar.consumed.arrive(); + vtBar.produced.arrive(); +#endif + auto& xBar = smem.xBar[idxXBuf]; + xBar.produced.arrive_and_wait(); +#if !defined(NDEBUG) && DBG_PRINT +#if SWAP_AB + if (threadIdx.x == 0) + { + printf("colMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xColMax[idxXBuf][i]); + } + printf("\n"); + printf("colSum:\n"); + for (int n = 0; n < 4; n++) + { + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xColSum[idxXBuf][n][i]); + } + printf("\n"); + } + printf("\n"); + printf("X:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + for (int j = 0; j < gemm0CtaTileNbTokens; j++) + { + auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); + auto const e = reinterpret_cast&>( + smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( + i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; + printf("%.2f, ", float(e)); + if (j % 16 == 15) + { + printf("| "); + } + } + printf("\n\n"); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); +#else + if (blockIdx.y == 1 && threadIdx.x == 0) + { + printf("rowMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xRowMax[idxXBuf][i]); + } + printf("\n"); + printf("rowSum:\n"); + for (int i = 0; i < ctaNbQHeads; i++) + { + printf("%f, ", smem.xRowSum[idxXBuf][i]); + } + printf("\n"); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); +#endif +#endif + +#if SWAP_AB + // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead. + rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar); +#else + rescaleGemm1AccForNewRowMax_sync( + warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum); +#endif + auto& xBuf = smem.xBuf(idxXBuf); + + auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) + .raw(); +#if CACHE_ELEM_ENUM == 0 + auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) + .raw(); +#endif +#if SWAP_AB +//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed. +#pragma unroll + for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) + { +#if CACHE_ELEM_ENUM == 2 + Vec const fragA + = loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); +#if !defined(NDEBUG) && DBG_PRINT + if (threadIdx.x == 0) + { + printf("fragA:\nidxInstK == %u\n", idxInstK); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + for (int m = 0; m < 2; m++) + { + for (int w = 0; w < 4; w++) + { + if (warpRank == w) + { + if (laneId() == 0) + { + printf(" warpRank = %u\n", warpRank); + } + __syncwarp(); + for (int a = 0; a < 2; a++) + { + for (int b = 0; b < 8; b++) + { + for (int c = 0; c < 2; c++) + { + for (int d = 0; d < 4; d++) + { + if (laneId() == b * 4 + d) + { + for (int e = 0; e < 4; e++) + { + auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>( + fragA[m](0, c)(a, 0)); + printf("%.2f, ", float(elem4[e])); + } + } + __syncwarp(); + } + } + if (laneId() == 0) + { + printf("\n"); + } + __syncwarp(); + } + if (laneId() == 0 && a == 0) + { + printf("----------------------\n"); + } + __syncwarp(); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + } + } +#endif +#endif + BoundedVal const kOffsetInGrains{grainsPerInstK * idxInstK}; + auto const descX = addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + 0, kOffsetInGrains.template mod().get())); +#if CACHE_ELEM_ENUM == 2 + gmma::fence(); +#endif +#pragma unroll + for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) + { +#if CACHE_ELEM_ENUM == 0 + auto const descV + = addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); + gmma::mma_async_shmA( + reinterpret_cast(acc(idxInstM, 0)), + descV, descX, true); +#elif CACHE_ELEM_ENUM == 2 + gmma::mma_async_regA( + reinterpret_cast(acc(idxInstM, 0)), + reinterpret_cast(fragA[idxInstM]), descX, true); +#endif + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of + // gmma. + gmma::wait_group<0>(); + } +#else + auto const descVTBase = gmma::makeMatDesc( + nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::VTBuffer{})) + .raw(); + vtBar.produced.arrive_and_wait(); +// if (idxIter == 1 && threadIdx.x == 0) { +// printf("vtBuf:\n"); +// dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf); +// } +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { +#pragma unroll + for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) + { + BoundedVal const kOffsetInGrains{grainsPerInstK * k}; + auto const descX = addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + gmma::instM * m, kOffsetInGrains.template mod().get())); + auto const descVT = addAddr( + descVTBase, &vtBuf(0, kOffsetInGrains.template mod().get())); + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), descX, + descVT, true); + } + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma. + gmma::wait_group<0>(); +#endif + if (idxIter == nbIters - 1) + { + // gmma::wait_group should have already synchronized threads, so this may be unnecessary. + smem.gemm1WarpGrpBar.arrive_and_wait(); + assert(idxXBuf == idxVBuf); + if (isMultiBlockMode) + { + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxSubSeq; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + // save row max/sum + static_assert(ctaNbValidQHeads <= gmmaWarpsPerGrp * warp_size); + if (threadIdx.x < ctaNbValidQHeads) + { + float const colMax = smem.gemm1AccColMax[threadIdx.x]; + float const colSum = smem.gemm1AccColSum[threadIdx.x]; + ScratchMem::SumMax sumMax; + sumMax.sum = colSum; + sumMax.max = colMax; + (scratchMem.rowSumMax() + idxChunk).template cast()[threadIdx.x] = sumMax; + } + // compute scratch ptr for output writing + IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); +#if SWAP_AB + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr); +#else + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, 1, ctaNbValidTokens); +#endif + } + else + { + uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + OutputHead* const dst = &output[outOffset]; + ShmQWiseVec const* attentionSinksVec = nullptr; + if (attentionSinks != nullptr) + { + attentionSinksVec + = reinterpret_cast(attentionSinks + headGrpSize * idxHeadGrp); + } +#if SWAP_AB + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, + xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec, + nbKHeads); +#else + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); +#endif + } + } + unused(xBar.consumed.arrive()); +#if SWAP_AB + unused(vBar.consumed.arrive()); +#else + unused(vtBar.consumed.arrive()); +#endif + } + } + else + { + // IO warps + static_assert(beamWidth == 1); +#if ENABLE_PDL + preExit(); +#endif +#if ENABLE_PDL == 1 + acqBulk(); +#endif + assert(warpIdx.z == 2); + uint32_t const newTokenPos = cacheSeqLen - 1; + if (warpIdx.x < nbQLdWarps) + { + // load Q. Use register to load fp16 data and store fp8 to shared mem. + // @fixme: If register pressure is high and shared mem pressure is low, switch to TMA instead. + using QCvt = F16QToF8Converter; + static_assert(beamWidth == 1); +#if USE_INPUT_KV + TinyPtr const qData{qkv, headGrpSize * idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq}; + constexpr bool isNeox = (ROPE_STYLE == 1); + constexpr uint32_t thrdsPerHead = mha::min(warp_size, divUp(headElems, 4U)); + uint32_t const lane = laneId(); + uint32_t const idxThrd = warpIdx.x * warp_size + lane; + uint32_t const idxThrdGrp = (thrdsPerHead % 32 == 0 ? makeWarpUniform(this_warp(), idxThrd / thrdsPerHead) + : idxThrd / thrdsPerHead); + constexpr uint32_t nbThrdGrps = exactDiv(warp_size * nbQLdWarps, thrdsPerHead); + uint32_t const tid = idxThrd % thrdsPerHead; + smem.qBar.consumed.arrive_and_wait(); +#if ROPE_STYLE != 0 + auto const& ropeCosSinHead + = reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, tid); +#endif +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#pragma unroll + for (uint32_t iter = 0; iter < divUp(headGrpSize, nbThrdGrps); iter++) + { + uint32_t const idxHead = nbThrdGrps * iter + idxThrdGrp; + if (idxHead >= headGrpSize) + { + break; + } +#if ROPE_STYLE == 0 + auto const rotatedPairs = loadHead(qData[idxHead], tid); +#else + auto const pairs = loadHead(qData[idxHead], tid); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); +#endif + storeRotatedPairsForQ(smem.q, rotatedPairs, idxHead, tid); + } +#else + TinyPtr const qData{q, headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp)}; +#if ENABLE_PDL == 2 + acqBulk(); +#endif + auto const f16QData = QCvt::load(threadIdx.x, qData, nbKHeads, ctaNbValidTokens); + + smem.qBar.consumed.arrive_and_wait(); + QCvt::store(threadIdx.x, smem.q, f16QData); +#endif + // the release semantics of arrive does not work for async consumers like gmma. additional fence is + // needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(smem.qBar.produced.arrive()); + } + else if (warpIdx.x == nbQLdWarps) + { // load k + KVTilePartLoader kTilePartLoader + { + true, nbKHeads, cacheList, idxReq, idxHeadGrp, +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, +#else + tensorMap, +#endif + nbPages, smem.pages[0] +#else + tensorMap +#endif + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) + { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + kTilePartLoader.loadPages(idxKTile); +#if USE_INPUT_KV || ENABLE_PDL == 2 +#if SPEC_DEC + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) > cacheSeqLen - inputSeqLen); +#else + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) >= cacheSeqLen); +#endif + if (anyNewTokens) + { +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#if USE_INPUT_KV + static_assert(beamWidth == 1); + uint32_t const inputKHeadOffset + = headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inKHead = qkv[inputKHeadOffset]; + uint32_t const lane = laneId(); + float const rcpKScale = 1.F / kvCacheScale[0]; +#if ROPE_STYLE == 0 + constexpr bool isNeox = false; + auto const pairs = loadHead(inKHead, lane) * rcpKScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) + = convert(reinterpret_cast const&>(pairs)); + storeRotatedPairsForKV( + kTilePartLoader.getHead(newTokenPos), convertedPairs, lane); +#else + constexpr bool isNeox = (ROPE_STYLE == 1); + auto const pairs = loadHead(inKHead, lane) * rcpKScale; + auto const& ropeCosSinHead + = reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, lane); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); + storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), rotatedPairs, lane); +#endif + static_assert(inputSeqLen == 1); + __syncwarp(); +#endif + } +#endif + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) + { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBar = smem.kBar[idxKBuf]; + kBar.consumed.arrive_and_wait(); + if (warpElectSync()) + { + kTilePartLoader.loadData(smem.k[idxKBuf], idxKTile, idxPart, kBar.produced); + } + __syncwarp(); + } + } + } + else if (warpIdx.x == nbQLdWarps + 1) + { // load v + KVTilePartLoader vTileLoader + { + false, nbKHeads, cacheList, idxReq, idxHeadGrp, +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMV, +#else + tensorMap, +#endif + nbPages, smem.pages[1] +#else + tensorMap +#endif + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) + { + uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; + vTileLoader.loadPages(idxVTile); +#if USE_INPUT_KV || ENABLE_PDL == 2 +#if SPEC_DEC + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) > cacheSeqLen - inputSeqLen); +#else + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) >= cacheSeqLen); +#endif + if (anyNewTokens) + { +#if ENABLE_PDL == 2 + acqBulk(); +#endif +#if USE_INPUT_KV + static_assert(beamWidth == 1); + uint32_t const inputVHeadOffset + = (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inVHead = qkv[inputVHeadOffset]; + uint32_t const lane = laneId(); + float const rcpVScale = 1.F / kvCacheScale[0]; + constexpr bool isNeox = false; + auto const pairs = loadHead(inVHead, lane) * rcpVScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) + = convert(reinterpret_cast const&>(pairs)); + static_assert(SPEC_DEC == 0); + storeRotatedPairsForKV(vTileLoader.getHead(newTokenPos), convertedPairs, lane); + __syncwarp(); +#endif + } +#endif + + uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; + auto& vBar = smem.vBar[idxVBuf]; + vBar.consumed.arrive_and_wait(); + if (warpElectSync()) + { +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) + { + vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced); + } + } + __syncwarp(); + } + } + } + __syncthreads(); + uint32_t const nbBarriers = &smem.gemm1WarpGrpBar - &smem.qBar.produced + 1; + uint32_t const tid = threadIdx.x + blockDim.x * threadIdx.y + blockDim.x * blockDim.y * threadIdx.z; + assert(nbBarriers <= blockDim.x * blockDim.y * blockDim.z); + if (tid < nbBarriers) + { + (&smem.qBar.produced)[tid].~CtaBarrier(); + } + if (!isMultiBlockMode) + { + return; + } + bool& smemIsLastCta = smem.isLastCta; + if (threadIdx.x == gemm1NbThrds - 1U && threadIdx.z == 0) + { + uint32_t const lastOld = nbSubSeq - 1; + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t old; + uint32_t const idxSemaphore = idxSeq * nbInputSeqSplit + idxInputSubSeq; + auto const pSemaphore = &semaphores[idxSemaphore]; + asm volatile("atom.acq_rel.gpu.global.inc.u32 %0, [%1], %2;\n" : "=r"(old) : "l"(pSemaphore), "r"(lastOld)); + smemIsLastCta = (old == lastOld); + } + { + assert(dynamicSmemSize() >= sizeof(MultiBlockSMem)); +#ifndef __CUDACC_RTC__ + assert(sizeof(MultiBlockSMem) < offsetof(SharedMem, isLastCta)); +#endif + auto& smem = *reinterpret_cast(&smemByteBuf[0]); + assert(blockDim.x >= MultiBlockSMem::nbBuf); + constexpr uint32_t nbMathWarps = gemm0NbWarps + gemm1NbWarps; + + static_assert(nbWarps >= MultiBlockSMem::nbBuf); + if (wid < MultiBlockSMem::nbBuf) + { + if (warpElectSync()) + { + smem.barriers[wid].initialize(isHeadPadded ? warp_size : 1U, nbMathWarps * warp_size); + smem.barriers[wid].consumed.arrive(nbMathWarps * warp_size); + } + } + __syncthreads(); + + if (!smemIsLastCta) + { + return; + } + if (wid < nbMathWarps) + { + constexpr uint32_t headsPerWarp = divUp(ctaNbValidQHeads, nbMathWarps); + using Acc = Vec; + + struct HeadState + { + Acc acc; + float sum; + float max; + }; + + Vec states{}; + for (auto& s : states.data) + { + s.max = safeInitRowMax; + } + uint32_t const lane = laneId(); + for (uint32_t idxBlock = 0; idxBlock < nbSubSeq; idxBlock++) + { + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.produced.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + for (uint32_t i = 0; i < headsPerWarp; i++) + { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) + { + break; + } + HeadState& state = states[i]; + auto const sumMax = smem.rowSumMax[idxBuf][idxHead]; + auto const data = convert( + reinterpret_cast&>(smem.tokens[idxBuf][idxHead][Acc::size * lane])); + if (sumMax.max > state.max) + { + float const scale = expf(state.max - sumMax.max); + state.max = sumMax.max; + state.sum = state.sum * scale + sumMax.sum; + state.acc = state.acc * scale + data * sumMax.sum; + } + else + { + float const scale = expf(sumMax.max - state.max); + state.sum = state.sum + sumMax.sum * scale; + state.acc = state.acc + data * (sumMax.sum * scale); + } + } + unused(bar.consumed.arrive()); + } + // Add the attention sinks. + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headsPerWarp; i++) + { + uint32_t const idxHead = wid + nbMathWarps * i; + float sink = expf( + attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max); + states[i].sum += sink; + } + } + __syncthreads(); + uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + auto const dst = &output[outOffset]; + for (uint32_t i = 0; i < headsPerWarp; i++) + { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) + { + break; + } +#if SPEC_DEC + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken >= ctaNbValidTokens) + { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const idxDstHead = idxHead + idxToken * tokenPad; +#else + uint32_t const idxDstHead = idxHead; +#endif + auto const& s = states[i]; + auto const outData = convert(s.acc * (1.f / s.sum)); + if (Acc::size * lane < validElemsPerHead) + { + reinterpret_cast&>(dst[idxDstHead][Acc::size * lane]) = outData; + } + } + } + else if (wid < nbMathWarps + MultiBlockSMem::nbIOWarps) + { + static_assert(MultiBlockSMem::nbIOWarps <= MultiBlockSMem::nbBuf); + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const initIdxBlock = wid - nbMathWarps; + // each warp loads data for a block + for (uint32_t idxBlock = initIdxBlock; idxBlock < nbSubSeq; idxBlock += MultiBlockSMem::nbIOWarps) + { + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxBlock; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.consumed.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + auto const lane = laneId(); +#pragma unroll + for (uint32_t iter = 0; iter < divUp(ctaNbValidQHeads, warp_size); iter++) + { + uint32_t const i = iter * warp_size + lane; + if (ctaNbValidQHeads % warp_size != 0 && i >= ctaNbValidQHeads) + { + break; + } + ldgsts::copyAsync( + &smem.rowSumMax[idxBuf][i], &scratchMem.rowSumMax()[idxChunk][i]); + } + ldgsts::barArrive(bar.produced, false); + if constexpr (isHeadPadded) + { + static_assert(grainsPerPaddedInputHead <= warp_size); + constexpr uint32_t headsPerIter = exactDiv(warp_size, grainsPerPaddedInputHead); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; +#pragma unroll + for (uint32_t i = 0; i < nbIters; i++) + { + uint32_t const idxHead = headsPerIter * i + + BoundedVal{lane}.template divBy().get(); + uint32_t const idxGrain + = BoundedVal{lane}.template mod().get(); + if (i < nbWholeIters || idxHead < ctaNbValidQHeads) + { + constexpr uint32_t nbElemsPerGrain = exactDiv(grainBytes, sizeof(MultiBlockSMem::Elem)); + auto const dst = &smem.tokens[idxBuf][idxHead][nbElemsPerGrain * idxGrain]; + auto const src = idxGrain < grainsPerIOHead + ? &scratchMem.tokens()[idxChunk][idxHead][nbElemsPerGrain * idxGrain] + : nullptr; + ldgsts::copyAsync(dst, src, idxGrain < grainsPerIOHead ? grainBytes : 0U); + } + } + ldgsts::barArrive(bar.produced, true); + } + else + { + if (warpElectSync()) + { + tma::loadLinearAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], + sizeof(smem.tokens[idxBuf]), bar.produced); + arrive_tx(bar.produced, sizeof(smem.tokens[idxBuf]), 1); + } + } + } + __syncthreads(); + uint32_t const idxBar = tid - warp_size * nbMathWarps; + if (idxBar < MultiBlockSMem::nbBuf * 2) + { + reinterpret_cast(&smem.barriers[0])[idxBar].~CtaBarrier(); + } + } + } +#else +#if GENERATE_CUBIN + static_assert("This kernel is for Hopper only"); +#else + asm volatile("trap;\n"); +#endif +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && BEAM_WIDTH == 1 +} + +#if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 +template +__device__ inline typename F16QToF8Converter::RegData F16QToF8Converter::load( + uint32_t tid, TinyPtr const& src, uint32_t const nbKHeads /*for beam search only*/, uint32_t nbTokens) +{ +#if !(SPEC_DEC) + assert(nbTokens == 1); + nbTokens = 1; +#endif + typename F16QToF8Converter::RegData dst; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) + { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) + { + break; + } +#if SPEC_DEC + uint32_t const idxToken = idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const tokenPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + tokenPad * idxToken; + static_assert(beamWidth == 1); +#else + uint32_t const idxBeam = beamWidth == 1 ? 0 : idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const beamPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + beamPad * idxBeam; +#endif + bool isGrainInBound = true; + if constexpr (isHeadPadded) + { + uint32_t const idxGrainInsideHead = offsetInGrains % grainsPerPaddedInputHead; + offsetInGrains = offsetInGrains / grainsPerPaddedInputHead * grainsPerIOHead + idxGrainInsideHead; + isGrainInBound = (idxGrainInsideHead < grainsPerIOHead); + } +#if SPEC_DEC + isGrainInBound = isGrainInBound && (idxToken < nbTokens); +#endif + LdGrain const srcGrain = isGrainInBound ? src.template cast()[offsetInGrains] : LdGrain{}; + static_assert(inputElemSize == 2); + auto const& fp16Data = reinterpret_cast const&>(srcGrain); + dst[iter] + = idxGrain % grainsPerPaddedInputHead < grainsPerIOHead ? fp16Data : mha::decay_t{}; + } + return dst; +} + +template +__device__ inline void F16QToF8Converter::store( + uint32_t tid, SharedMem::QBuffer& dst, F16QToF8Converter::RegData const& data) +{ +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) + { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) + { + break; + } +#if CACHE_ELEM_ENUM == 0 + static_assert(inputElemSize == cacheElemSize); + ShmVec const& shmData = data[iter]; + uint32_t const r = idxGrain / grainsPerPaddedInputHead; + BoundedVal const c = {idxGrain % grainsPerPaddedInputHead}; + + dst[c.template divBy().get()].template at(r, c.template mod().get()) + = reinterpret_cast(shmData); +#else + auto const& fp16Data = data[iter]; + ShmVec shmData; +#pragma unroll + for (uint32_t i = 0; i < fp16Data.size; i++) + { + shmData[i] = CacheElem{fp16Data[i]}; + } + uint32_t const dstIdxGrain = idxGrain / 2; + uint32_t const dstIdxHalfGrain = idxGrain % 2; + constexpr uint32_t grainsPerCacheHead = exactDiv(paddedCacheHeadBytes, grainBytes); + uint32_t const r = dstIdxGrain / grainsPerCacheHead; + BoundedVal const c = {dstIdxGrain % grainsPerCacheHead}; + reinterpret_cast&>(dst[c.template divBy().get()].template at( + r, c.template mod().get()))[dstIdxHalfGrain] + = shmData; +#endif + } +} +#endif + +__device__ inline KVTilePartLoader::KVTilePartLoader(bool isK, uint32_t nbKHeads, + KVCacheList const& cacheList, uint32_t idxReq, uint32_t idxHeadGrp, CUtensorMap const& tensorMap +#if USE_PAGED_KV_CACHE + , + uint32_t nbPages, Vec& pageBuf +#endif + ) + : nbKHeads{nbKHeads} + , cacheList{cacheList} + , idxReq{idxReq} + , idxHeadGrp{idxHeadGrp} + , tensorMap{tensorMap} +#if USE_PAGED_KV_CACHE + , nbPages{nbPages} + , pages{pageBuf} +#if PAGED_KV_CACHE_LAYOUT == 1 + , baseOffset{idxReq * cacheList.maxNbPagesPerSeq} +#else + , baseOffset{((idxReq * beamWidth) * 2 + (isK ? 0 : 1)) * cacheList.maxNbPagesPerSeq} +#endif +#else + , baseOffset{(idxReq * beamWidth) * 2 + (isK ? 0 : 1)} +#endif +{ +} + +// tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache +template +__device__ inline void KVTilePartLoader::loadData( + Array2D& dst, uint32_t idxTile, + uint32_t idxPart, CtaBarrier& bar) +{ + static_assert(nbTokens == gemm0CtaTileNbTokens); +#if USE_PAGED_KV_CACHE + assert(idxTile == idxTileRef); + if constexpr (nbTokens < tokensPerPage) + { + assert(nbPagesPerTile == 1); + uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t) pages[0]}, bar); +#else + tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[0]}, bar); +#endif + } + else + { +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) + { +#if PAGED_KV_CACHE_LAYOUT == 1 + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, idxHeadGrp, 0, (uint32_t) pages[i]}, bar); +#else + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t) pages[i]}, bar); +#endif + } + } +#else + tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); +#endif +} + +__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) +{ +#if USE_PAGED_KV_CACHE + uint32_t const idxPageBeg = gemm0CtaTileNbTokens >= tokensPerPage + ? nbPagesPerTile * idxTile + : idxTile / exactDiv(tokensPerPage, gemm0CtaTileNbTokens); +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) + { + uint32_t const idxPage = idxPageBeg + i; + auto const page = idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; + if (warpElectSync()) + { + pages[i] = page; + } + } + idxTileRef = idxTile; + __syncwarp(); +#endif +} + +__device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) +{ + constexpr uint32_t nbTokens = gemm0CtaTileNbTokens; +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + // Raise a runtime error indicating not implemented + assert(false && "KVTilePartLoader::getHead is not implemented for PAGED_KV_CACHE_LAYOUT == 1"); + __trap(); +#else + uint32_t const idxTile = pos / nbTokens; + assert(idxTile == idxTileRef); + uint32_t const offset = pos % tokensPerPage; + return cacheList.pool[tokensPerPage * (nbKHeads * pages[pos % nbTokens / tokensPerPage] + idxHeadGrp) + offset]; +#endif +#else + // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity] + return cacheList.data[cacheList.capacity * (baseOffset * nbKHeads + idxHeadGrp) + pos]; +#endif +} + +#if SWAP_AB +#if SPEC_DEC +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) +{ + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static_assert(SPEC_Q_SEQ_LEN <= sizeof(MaskType) * 8, "not implemented"); + + assert(cacheSeqLen >= SPEC_Q_SEQ_LEN); + uint32_t const maskStartRow = cacheSeqLen - SPEC_Q_SEQ_LEN; + uint32_t const tileStartRow = tileSize * idxTile; + if (tileStartRow + tileSize < maskStartRow) + { + return; + } + + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; + +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + uint32_t const maskCol = col / headGrpSize; + MaskType const bit_mask = (1ULL << (maskCol + 1)) - 1; + +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const globalRow = tileStartRow + row; + if (globalRow >= cacheSeqLen) + { + acc(m, n)(i, j) = safeInitRowMax; + continue; + } + if (globalRow >= maskStartRow) + { + uint32_t const maskRow = globalRow - maskStartRow; + if ((bit_mask >> maskRow) == 0) + { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } + } + } +} +#endif // SPEC_DEC + +// smemColMax is persistent across multiple iterations +__device__ inline RegColWiseVec computeWarpGrpColMax_sync( + CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src) +{ + auto colMax = RegColWiseVec::filled(Vec::filled(safeInitRowMax)); +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + colMax[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : fmax(colMax[n][j], src(m, n)(i, j)); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) + { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) + { + auto& x = colMax[n][j]; + x = fmax(x, __shfl_xor_sync(~0U, x, xorMask)); + } + } + } + + uint32_t const lane = laneId(); + if (lane < 4) + { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) + { + atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); + } + } + } + warpGrpBar.arrive_and_wait(); + uint32_t const idxInQuad = lane % 4; + +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]); + colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j]; + } + } + warpGrpBar.arrive_and_wait(); + return colMax; +} + +__device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec) +{ + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) + { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast< + Vec, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + smemVec)[i * nbThrdsPerInstNBase + idx]; + } + return ret; +} + +__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound) +{ + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) + { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast< + Vec, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + } + return ret; +} + +__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd) +{ + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + uint32_t const row = 64 * m + 16 * warpRank + 8 * i + idxQuad; + if (row >= validRowBeg && row < validRowEnd) + { + continue; + } +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } +} + +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax) +{ +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + float const maxVal = colMax[n][j]; + float const bias = maxVal * log2e; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); + } + } + } + } +} + +__device__ inline RegColWiseVec computeWarpColSum(Gemm0Acc& src) +{ + auto colSum = RegColWiseVec::filled(Vec::filled(0)); +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + colSum[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : colSum[n][j] + src(m, n)(i, j); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) + { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + auto& x = colSum[n][j]; + x += __shfl_xor_sync(~0U, x, xorMask); + } + } + } + return colSum; +} + +__device__ inline void storeGemm0AccToShm( + uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc) +{ +#if CACHE_ELEM_ENUM == 0 + using F16Acc = Array2D, Gemm0Acc::rows, Gemm0Acc::cols>; + F16Acc f16Acc; + reinterpret_cast&>(f16Acc) + = convert(reinterpret_cast const&>(acc)); + static_assert(Gemm0Acc::size == 1 || Gemm0Acc::size % 2 == 0); + uint32_t const idxHalf = lane / 16; + uint32_t const idxInHalf = lane % 16; + uint32_t const idxOctInsideHalf = idxInHalf / 8; + uint32_t const idxRowInsideOct = lane % 8; + uint32_t const warpBaseC = 16 * warpRank; + auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair + { + uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols; + uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols; + return {accR, accC}; + }; + auto const getDstAddr = [&](uint32_t idxAccCoreMat) -> LdGrain* + { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + static_assert(sizeof(MathElem) * gemm0CtaTileNbTokens == xPartBytes); + uint32_t const idxPart = 0; + uint32_t const dstR = accC * 8 + idxRowInsideOct; + uint32_t const dstC = exactDiv(gmma::instM * accR + warpBaseC + 8 * idxOctInsideHalf, cacheElemsPerGrain); + assert(dstC / exactDiv(xPartBytes, grainBytes) == idxPart); + return &smemX[idxPart].template at(dstR, dstC); + }; + auto const getAccData = [&](uint32_t idxAccCoreMat) + { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + return f16Acc(accR, accC); + }; + + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t iter = 0; iter < Gemm0Acc::size / 2; iter++) + { + auto const dstAddr = getDstAddr(iter * 2 + idxHalf); + Vec const data[2] = {getAccData(iter * 2), getAccData(iter * 2 + 1)}; + stmatrix(dstAddr, reinterpret_cast(data)); + } + if constexpr (Gemm0Acc::size % 2 != 0) + { + auto const dstAddr = lane < 16 ? getDstAddr(Gemm0Acc::size - 1) : nullptr; + stmatrix(dstAddr, getAccData(Gemm0Acc::size - 1)); + } +#elif CACHE_ELEM_ENUM == 2 + using F8Acc = Array2D; + F8Acc f8Acc; +#pragma unroll + for (uint32_t i = 0; i < acc.rows; i++) + { +#pragma unroll + for (uint32_t j = 0; j < acc.cols; j++) + { + auto const& core = acc(i, j); + static_assert(mha::is_same_v); + Vec const f8Data + = {__nv_cvt_float2_to_fp8x2(float2{core(0, 0), core(1, 0)}, __NV_SATFINITE, __NV_E4M3), + __nv_cvt_float2_to_fp8x2(float2{core(0, 1), core(1, 1)}, __NV_SATFINITE, __NV_E4M3)}; + f8Acc(i, j) = reinterpret_cast(f8Data); + } + } + + if constexpr (F8Acc::size == 4 || F8Acc::size == 2 || F8Acc::size == 1) + { + LdGrain* dst = nullptr; + if (F8Acc::size == 4 || lane < 8 * F8Acc::size) + { + uint32_t const idxCore = lane / 8; + uint32_t const srcRow = idxCore / F8Acc::cols; + uint32_t const srcCol = idxCore % F8Acc::cols; + uint32_t const dstCoreRow = lane % 8; + uint32_t const dstRow = srcCol * 8 + dstCoreRow; + BoundedVal const dstCol{srcRow * 4 + warpRank}; + dst = &smemX[dstCol.template divBy().get()].template at( + dstRow, dstCol.template mod().get()); + } + barConsumed.arrive_and_wait(); + stmatrix(dst, reinterpret_cast const&>(f8Acc)); + } + else + { + // we need to use loops + assert(false); + trap(); + } +#endif +} + +#else + +__device__ inline RegRowWiseVec warpRowWiseReduce( + RegRowWiseVec const& init, Gemm0Acc const& src, float (*op)(float, float)) +{ + RegRowWiseVec vec = init; +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + // @fixme: check if compiler is reordering these op to hide latency. + vec[m][i] = op(vec[m][i], src(m, n)(i, j)); + } + } + } + } + +#pragma unroll + for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) + { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + auto& x = vec[m][i]; + x = op(x, __shfl_xor_sync(~0U, x, xorMask)); + } + } + } + return vec; +} + +__device__ inline RegRowWiseVec computeWarpGrpRowMax_sync( + uint32_t warpRank, ShmQWiseVec& smemRowMax, Gemm0Acc const& src) +{ + assert(warpRank < 4); + RegRowWiseVec const init = loadShmRowWiseVecWithDup(warpRank, smemRowMax); + RegRowWiseVec rowMax = warpRowWiseReduce(init, src, fmax); + + storeShmRowWiseVec(warpRank, smemRowMax, rowMax); + __syncwarp(); + return rowMax; +} + +#if SPEC_DEC +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) +{ + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + auto const inputSeqLen = specDec.inputSeqLen; + auto const idxInputSubSeq = specDec.idxInputSubSeq; + constexpr uint64_t fullMask = ~uint64_t{0}; + static_assert(tileSize == sizeof(fullMask) * 8); +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + Range const tileRange = {tileSize * idxTile, tileSize * idxTile + tileSize}; + Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (inputTokensPerCta - 1)}; + bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end; + assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange)); + int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(tileSize * idxTile); +#else + constexpr bool ctaNeedBegMask = false; + uint64_t const begMask = fullMask; + int32_t const tok0NbMaskOut = -2147483648; +#endif + uint32_t const offset = tileSize * idxTile; + uint32_t const nbValidCols = mha::min(offset < cacheSeqLen ? cacheSeqLen - offset : 0U, tileSize); + bool const ctaNeedEndMask = (nbValidCols < tileSize); + bool const ctaNeedSpecDecMask = specDec.needMask(idxTile, 0); + bool const needMask = ctaNeedBegMask || ctaNeedEndMask || ctaNeedSpecDecMask; + if (!needMask) + { + return; + } + static_assert(tileSize == 64, "not implemented"); + auto const endMask = fullMask >> (tileSize - nbValidCols); + + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const idxQTokInCta = row / headGrpSize; + bool const isQTokValid + = (headGrpSize * inputTokensPerCta == ctaNbQHeads) || (idxQTokInCta < inputTokensPerCta); + auto const specDecMask = (isQTokValid && specDec.needMask(idxTile, idxQTokInCta)) + ? specDec.loadTileMaskRow(idxTile, idxQTokInCta) + : SpecDec::TileMaskRow{~0U, ~0U}; +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta); + uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask); +#else + uint64_t const begMask = fullMask; +#endif + auto const mask = begMask & endMask & reinterpret_cast(specDecMask); + if (mask == ~uint64_t{0}) + { + continue; + } +#if DBG_PRINT + if (idxInQuad == 0) + { + printf("mask at row %d: %lx\n", row, mask); + } +#endif +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + assert((col < nbValidCols) == bool(endMask & (1ULL << col))); + if ((mask & (1ULL << col)) == 0) + { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } + } +} +#else +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd) +{ + uint32_t const idxInQuad = laneId() % 4; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + if (col >= validColBeg && col < validColEnd) + { + continue; + } +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + acc(m, n)(i, j) = safeInitRowMax; + } + } + } + } +} +#endif + +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& rowMax) +{ +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + float const maxVal = rowMax[m][i]; + float const bias = maxVal * log2e; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); + } + } + } + } +} + +__device__ inline RegRowWiseVec computeWarpRowSum(Gemm0Acc& src) +{ + return warpRowWiseReduce(RegRowWiseVec{}, src, [](float a, float b) { return a + b; }); +} + +__device__ inline RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec) +{ + RegRowWiseVec vec; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) + { +#pragma unroll + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) + { + vec[m][i] = smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad]; + } + } + return vec; +} + +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, RegRowWiseVec const& regVec) +{ + uint32_t const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + bool const enable = (idxInQuad == 0); +#pragma unroll + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) + { +#pragma unroll + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) + { + assert(__shfl_sync(~0U, regVec[m][i], idxQuad * 4) == regVec[m][i]); + if (enable) + { + smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad] = regVec[m][i]; + } + } + } +} + +// for X +// order: 0,8,1,9, 2,10,3,11, 4,12,5,13, 6,14,7,15, ... +__device__ inline void storeGemm0AccToShm( + uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc) +{ + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t m = 0; m < Gemm0Acc::rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + Vec fp8Data; +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm0Acc::cols, 2); n++) + { + reinterpret_cast&>(fp8Data[n]) + = {__nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0)}), + __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)})}; + } + static_assert(decltype(fp8Data)::size == 4); + stmatrix_4x( + this_warp(), &smemX[m].template at(16 * warpRank + 8 * i + idxRow, idxMat), fp8Data); + } + } +} +#endif + +#if SWAP_AB +__device__ inline Vec loadVTileTransposed( + uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK) +{ + Vec fragA; + constexpr uint32_t instK = gmma::instK; +#pragma unroll + for (uint32_t i = 0; i < gemm1NbGmmaInstM; i++) + { + static_assert(exactDiv(gmma::instM, gmmaWarpsPerGrp) == grainBytes); + constexpr uint32_t grainsPerPart = exactDiv(cacheHeadPartBytes, grainBytes); +#if CACHE_ELEM_ENUM == 0 + uint32_t idxRow = lane % 8; + uint32_t idxMat = lane / 8; + uint32_t c = idxMat % 2; + uint32_t r = idxMat / 2; + auto const col = BoundedVal<2 * gmmaWarpsPerGrp * gemm1NbGmmaInstM>{2 * (gmmaWarpsPerGrp * i + warpRank) + c}; + auto const src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + 8 * r + idxRow, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i] = reinterpret_cast(data); +#elif CACHE_ELEM_ENUM == 2 + auto const col = BoundedVal{gmmaWarpsPerGrp * i + warpRank}; + LdGrain const* src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + lane, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i](0, 0)(0, 0) = prmt(data[0], data[1], {0, 4, 2, 6}); + fragA[i](0, 0)(1, 0) = prmt(data[0], data[1], {1, 5, 3, 7}); + fragA[i](0, 1)(0, 0) = prmt(data[2], data[3], {0, 4, 2, 6}); + fragA[i](0, 1)(1, 0) = prmt(data[2], data[3], {1, 5, 3, 7}); +#endif + } + return fragA; +} +#else +__device__ inline void transposeVTile( + uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src) +{ + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#pragma unroll + for (uint32_t m = 0; m < exactDiv(SharedMem::VTBuffer::rows, gmma::instM); m++) + { + static_assert(cacheHeadPartElems >= gmma::instM); + uint32_t const idxPart = gmma::instM * m / cacheHeadPartElems; + constexpr uint32_t grainsPerCacheHeadPart = exactDiv(cacheHeadPartElems, cacheElemsPerGrain); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(SharedMem::VTBuffer::cols, 2); n++) + { + LdGrain const a = ldmatrix_4x(this_warp(), + &src[idxPart].template at(32 * n + lane, + exactDiv(gmma::instM, cacheElemsPerGrain) * m - grainsPerCacheHeadPart * idxPart + warpRank)); + LdGrain const b = {prmt(a[0], a[1], {0, 4, 2, 6}), prmt(a[0], a[1], {1, 5, 3, 7}), + prmt(a[2], a[3], {0, 4, 2, 6}), prmt(a[2], a[3], {1, 5, 3, 7})}; + uint32_t const i = idxMat % 2; + uint32_t const j = idxMat / 2; + stmatrix_4x( + this_warp(), &dst.template at(gmma::instM * m + 16 * warpRank + 8 * i + idxRow, 2 * n + j), b); + } + } +} +#endif + +#if SWAP_AB +__device__ inline Vec loadShmColWiseVecNoDup(ShmQWiseVec const& shmVec) +{ + Vec ret; +#pragma unroll + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) + { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; +} + +__device__ inline void storeShmColWiseVecNoDup( + ShmQWiseVec& shmVec, Vec const& src) +{ +#pragma unroll + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) + { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + if (inBound) + { + shmVec[idx] = src[i]; + } + } +} +#else +__device__ inline Vec +loadShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec const& shmVec) +{ + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) + { + uint32_t const idx = gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound + = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; +} + +__device__ inline void storeShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec& shmVec, + Vec const& src) +{ + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) + { + uint32_t const idx = gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound + = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || (idx < ShmQWiseVec::size)); + if (inBound) + { + shmVec[idx] = src[i]; + } + } +} +#endif + +#if SWAP_AB +__device__ inline void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, + ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar) +{ + auto accColSum = loadShmColWiseVecNoDup(shmAccColSum); + + auto const xColMax = loadShmColWiseVecNoDup(shmXColMax); + auto const accColMax = loadShmColWiseVecNoDup(shmAccColMax); + auto token = gemm1WarpGrpBar.arrive(); + auto const needRescaleVec = (accColMax < xColMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) + { + assert(accColMax[i] <= xColMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + if (anyNeedRescale) + { + auto const scaleVec = expf(accColMax - xColMax); + auto const lane = laneId(); +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) + { + uint32_t const vecIdx = gmma::instNBase * n / warp_size; + uint32_t const offset = gmma::instNBase * n % warp_size; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + auto const mask = ((rescaleMask[vecIdx] >> (offset + j)) & 0b01010101U); + auto getScale = [&] { + return __shfl_sync( + ~0U, scaleVec[vecIdx], offset + lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols + j); + }; + assert((getScale() != 1) == ((mask >> lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols) & 0x1U)); + bool const needRescale = (mask != 0); + if (!needRescale) + { // this branch is warp-uniform + continue; + } + float const scale = getScale(); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + acc(m, n)(i, j) *= scale; + } + } + } + } + accColSum = accColSum * scaleVec; + } + gemm1WarpGrpBar.wait(mha::move(token)); + + // @fixme: with atomic, we can let the first warp reaching here to do the update, instead of always warp 3. + uint32_t const warpRankForUpdate = gmmaWarpsPerGrp - 1; + if (warpRank == warpRankForUpdate) + { + if (anyNeedRescale) + { + storeShmColWiseVecNoDup(shmAccColMax, xColMax); + } +#pragma unroll + for (uint32_t i = 0; i < gemm0NbWarps; i++) + { + accColSum = accColSum + loadShmColWiseVecNoDup(shmXColSum[i]); + } + storeShmColWiseVecNoDup(shmAccColSum, accColSum); + } + gemm1WarpGrpBar.arrive_and_wait(); +} +#else +__device__ inline void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, + ShmQWiseVec const& shmXRowSum, ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, ShmQWiseVec& shmAccRowSum) +{ + auto accRowSum = loadShmRowWiseVecNoDup(warpRank, shmAccRowSum); + auto const xRowMax = loadShmRowWiseVecNoDup(warpRank, shmXRowMax); + auto const accRowMax = loadShmRowWiseVecNoDup(warpRank, shmAccRowMax); + assert(all(xRowMax >= accRowMax)); + auto const needRescaleVec = (accRowMax < xRowMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) + { + assert(accRowMax[i] <= xRowMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + + if (anyNeedRescale) + { + auto const scaleVec = expf(accRowMax - xRowMax); + auto const lane = laneId(); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + uint8_t const mask = reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; + bool const needRescale = (mask != 0); + if (needRescale) + { // this branch is warp-uniform + float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + acc(m, n)(i, j) *= scale; + } + } + } + } + } + accRowSum = accRowSum * scaleVec; + } + __syncwarp(); + auto const xRowSum = loadShmRowWiseVecNoDup(warpRank, shmXRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowSum, accRowSum + xRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowMax, xRowMax); + __syncwarp(); +} +#endif + +#if SWAP_AB +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegColWiseVec const& scale) +{ +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + acc(m, n)(i, j) *= scale[n][j]; + } + } + } + } +} +#else +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegRowWiseVec const& scale) +{ +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) + { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) + { + acc(m, n)(i, j) *= scale[m][i]; + } + } + } + } +} +#endif + +#if SWAP_AB +// @fixme: consider make this noinline +template +__device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRank, DstHead* dst, + SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc const& acc, CtaBarrier& warpGrpBar, uint32_t nbKHeads) +{ + uint32_t const lane = laneId(); +#if CACHE_ELEM_ENUM == 0 + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#elif CACHE_ELEM_ENUM == 2 + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; +#endif +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { +#pragma unroll + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) + { + auto const& core = acc(m, n); +#if CACHE_ELEM_ENUM == 0 + Vec f16Core; + reinterpret_cast&>(f16Core) + = convert(reinterpret_cast const&>(core)); + auto const dst = idxMat < 2 + ? &swizzleBuf.template at(8 * n + idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat) + : nullptr; + stmatrix(dst, f16Core); +#elif CACHE_ELEM_ENUM == 2 + // each row is part of a b16 8x8 matrix and is transposed + Array2D coreTrans; + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + static_assert(GmmaAccCoreMat::cols == 2 && sizeof(InputElem) == 2); + InputElem2 const coreRow = float2ToInputElem2({core(i, 0), core(i, 1)}); + auto const coreRowTrans = movmatrix(reinterpret_cast(coreRow)); + reinterpret_cast(coreTrans(i, 0)) = coreRowTrans; + } + // expect compiler to generate two PRMT instructions + Vec const data = {coreTrans(0, 0), coreTrans(1, 0), coreTrans(0, 1), coreTrans(1, 1)}; + swizzleBuf.template at(gmma::instNBase * n + idxQuad, + (gmma::instM * m + exactDiv(gmma::instM, gmmaWarpsPerGrp) * warpRank) / 16)[idxInQuad] + = data; +#endif + } + } + warpGrpBar.arrive_and_wait(); + + constexpr uint32_t headsPerIter = exactDiv(grainBytes * gemm1NbThrds, paddedInputHeadBytes); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; + constexpr uint32_t nbGrainsPerHead = exactDiv(paddedInputHeadBytes, grainBytes); + uint32_t const idxHeadBase = threadRank / nbGrainsPerHead; + uint32_t const idxGrain = threadRank % nbGrainsPerHead; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) + { + uint32_t const idxHead = idxHeadBase + iter * headsPerIter; + if ((iter < nbWholeIters || idxHead < ctaNbValidQHeads) && (!isHeadPadded || idxGrain < grainsPerIOHead)) + { +#if CACHE_ELEM_ENUM == 0 + auto const data = swizzleBuf.template at(idxHead, idxGrain); +#elif CACHE_ELEM_ENUM == 2 + auto const data + = reinterpret_cast&>(swizzleBuf.template at(idxHead, idxGrain / 2))[idxGrain % 2]; +#endif + constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); + auto const outVec + = convert(reinterpret_cast const&>(data)); + uint32_t dstHeadIdx = idxHead; +#ifdef SPEC_Q_SEQ_LEN + if constexpr (dstIsStrided) + { + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken < SPEC_Q_SEQ_LEN) + { + uint32_t const strideBetweenTokens = nbKHeads * headGrpSize; + dstHeadIdx = idxToken * strideBetweenTokens + (idxHead % headGrpSize); + } + } +#endif + reinterpret_cast, nbGrainsPerHead>&>(dst[dstHeadIdx])[idxGrain] = outVec; + } + } +} + +template +__device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, + SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) +{ + // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp + // static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of + // mufu.rcp"); + auto regColSum = loadShmColWiseVecWithDup(accColSum); + if (attentionSinksVec != nullptr) + { + auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); + auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); + auto regColSinks = expf(regAttentionSinks - regAccColMax); + regColSum = regColSum + regColSinks; + } + auto const regOutScale = __frcp_rn(regColSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + saveTransposedOutput(threadRank, warpRank, dst, swizzleBuf, acc, warpGrpBar, nbKHeads); + warpGrpBar.arrive_and_wait(); +} +#else +template +__device__ inline void finalizeAndWriteOut_sync(uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, ShmQWiseVec const& accRowSum, + uint32_t nbKHeads /* for spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens) +{ + auto const regRowSum = loadShmRowWiseVecWithDup(warpRank, accRowSum); + auto const regOutScale = __frcp_rn(regRowSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + using DstElem = typename DstHead::Elem; + auto const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + using Atom = Vec, 4>; + using SwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; + static_assert(sizeof(SwizzleBuf) <= sizeof(swizzleBuf)); + auto& buf = reinterpret_cast(swizzleBuf); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) + { + uint32_t const r = gmma::instM * m + 16 * warpRank + 8 * i + idxQuad; + static_assert(SwizzleBuf::cols == exactDiv(Gemm1Acc::cols, 2)); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm1Acc::cols, 2); n++) + { + Vec const v = convert(Vec{ + acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0), acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)}); + //@fixme: without reinterpret_cast to V, the compiler generates wrong code, and require a __syncwarp() + // after rescaleAcc() to work around. Likely a bug of the compiler. + //@todo: report a compiler bug. + using V = Vec; + reinterpret_cast(buf.template at(r, n)[idxInQuad]) = reinterpret_cast(v); + // buf.template at(r, n)[idxInQuad] = v; + } + } + } + __syncwarp(); + +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) + { + constexpr uint32_t srcHeadBytes = sizeof(DstElem) * headElems; + constexpr uint32_t grpSize = exactDiv(srcHeadBytes, grainBytes); + constexpr uint32_t nbGrps = exactDiv(warp_size, grpSize); + uint32_t const idxGrp = lane / grpSize; + constexpr uint32_t grainsPerAtom = exactDiv(sizeof(Atom), grainBytes); + uint32_t const rowBase = gmma::instM * m + 16 * warpRank; + constexpr uint32_t totalNbGrains = grainsPerAtom * SwizzleBuf::cols * 16; + uint32_t const nbIters = divUp(totalNbGrains, nbGrps); + constexpr bool wholeIters = (totalNbGrains % nbGrps == 0); + constexpr bool wholeHeads = (validElemsPerHead == headElems); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) + { + uint32_t const idxGrain = nbGrps * iter + idxGrp; + constexpr uint32_t grainsPerSrcHead = exactDiv(srcHeadBytes, grainBytes); + uint32_t const r = idxGrain / grainsPerSrcHead; + if (!wholeIters && r >= 16) + { + break; + } + uint32_t const cGrain = idxGrain % grainsPerSrcHead; + uint32_t const cAtom = cGrain / grainsPerAtom; + constexpr uint32_t grainsPerDstHead = exactDiv(sizeof(DstHead), grainBytes); + uint32_t const glbRow = gmma::instM * m + 16 * warpRank + r; + if (ctaNbValidQHeads != ctaNbQHeads && glbRow >= ctaNbValidQHeads) + { + break; + } + if (wholeHeads || cGrain < grainsPerDstHead) + { + uint32_t const srcRow = rowBase + r; + auto const data = reinterpret_cast( + buf.template at(srcRow, cAtom))[cGrain % grainsPerAtom]; +#if SPEC_DEC + static_assert(beamWidth == 1); + uint32_t const idxToken = srcRow / headGrpSize; // inside CTA + if (idxToken >= ctaNbValidTokens) + { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const dstRow = srcRow + idxToken * tokenPad; +#else + uint32_t const dstRow = srcRow; +#endif + reinterpret_cast(dst[dstRow])[cGrain] = data; + } + } + } +} +#endif + +template +__device__ inline Vec, ropeNbPairsPerThrd> loadHead( + Vec const& head, uint32_t tid) +{ + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + Vec, nbPairsPerThrd> ret; + if constexpr (forNeox) + { + auto const& pairs = reinterpret_cast, nbWorkingThrds>, 2> const&>(head); + auto const data = isWorkingThrd ? Vec, 2>{pairs[0][tid], pairs[1][tid]} + : Vec, 2>{}; + Vec, 2> const tmp = {convert(data[0]), convert(data[1])}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) + { + ret[i][0] = tmp[0][i]; + ret[i][1] = tmp[1][i]; + } + } + else + { + auto const data = isWorkingThrd ? reinterpret_cast, nbPairsPerThrd> const*>(&head)[tid] + : Vec, nbPairsPerThrd>{}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) + { + ret[i] = convert(data[i]); + } + } + return ret; +} + +template +__device__ inline mha::conditional_t, 2>, + Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, Vec, nbPairsPerThrd> const& ropeCosSin) +{ + Vec, nbPairsPerThrd> r; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) + { + float const x = data[i][0]; + float const y = data[i][1]; + float const c = ropeCosSin[i][0]; + float const s = ropeCosSin[i][1]; + r[i] = Vec{c * x - s * y, s * x + c * y}; + } + if constexpr (forNeox) + { + Vec, 2> tmp; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) + { + tmp[0][i] = r[i][0]; + tmp[1][i] = r[i][1]; + } + return Vec, 2>{convert(tmp[0]), convert(tmp[1])}; + } + else + { + Vec, nbPairsPerThrd> ret; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) + { + ret[i] = convert(r[i]); + } + return ret; + } +} + +template +__device__ inline void storeRotatedPairsForKV(GMemCacheHead& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t tid) +{ + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (!isWorkingThrd) + { + return; + } + if constexpr (forNeox) + { + auto& pairs = reinterpret_cast, nbWorkingThrds>, 2>&>(dst); + pairs[0][tid] = src[0]; + pairs[1][tid] = src[1]; + } + else + { + reinterpret_cast, nbPairsPerThrd>*>(&dst)[tid] = src; + } +} + +template +__device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst, + mha::conditional_t>, 2>, + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t row, uint32_t tid) +{ + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (isWorkingThrd) + { + if constexpr (forNeox) + { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) + { + auto const byteOffset + = BoundedVal{cacheElemSize * nbPairsPerThrd * (nbWorkingThrds * i + tid)}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); + static_assert( + cacheElemSize * nbPairsPerThrd <= grainBytes && grainBytes % (cacheElemSize * nbPairsPerThrd) == 0); + reinterpret_cast&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) + = src[i]; + } + } + else + { + auto const byteOffset = BoundedVal{cacheElemSize * 2 * nbPairsPerThrd * tid}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); + static_assert(cacheElemSize * 2 * nbPairsPerThrd <= grainBytes + && grainBytes % (cacheElemSize * 2 * nbPairsPerThrd) == 0); + reinterpret_cast, nbPairsPerThrd>&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) + = src; + } + } + static_assert(validElemsPerHead % 16 == 0); + __syncwarp(); + if constexpr (validElemsPerHead < headElems) + { + static_assert(validElemsPerHead >= headElems - exactDiv(headElems, nbQParts)); + constexpr uint32_t nbPadGrainsPerHead = exactDiv(headElems - validElemsPerHead, cacheElemsPerGrain); + constexpr uint32_t nbPadGrains = nbPadGrainsPerHead * ctaNbQHeads; + uint32_t const nbIters = divUp(nbPadGrains, nbThrds); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) + { + uint32_t idx = tid + nbThrds * iter; + if (idx >= nbPadGrains) + { + break; + } + uint32_t const r = idx / nbPadGrainsPerHead; + uint32_t const c = grainsPerQPart - nbPadGrainsPerHead + idx % nbPadGrainsPerHead; + dst[dst.size - 1].template at(r, c) = LdGrain{}; + } + } +} + +#ifndef GENERATE_CUBIN +void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, +#if SLIDING_WINDOW + uint32_t slidingWinSize, +#endif + float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif +#if USE_INPUT_KV + InputHead const* qkv, +#if ROPE_STYLE != 0 + Vec const* ropeCosSin, +#endif +#else + InputHead const* q, +#endif + float const* attentionSinks, // [headGrpSize] +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, // global pool of pages +#endif + KVCachePageIndex const* + kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] +#else + GMemKVCacheHead* kvCacheData, +#endif + uint32_t maxSeqLen, uint32_t const* seqLen, +#if USE_BEAM_SEARCH + BeamSearchParams const& beamSearchParams, +#endif + uint32_t batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for + // int8/fp8 KV cache. +#if SPEC_DEC + SpecDecParams const& specDecParams, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) +{ + if (beamWidth != 1) + { + throw std::runtime_error("not implemented"); + } + static uint32_t const hostSmemSize = [&]() + { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + // printf("smemSize = %u\n", hostSmemSize); + uint32_t const nbVHeads = nbKHeads; + uint32_t const nbQHeads = nbKHeads * headGrpSize; + uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t + { + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) + { + int32_t const val = std::stoi(env); + if (val > 0) + { + return val; + } + } + float const factor = 0.25f; + return mha::min( + mha::max(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); +#if SPEC_DEC + uint32_t const qSeqLen = specDecParams.qSeqLen; +#else + uint32_t const qSeqLen = 1; +#endif + // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == nbInputSeqSplit + dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] + { + if (std::is_same_v) + { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } + else if (std::is_same_v) + { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } + else if (std::is_same_v) + { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; + + auto const tensorMapVLLMK = makeTensorMapForPagedKVCache( + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = makeTensorMapForPagedKVCache( + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = makeTensorMapForPagedKVCache( + pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); +#endif + + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + attentionSinks, cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, tensorMapVLLMV, +#else + tensorMap, +#endif +#if SPEC_DEC + specDecParams, +#endif + semaphores, scratch); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache(kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, + validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif +#if USE_INPUT_KV + qkv, +#if ROPE_STYLE != 0 + ropeCosSin, +#endif +#else + q, +#endif + attentionSinks, cacheList, +#if USE_BEAM_SEARCH + beamSearchParams, +#endif + batchSize, kvCacheScale, tensorMap, semaphores, scratch); +#endif + checkCuda(err); +} +#endif + +void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, + float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + InputHead const* q, float const* attentionSinks, GMemCacheHead* pool, + KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, + uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, +#if SPEC_DEC + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) + { + static uint32_t const hostSmemSize = [&]() + { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t + { + float const factor = 0.25f; + return mha::min( + mha::max(1U, (uint32_t) round(multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); + #if SPEC_DEC + auto specDecParams = SpecDecParams{qSeqLen, qCuSeqLens, mask}; + uint32_t const qLen = qSeqLen; + #else + uint32_t const qLen = 1; + #endif + dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); + #if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] + { + if (std::is_same_v) + { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } + else if (std::is_same_v) + { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } + else if (std::is_same_v) + { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + + #if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; + + auto const tensorMapVLLMK = makeTensorMapForPagedKVCache( + kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = makeTensorMapForPagedKVCache( + vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); + #else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = makeTensorMapForPagedKVCache( + pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); + #endif + + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, + #if SLIDING_WINDOW + slidingWinSize, + #endif + qScale, output, + #if LOW_PREC_OUTPUT + rcpOutScale, + #endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, + #if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, tensorMapVLLMV, + #else + tensorMap, + #endif + #if SPEC_DEC + specDecParams, + #endif + semaphores, scratch); + #else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache(kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, + validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, + #if SLIDING_WINDOW + slidingWinSize, + #endif + qScale, output, + #if LOW_PREC_OUTPUT + rcpOutScale, + #endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, tensorMap, semaphores, scratch); + #endif + checkCuda(err); + } +#endif diff --git a/csrc/xqa/tensorMap.cpp b/csrc/xqa/tensorMap.cpp new file mode 100644 index 0000000000..58a608aada --- /dev/null +++ b/csrc/xqa/tensorMap.cpp @@ -0,0 +1,93 @@ +#include "tensorMap.h" +#include "utils.h" +#include +#include +#include + +uint32_t getElemBytes(CUtensorMapDataType_enum dataType) +{ + switch (dataType) + { + case CU_TENSOR_MAP_DATA_TYPE_UINT8: return 1; + case CU_TENSOR_MAP_DATA_TYPE_UINT16: return 2; + case CU_TENSOR_MAP_DATA_TYPE_UINT32: return 4; + case CU_TENSOR_MAP_DATA_TYPE_INT32: return 4; + case CU_TENSOR_MAP_DATA_TYPE_UINT64: return 8; + case CU_TENSOR_MAP_DATA_TYPE_INT64: return 8; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: return 4; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: return 8; + case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: return 4; + default: throw std::runtime_error("unsupported data type"); + } +} + +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, + uint32_t nbKHeads, uint32_t maxCacheLen, uint32_t beamWidth, uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens) +{ + CUtensorMap tensorMap{}; + uint64_t const globalDims[] = {headElems, maxCacheLen, nbKHeads, 2 * beamWidth * batchSize}; + uint32_t elemBytes = getElemBytes(dataType); + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * maxCacheLen, headBytes * maxCacheLen * nbKHeads}; + uint32_t const boxDims[] = {partElems, nbTokens, 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] + { + switch (partElems) + { + case 128: return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: return CU_TENSOR_MAP_SWIZZLE_64B; + default: throw std::runtime_error("unsupported cache head size"); + } + }(); + + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, globalStrides, boxDims, + elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; +} + +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, + uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t partElems, uint32_t nbTokensPerTile) +{ + CUtensorMap tensorMap{}; + uint32_t elemBytes = getElemBytes(dataType); +// VLLM Layout +#if PAGED_KV_CACHE_LAYOUT == 1 + uint64_t const globalDims[] = {headElems, nbKHeads, tokensPerPage, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * nbKHeads, headBytes * nbKHeads * tokensPerPage}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, 1, mha::min(tokensPerPage, nbTokensPerTile), 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; + // XQA Original Layout +#else + uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage, headBytes * tokensPerPage * nbKHeads}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, mha::min(tokensPerPage, nbTokensPerTile), 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; +#endif + + auto const swizzle = [&] + { + switch (partBytes) + { + case 128: return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: return CU_TENSOR_MAP_SWIZZLE_64B; + default: throw std::runtime_error("unsupported cache head size"); + } + }(); + + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, globalStrides, boxDims, + elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; +} diff --git a/csrc/xqa/tensorMap.h b/csrc/xqa/tensorMap.h new file mode 100644 index 0000000000..83ecb54252 --- /dev/null +++ b/csrc/xqa/tensorMap.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +uint32_t getElemBytes(CUtensorMapDataType_enum dataType); + +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, + uint32_t nbKHeads, uint32_t maxCacheLen, uint32_t beamWidth, uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens); + +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, + uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t partElems, uint32_t nbTokensPerTile); diff --git a/csrc/xqa/tma.h b/csrc/xqa/tma.h new file mode 100644 index 0000000000..38d7e43928 --- /dev/null +++ b/csrc/xqa/tma.h @@ -0,0 +1,319 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement + * + * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual + * property and proprietary rights in and to this material, related + * documentation and any modifications thereto. Any use, reproduction, + * disclosure or distribution of this material and related documentation + * without an express license agreement from NVIDIA CORPORATION or + * its affiliates is strictly prohibited. + */ + +#pragma once + +#include "cuda_hint.cuh" +#include "utils.h" +#ifndef GENERATE_CUBIN +#include +#include +#include +#endif +#include "barriers.cuh" + +enum class StateSpace +{ + kCONSTANT, + kPARAMETER, + kGENERIC +}; + +#ifdef GENERATE_CUBIN +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +typedef struct CUtensorMap_st +{ +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + uint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +} CUtensorMap; +#endif + +namespace tma +{ + +__device__ inline void loadLinearAsync(void* dst, void const* src, uint32_t nbBytes, CtaBarrier& bar) +{ + asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +__device__ inline void prefetchLinear(void const* src, uint32_t nbBytes) +{ + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(reinterpret_cast(src)), "r"(nbBytes) + : "memory"); +} + +// dsr and &bar must be remote address generated by mapa and src must be local address +__device__ inline void sm2smCopyAsync(void* dst, void const* src, uint32_t nbBytes, CgaBarrier& bar) +{ + asm volatile("cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +template +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, CtaBarrier& bar) +{ + if constexpr (nbDims == 1) + { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2}], [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else if constexpr (nbDims == 2) + { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3}], [%4];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else if constexpr (nbDims == 3) + { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4}], " + "[%5];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else if constexpr (nbDims == 4) + { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4, " + "%5}], [%6];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else if constexpr (nbDims == 5) + { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4, %5, " + "%6}], [%7];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else + { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +template +__device__ inline void loadAsync( + void* dst, CUtensorMap const& tensorMap, DimsLE offset, CtaBarrier& bar, uint64_t cacheHint) +{ + if constexpr (nbDims == 1) + { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2}], [%3], %4;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } + else if constexpr (nbDims == 2) + { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3}], [%4], %5;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } + else if constexpr (nbDims == 3) + { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3, %4}], [%5], %6;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } + else if constexpr (nbDims == 4) + { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3, %4, %5}], [%6], %7;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } + else if constexpr (nbDims == 5) + { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(&bar)), + "l"(cacheHint) + : "memory"); + } + else + { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +// shared::cta -> global +__device__ inline void store1DAsync(void* dst, void const* src, uint32_t nbBytes) +{ + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(reinterpret_cast(dst)), "l"(__cvta_generic_to_shared(src)), "r"(nbBytes)); +} + +template +__device__ inline void storeAsync(CUtensorMap const& tensorMap, DimsLE const& offset, void* src) +{ + if constexpr (nbDims == 1) + { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group.tile [%0, {%1}], [%2];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } + else if constexpr (nbDims == 2) + { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2}], [%3];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } + else if constexpr (nbDims == 3) + { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } + else if constexpr (nbDims == 4) + { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), + "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } + else if constexpr (nbDims == 5) + { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), + "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } + else + { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) +{ + asm volatile("tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap), "l"(ptr) + : "memory"); +} + +__device__ inline void commitGroup() +{ + asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); +} + +// wait until only targetNbInFlightGroups groups are still in-flight. +template +__device__ inline void waitGroup() +{ + asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups) : "memory"); +} + +__device__ inline void prefetchTensorMap(CUtensorMap const& tensorMap, StateSpace loc = StateSpace::kGENERIC) +{ + assert(reinterpret_cast(&tensorMap) % alignof(CUtensorMap) == 0); + switch (loc) + { + case StateSpace::kCONSTANT: + asm volatile("prefetch.const.tensormap [%0];\n" ::"l"(__cvta_generic_to_constant(&tensorMap)) : "memory"); + break; + case StateSpace::kPARAMETER: + asm volatile("prefetch.param.tensormap [%0];\n" ::"l"(__cvta_generic_to_grid_constant(&tensorMap)) : "memory"); + break; + case StateSpace::kGENERIC: + asm volatile("prefetch.tensormap [%0];\n" ::"l"(reinterpret_cast(&tensorMap)) : "memory"); + break; + default: asm volatile("trap;\n"); + } +} + +template +__device__ inline void storeAsync(void* dst, T const& src, CgaBarrier& bar) +{ + constexpr uint32_t nbWords = exactDiv(sizeof(T), sizeof(uint32_t)); + Vec const& srcVec = reinterpret_cast const&>(src); + if constexpr (nbWords == 1) + { + asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else if constexpr (nbWords == 2) + { + asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.u32 [%0], {%1, %2}, [%3];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else if constexpr (nbWords == 4) + { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v4.u32 [%0], {%1, %2, %3, %4}, [%5];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } + else + { + static_assert(nbWords == 1 || nbWords == 2 || nbWords == 4, "src size must be 4, 8 or 16 bytes"); + } +} + +} // namespace tma diff --git a/csrc/xqa/utils.cuh b/csrc/xqa/utils.cuh index 5883e5b834..2804d2b322 100644 --- a/csrc/xqa/utils.cuh +++ b/csrc/xqa/utils.cuh @@ -31,7 +31,13 @@ #include "barriers.cuh" inline constexpr float log2e = 1.4426950408889634; // std::log2(M_E) -inline constexpr float safeInitRowMax = -1e+30F; +// we used an optimization where exp(x-rowMax) is computed as: +/* bias = rowMax * log2e // shared for the whole row + exp(x-rowMax) = exp2f(x * log2e - bias) +*/ +// But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and +// x is too large. For this reason, don't set safeInitRowMax with a huge absolute value. +inline constexpr float safeInitRowMax = -1e+5F; inline constexpr int32_t kBAD_PAGE_INDEX = -1; __constant__ constexpr float kE4M3_MAX = 448.F; diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index a7bbbfaf0c..a71967da3d 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -17,8 +17,8 @@ #include "../pytorch_extension_utils.h" #include "mha.h" -void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingWinSize, - double qScale, at::Tensor output, +void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads, + int64_t slidingWinSize, double qScale, at::Tensor output, #if LOW_PREC_OUTPUT at::Tensor rcpOutScale, #endif @@ -34,20 +34,40 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW ? reinterpret_cast(attentionSinks.data_ptr()) : nullptr; - launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, - reinterpret_cast(output.data_ptr()), + if (run_fp8_mha) { + launchHopperF8MHAFlashInfer( + multiProcessorCount, nbKHeads, slidingWinSize, qScale, + reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT - reinterpret_cast(rcpOutScale.data_ptr()), + reinterpret_cast(rcpOutScale.data_ptr()), #endif - reinterpret_cast(q.data_ptr()), attentionSinksPtr, - reinterpret_cast(pool.data_ptr()), - reinterpret_cast(kvCachePageList.data_ptr()), - maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, - reinterpret_cast(kvCacheScale.data_ptr()), + reinterpret_cast(q.data_ptr()), attentionSinksPtr, + reinterpret_cast(pool.data_ptr()), + reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, + reinterpret_cast(seqLen.data_ptr()), batchSize, + reinterpret_cast(kvCacheScale.data_ptr()), #if SPEC_DEC - qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), - reinterpret_cast(mask.data_ptr()), + qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), + reinterpret_cast(mask.data_ptr()), #endif - reinterpret_cast(semaphores.data_ptr()), - reinterpret_cast(scratch.data_ptr()), stream); + reinterpret_cast(semaphores.data_ptr()), + reinterpret_cast(scratch.data_ptr()), stream); + } else { + launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, + reinterpret_cast(output.data_ptr()), +#if LOW_PREC_OUTPUT + reinterpret_cast(rcpOutScale.data_ptr()), +#endif + reinterpret_cast(q.data_ptr()), attentionSinksPtr, + reinterpret_cast(pool.data_ptr()), + reinterpret_cast(kvCachePageList.data_ptr()), + maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, + reinterpret_cast(kvCacheScale.data_ptr()), +#if SPEC_DEC + qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), + reinterpret_cast(mask.data_ptr()), +#endif + reinterpret_cast(semaphores.data_ptr()), + reinterpret_cast(scratch.data_ptr()), stream); + } } diff --git a/flashinfer/aot.py b/flashinfer/aot.py index 54eb4f6c3d..630e9eee67 100644 --- a/flashinfer/aot.py +++ b/flashinfer/aot.py @@ -353,7 +353,8 @@ def gen_attention( def gen_xqa( - use_fp16_: List[bool], + fp16_input_: List[bool], + fp8_kv_cache_: List[bool], token_per_page_: List[int], head_size_: List[int], head_grp_size_: List[int], @@ -365,13 +366,15 @@ def gen_xqa( return # XQA requires SM90+ for ( - use_fp16, + fp16_input, + fp8_kv_cache, token_per_page, head_size, head_grp_size, use_sliding_window, ) in product( - use_fp16_, + fp16_input_, + fp8_kv_cache_, token_per_page_, head_size_, head_grp_size_, @@ -384,7 +387,8 @@ def gen_xqa( continue yield gen_xqa_module( - use_fp16=use_fp16, + fp16_input=fp16_input, + fp8_kv_cache=fp8_kv_cache, token_per_page=token_per_page, head_size=head_size, head_grp_size=head_grp_size, @@ -491,14 +495,16 @@ def gen_all_modules( if add_xqa: # Define XQA configurations to iterate over - xqa_use_fp16_ = [True, False] # fp16 and bf16 + xqa_fp16_input_ = [True, False] # fp16 and bf16 + xqa_fp8_kv_cache_ = [True, False] xqa_token_per_page_ = [16, 32, 64, 128] xqa_head_size_ = [64, 128, 256] xqa_head_grp_size_ = [1, 2, 4, 8] # Different group sizes for MQA/GQA jit_specs += list( gen_xqa( - xqa_use_fp16_, + xqa_fp16_input_, + xqa_fp8_kv_cache_, xqa_token_per_page_, xqa_head_size_, xqa_head_grp_size_, diff --git a/flashinfer/xqa.py b/flashinfer/xqa.py index fc34ef6c0b..9581a0f6b9 100644 --- a/flashinfer/xqa.py +++ b/flashinfer/xqa.py @@ -31,7 +31,7 @@ xqa_nvcc_flags = [ "-DNDEBUG=1", "-DBEAM_WIDTH=1", - "-DCACHE_ELEM_ENUM=0", + "-DUSE_INPUT_KV=0", "-DUSE_CUSTOM_BARRIER=1", "-DLOW_PREC_OUTPUT=0", "-DSPEC_DEC=0", @@ -39,16 +39,22 @@ def gen_xqa_module( - use_fp16: bool, + fp16_input: bool, + fp8_kv_cache: bool, token_per_page: int, head_size: int, head_grp_size: int, use_sliding_window: bool, ) -> JitSpec: - if use_fp16: - flag_use_fp16 = ["-DINPUT_FP16=1", "-DDTYPE=__half"] + if fp16_input: + flag_data_type = ["-DINPUT_FP16=1", "-DDTYPE=__half"] else: - flag_use_fp16 = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"] + flag_data_type = ["-DINPUT_FP16=0", "-DDTYPE=__nv_bfloat16"] + + if fp8_kv_cache: + flag_data_type.append("-DCACHE_ELEM_ENUM=2") + else: + flag_data_type.append("-DCACHE_ELEM_ENUM=0") if token_per_page not in [16, 32, 64, 128]: raise ValueError( @@ -70,9 +76,11 @@ def gen_xqa_module( flag_sliding_window = ["-DSLIDING_WINDOW=0"] return gen_jit_spec( - f"xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", + f"xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", [ jit_env.FLASHINFER_CSRC_DIR / "xqa/mha.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/mha_sm90.cu", + jit_env.FLASHINFER_CSRC_DIR / "xqa/tensorMap.cpp", jit_env.FLASHINFER_CSRC_DIR / "xqa/xqa_wrapper.cu", jit_env.FLASHINFER_CSRC_DIR / "flashinfer_xqa_ops.cu", ], @@ -80,29 +88,37 @@ def gen_xqa_module( + sm90a_nvcc_flags + flag_tokens_per_page + flag_head_size - + flag_use_fp16 + + flag_data_type + flag_head_grp_size + flag_sliding_window, + extra_ldflags=["-lcuda"], # Add CUDA Driver API library ) @functools.cache def get_xqa_module( - use_fp16: bool, + fp16_input: bool, + fp8_kv_cache: bool, token_per_page: int, head_size: int, head_grp_size: int, use_sliding_window: bool, ): module = gen_xqa_module( - use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window + fp16_input, + fp8_kv_cache, + token_per_page, + head_size, + head_grp_size, + use_sliding_window, ).build_and_load() @register_custom_op( - f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", + f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}", mutates_args=("output", "scratch"), ) def xqa( + run_fp8_mha: bool, multiProcessorCount: int, nbKHeads: int, slidingWinSize: int, @@ -120,6 +136,7 @@ def xqa( scratch: torch.Tensor, ) -> None: module.xqa_wrapper.default( + run_fp8_mha, multiProcessorCount, nbKHeads, slidingWinSize, @@ -138,9 +155,10 @@ def xqa( ) @register_fake_op( - f"flashinfer::xqa_use_fp16_{use_fp16}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}" + f"flashinfer::xqa_fp16_input_{fp16_input}_fp8_kv_cache_{fp8_kv_cache}_token_per_page_{token_per_page}_head_size_{head_size}_head_grp_size_{head_grp_size}_use_sliding_window_{use_sliding_window}" ) def _fake_xqa( + run_fp8_mha: bool, multiProcessorCount: int, nbKHeads: int, slidingWinSize: int, @@ -165,7 +183,9 @@ def _fake_xqa( def xqa( - use_fp16: bool, + fp16_input: bool, + fp8_kv_cache: bool, + run_fp8_mha: bool, token_per_page: int, head_size: int, head_grp_size: int, @@ -189,9 +209,15 @@ def xqa( if get_compute_capability(torch.device(device="cuda"))[0] != 9: raise RuntimeError("XQA is only supported on SM90 GPUs") xqa_module = get_xqa_module( - use_fp16, token_per_page, head_size, head_grp_size, use_sliding_window + fp16_input, + fp8_kv_cache, + token_per_page, + head_size, + head_grp_size, + use_sliding_window, ) xqa_module.xqa( + run_fp8_mha, multiProcessorCount, nbKHeads, sliding_win_size if use_sliding_window else 0, diff --git a/tests/test_xqa.py b/tests/test_xqa.py index 2bdbb9e579..51d0d2159a 100644 --- a/tests/test_xqa.py +++ b/tests/test_xqa.py @@ -153,7 +153,9 @@ def ref_attention( reason="XQA is only supported on SM90 GPUs", ) @pytest.mark.parametrize("use_sliding_window", [True, False]) -@pytest.mark.parametrize("use_fp16", [True, False]) +@pytest.mark.parametrize("fp16_input", [True, False]) +@pytest.mark.parametrize("fp8_kv_cache", [True, False]) +@pytest.mark.parametrize("run_fp8_mha", [True, False]) @pytest.mark.parametrize("use_attention_sinks", [True, False]) @pytest.mark.parametrize("seq_len", [2, 15, 256, 514]) @pytest.mark.parametrize("batch_size", [1, 4]) @@ -166,7 +168,9 @@ def test_xqa( nb_k_heads, seq_len, tokens_per_page, - use_fp16, + fp16_input, + fp8_kv_cache, + run_fp8_mha, valid_elems_per_head, head_grp_size, use_attention_sinks, @@ -185,7 +189,7 @@ def test_xqa( beam_width, nb_q_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) output.fill_(float("nan")) @@ -194,7 +198,7 @@ def test_xqa( beam_width, nb_q_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) q_heads.normal_(0, 1) @@ -219,10 +223,12 @@ def test_xqa( cache_heads = torch.zeros( total_nb_cache_heads, valid_elems_per_head, - dtype=torch.bfloat16 if not use_fp16 else torch.float16, + dtype=torch.bfloat16 if not fp16_input else torch.float16, device="cuda", ) cache_heads.normal_(0, 1) + if fp8_kv_cache: + cache_heads /= 4.0 nb_pages_per_seq = div_up(max_seq_len, tokens_per_page) total_nb_pages = nb_pages_per_seq * 2 * beam_width * batch_size @@ -295,7 +301,9 @@ def cache_head_at( scratch_buf = torch.zeros(scratch_size, dtype=torch.uint8, device="cuda") xqa( - use_fp16, + fp16_input, + fp8_kv_cache, + run_fp8_mha, tokens_per_page, valid_elems_per_head, head_grp_size, @@ -307,7 +315,7 @@ def cache_head_at( output, q_heads, attention_sinks, - cache_heads, + cache_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_heads, page_list_arg, max_seq_len, seq_len_list, @@ -354,4 +362,21 @@ def cache_head_at( kernel_output = output[req][b][ idx_k_head * head_grp_size : (idx_k_head + 1) * head_grp_size ].to(torch.float32) - assert torch.allclose(ref_output, kernel_output, atol=0.01, rtol=0.01) + if fp8_kv_cache or run_fp8_mha: + atol = 0.05 + rtol = 0.05 + else: + atol = 0.01 + rtol = 0.01 + + diff_abs = torch.abs(ref_output - kernel_output) + diff_rel = diff_abs / (torch.abs(ref_output) + 1e-8) + + within_tolerance = (diff_abs <= atol) | (diff_rel <= rtol) + + pass_ratio = within_tolerance.float().mean().item() + + required_ratio = 0.99 + assert pass_ratio >= required_ratio, ( + f"Total {ref_output.numel()} elements, only {pass_ratio:.1%} meet tolerance criteria, require at least {required_ratio:.1%}" + ) From 966c5d567a85c96b5fc795ed1fff3f49cf026edd Mon Sep 17 00:00:00 2001 From: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Date: Thu, 25 Sep 2025 01:02:46 -0700 Subject: [PATCH 2/3] fix format Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- csrc/xqa/gmma.cuh | 207 +- csrc/xqa/gmma_impl.cuh | 8675 +++++++++++++++++++++------------------- csrc/xqa/mha_sm90.cu | 5091 +++++++++++------------ csrc/xqa/tensorMap.cpp | 174 +- csrc/xqa/tensorMap.h | 14 +- csrc/xqa/tma.h | 485 ++- 6 files changed, 7388 insertions(+), 7258 deletions(-) diff --git a/csrc/xqa/gmma.cuh b/csrc/xqa/gmma.cuh index f5148d8b9b..d1b2547fcd 100644 --- a/csrc/xqa/gmma.cuh +++ b/csrc/xqa/gmma.cuh @@ -20,103 +20,84 @@ #include #include -namespace gmma -{ - -enum class SwizzleMode : uint64_t -{ - kNONE = 0, - k128 = 1, - k64 = 2, - k32 = 3 -}; +namespace gmma { -struct MatDesc -{ - uint64_t addr : 16; - uint64_t dimKOffset : 16; - uint64_t dimMNOffset : 16; - uint64_t pad0 : 1; - uint64_t baseOffset : 3; - uint64_t pad1 : 10; - SwizzleMode swizzle : 2; - - enum class Raw : uint64_t - { - }; - - [[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const - { - MatDesc ret = *this; - ret.addr = encode(__cvta_generic_to_shared(data)); - return ret; - } +enum class SwizzleMode : uint64_t { kNONE = 0, k128 = 1, k64 = 2, k32 = 3 }; - static __device__ inline uint32_t encode(uint32_t val) - { - return (val & 0x3FFFFU) >> 4; - } +struct MatDesc { + uint64_t addr : 16; + uint64_t dimKOffset : 16; + uint64_t dimMNOffset : 16; + uint64_t pad0 : 1; + uint64_t baseOffset : 3; + uint64_t pad1 : 10; + SwizzleMode swizzle : 2; - __device__ inline bool operator==(MatDesc const& other) const - { - return raw() == other.raw(); - } + enum class Raw : uint64_t {}; - __device__ inline Raw const& raw() const - { - static_assert(sizeof(MatDesc) == 8); - return reinterpret_cast(*this); - } + [[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const { + MatDesc ret = *this; + ret.addr = encode(__cvta_generic_to_shared(data)); + return ret; + } - static __device__ inline MatDesc fromRaw(Raw const& raw) - { - return reinterpret_cast(raw); - } + static __device__ inline uint32_t encode(uint32_t val) { return (val & 0x3FFFFU) >> 4; } + + __device__ inline bool operator==(MatDesc const& other) const { return raw() == other.raw(); } + + __device__ inline Raw const& raw() const { + static_assert(sizeof(MatDesc) == 8); + return reinterpret_cast(*this); + } + + static __device__ inline MatDesc fromRaw(Raw const& raw) { + return reinterpret_cast(raw); + } }; static_assert(sizeof(MatDesc) == 8); -[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) -{ - assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0); - MatDesc::Raw ret = base; - auto& u32x2 = reinterpret_cast(ret); - u32x2[0] += static_cast(__cvta_generic_to_shared(data)) >> 4; - return ret; +[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) { + assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0); + MatDesc::Raw ret = base; + auto& u32x2 = reinterpret_cast(ret); + u32x2[0] += static_cast(__cvta_generic_to_shared(data)) >> 4; + return ret; } -__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset, - void const* patternStartAddr, SwizzleMode swizzleMode) -{ - uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr); - uint32_t const baseAlign = [&]() -> uint32_t - { - switch (swizzleMode) - { - case SwizzleMode::kNONE: return 1; - case SwizzleMode::k128: return 1024; - case SwizzleMode::k64: return 512; - case SwizzleMode::k32: return 256; - } - asm volatile("trap;\n"); - return 0; - }(); - uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); - return MatDesc{ - /*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), - /*dimKOffset=*/MatDesc::encode(dimKByteOffset), - /*dimMNOffset=*/MatDesc::encode(dimMNByteOffset), - /*pad0=*/0, - /*baseOffset=*/baseOffset, - /*pad1=*/0, - /*swizzle=*/swizzleMode, - }; +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, + uint32_t dimMNByteOffset, void const* patternStartAddr, + SwizzleMode swizzleMode) { + uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr); + uint32_t const baseAlign = [&]() -> uint32_t { + switch (swizzleMode) { + case SwizzleMode::kNONE: + return 1; + case SwizzleMode::k128: + return 1024; + case SwizzleMode::k64: + return 512; + case SwizzleMode::k32: + return 256; + } + asm volatile("trap;\n"); + return 0; + }(); + uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); + return MatDesc{ + /*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), + /*dimKOffset=*/MatDesc::encode(dimKByteOffset), + /*dimMNOffset=*/MatDesc::encode(dimMNByteOffset), + /*pad0=*/0, + /*baseOffset=*/baseOffset, + /*pad1=*/0, + /*swizzle=*/swizzleMode, + }; } -__device__ inline MatDesc makeMatDesc( - void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset, SwizzleMode swizzleMode) -{ - return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode); +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, + uint32_t dimMNByteOffset, SwizzleMode swizzleMode) { + return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode); } inline constexpr uint32_t instM = 64; @@ -129,50 +110,36 @@ inline constexpr uint32_t instNBase = 8; // for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N // acc is used as both input and output. template -__device__ void mma_async_shmA( - float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal); +__device__ void mma_async_shmA(float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal); template -__device__ void mma_async_regA( - float (&acc)[exactDiv(n, instNBase)][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal); +__device__ void mma_async_regA(float (&acc)[exactDiv(n, instNBase)][2][2], + uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal); -__device__ inline void fence() -{ - asm volatile("wgmma.fence.sync.aligned;\n"); -} +__device__ inline void fence() { asm volatile("wgmma.fence.sync.aligned;\n"); } -__device__ inline void commit_group() -{ - asm volatile("wgmma.commit_group.sync.aligned;\n"); -} +__device__ inline void commit_group() { asm volatile("wgmma.commit_group.sync.aligned;\n"); } template -__device__ inline void wait_group() -{ - asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups)); +__device__ inline void wait_group() { + asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups)); } template -constexpr SwizzleMode getSwizzleMode(Array2D const&) -{ - constexpr auto rowBytes = Array2D::rowBytes; - if constexpr (!swizzle) - { - return SwizzleMode::kNONE; - } - if constexpr (rowBytes % 128 == 0) - { - return SwizzleMode::k128; - } - else if constexpr (rowBytes == 64) - { - return SwizzleMode::k64; - } - else - { - static_assert(rowBytes == 32); - return SwizzleMode::k32; - } +constexpr SwizzleMode getSwizzleMode(Array2D const&) { + constexpr auto rowBytes = Array2D::rowBytes; + if constexpr (!swizzle) { + return SwizzleMode::kNONE; + } + if constexpr (rowBytes % 128 == 0) { + return SwizzleMode::k128; + } else if constexpr (rowBytes == 64) { + return SwizzleMode::k64; + } else { + static_assert(rowBytes == 32); + return SwizzleMode::k32; + } } -} // namespace gmma +} // namespace gmma #include "gmma_impl.cuh" diff --git a/csrc/xqa/gmma_impl.cuh b/csrc/xqa/gmma_impl.cuh index 96bc721ca3..b9515ddea9 100644 --- a/csrc/xqa/gmma_impl.cuh +++ b/csrc/xqa/gmma_impl.cuh @@ -19,8 +19,7 @@ #include #include -namespace gmma -{ +namespace gmma { // cog template. Do code generation with: pip install cogapp; cog -r $filename // clang-format off @@ -145,4442 +144,4828 @@ const&>(descB)), "n"(false)); // clang-format on template <> -__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_fp8_e4m3, 8, false, false>( - float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 8, false, false>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 16, false, false>( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 16, false, false>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>( - float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>( - float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_fp8_e4m3, 32, false, false>( - float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 32, false, false>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 64, false, false>( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 64, false, false>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_fp8_e4m3, 64, false, false>( - float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 64, false, false>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 128, false, false>( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 128, false, false>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> __device__ inline void mma_async_regA<__nv_fp8_e4m3, 128, false, false>( - float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 256, false, false>( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 256, false, false>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> __device__ inline void mma_async_regA<__nv_fp8_e4m3, 256, false, false>( - float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 0>( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 0>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 0>( - float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 0>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 1>( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 1>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 1>( - float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "{%4, %5, %6, %7},\n" // a - "%8,\n" // b-desc - "%9, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 1>(float (&acc)[1][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 0>( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 0>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[1][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 1>( - float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3},\n" // d - "%4,\n" // a-desc - "%5,\n" // b-desc - "%6, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 1>(float (&acc)[1][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 0>( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 0>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 0>( - float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 0>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 1>( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 1>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 1>( - float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "{%8, %9, %10, %11},\n" // a - "%12,\n" // b-desc - "%13, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 1>(float (&acc)[2][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 0>( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 0>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[2][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>( - float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d - "%8,\n" // a-desc - "%9,\n" // b-desc - "%10, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>(float (&acc)[2][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>( - float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>( - float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "{%12, %13, %14, %15},\n" // a - "%16,\n" // b-desc - "%17, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>(float (&acc)[3][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[3][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>( - float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d - "%12,\n" // a-desc - "%13,\n" // b-desc - "%14, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>(float (&acc)[3][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 0>( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 0>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 0>( - float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 0>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 1>( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 1>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 1>( - float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "{%16, %17, %18, %19},\n" // a - "%20,\n" // b-desc - "%21, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 1>(float (&acc)[4][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 0>( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 0>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[4][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 1>( - float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d - "%16,\n" // a-desc - "%17,\n" // b-desc - "%18, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 1>(float (&acc)[4][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 0>( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 0>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 0>( - float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 0>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 1>( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 1>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 1>( - float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "{%32, %33, %34, %35},\n" // a - "%36,\n" // b-desc - "%37, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 1>(float (&acc)[8][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 0>( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 0>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[8][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 1>( - float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d - "%32,\n" // a-desc - "%33,\n" // b-desc - "%34, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 1>(float (&acc)[8][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 0>( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 0>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 0>( - float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 0>(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 1>( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 1>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 1>( - float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "{%64, %65, %66, %67},\n" // a - "%68,\n" // b-desc - "%69, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 1>(float (&acc)[16][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 0>( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 0>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[16][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 1>( - float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d - "%64,\n" // a-desc - "%65,\n" // b-desc - "%66, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 1>(float (&acc)[16][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 0>( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 0>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 0>( - float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 0>(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA( - float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 1>( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 0, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 1>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 1>( - float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "{%128, %129, %130, %131},\n" // a - "%132,\n" // b-desc - "%133, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), - "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 1>(float (&acc)[32][2][2], + uint32_t const (&a)[2][2][1], + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 0>( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 0;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 0>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA(float (&acc)[32][2][2], MatDesc::Raw descA, + MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } template <> -__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 1>( - float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) -{ - if (accHasVal) - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); - } - else - { - asm volatile( - "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" - "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " - "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " - "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " - "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " - "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " - "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " - "%123, %124, %125, %126, %127},\n" // d - "%128,\n" // a-desc - "%129,\n" // b-desc - "%130, 1, 1, 1, 1;\n" - : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), - "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), - "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), - "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), - "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), - "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), - "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), - "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), - "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), - "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), - "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), - "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), - "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), - "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), - "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), - "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), - "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), - "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), - "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), - "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), - "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), - "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), - "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), - "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), - "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), - "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) - : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); - } +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 1>(float (&acc)[32][2][2], + MatDesc::Raw descA, + MatDesc::Raw descB, + bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, " + "%19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, " + "%41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, " + "%62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, " + "%83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, " + "%103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, " + "%121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), + "+f"(acc[1][0][0]), "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), + "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), + "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), "+f"(acc[3][1][1]), + "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), + "+f"(acc[6][0][0]), "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), + "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), + "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), "+f"(acc[8][1][1]), + "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), + "+f"(acc[11][0][0]), "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), + "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), + "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), "+f"(acc[13][1][1]), + "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), + "+f"(acc[16][0][0]), "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), + "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), + "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), "+f"(acc[18][1][1]), + "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), + "+f"(acc[21][0][0]), "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), + "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), + "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), "+f"(acc[23][1][1]), + "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), + "+f"(acc[26][0][0]), "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), + "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), + "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), "+f"(acc[28][1][1]), + "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), + "+f"(acc[31][0][0]), "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), + "l"(reinterpret_cast(descB)), "n"(false)); + } } //[[[end]]] -} // namespace gmma +} // namespace gmma diff --git a/csrc/xqa/mha_sm90.cu b/csrc/xqa/mha_sm90.cu index 7fa5234874..286ee08ec5 100644 --- a/csrc/xqa/mha_sm90.cu +++ b/csrc/xqa/mha_sm90.cu @@ -23,9 +23,10 @@ #endif #ifndef GENERATE_CUBIN +#include + #include "hostUtils.h" #include "tensorMap.h" -#include #endif #include "gmma.cuh" #include "mha.h" @@ -44,16 +45,18 @@ static_assert(specDecQLen * headGrpSize <= 32, "SPEC_Q_SEQ_LEN macro value is to #define SWAP_AB (!SPEC_DEC) #endif -#define IS_SUPPORTED_F16_CASE (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT) +#define IS_SUPPORTED_F16_CASE \ + (CACHE_ELEM_ENUM == 0 && !SPEC_DEC && SWAP_AB && !USE_INPUT_KV && !LOW_PREC_OUTPUT) inline constexpr bool swapAB = SWAP_AB; #pragma region Config -static_assert( - (inputElemSize == cacheElemSize && mha::is_same_v) || inputElemSize > cacheElemSize); -using MathElem - = mha::conditional_t<(inputElemSize > cacheElemSize && mha::is_same_v), InputElem, CacheElem>; +static_assert((inputElemSize == cacheElemSize && mha::is_same_v) || + inputElemSize > cacheElemSize); +using MathElem = + mha::conditional_t<(inputElemSize > cacheElemSize && mha::is_same_v), + InputElem, CacheElem>; constexpr uint32_t gmmaWarpsPerGrp = 4; constexpr uint32_t gmmaWarpGrpSize = warp_size * gmmaWarpsPerGrp; @@ -67,18 +70,15 @@ constexpr uint32_t ctaNbValidQHeads = ctaNbQHeads; #elif SPEC_DEC && SWAP_AB inline constexpr uint32_t inputTokensPerCta = specDecQLen; inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * inputTokensPerCta; -inline constexpr uint32_t ctaNbQHeads = []() -{ - static_assert(ctaNbValidQHeads <= 32, "ctaNbValidQHeads cannot exceed 32"); - if constexpr (ctaNbValidQHeads <= 8) - { - return 8; - } - if constexpr (ctaNbValidQHeads <= 16) - { - return 16; - } - return 32; +inline constexpr uint32_t ctaNbQHeads = []() { + static_assert(ctaNbValidQHeads <= 32, "ctaNbValidQHeads cannot exceed 32"); + if constexpr (ctaNbValidQHeads <= 8) { + return 8; + } + if constexpr (ctaNbValidQHeads <= 16) { + return 16; + } + return 32; }(); #else inline constexpr uint32_t ctaNbValidQHeads = headGrpSize * beamWidth; @@ -94,18 +94,19 @@ constexpr uint32_t gemm1CtaTileNbTokens = gemm0CtaTileNbTokens; constexpr uint32_t mathHeadBytes = sizeof(Vec); constexpr uint32_t nbIOWarps = 4; constexpr uint32_t nbIOThrds = warp_size * nbIOWarps; -constexpr uint32_t multiBlockMinNbTilesPerCta = 1; // 3; // @fixme: need tuning +constexpr uint32_t multiBlockMinNbTilesPerCta = 1; // 3; // @fixme: need tuning constexpr uint32_t multiBlockMinNbTiles = multiBlockMinNbTilesPerCta * 2; constexpr uint32_t nbWarps = gemm0NbWarps + gemm1NbWarps + nbIOWarps; constexpr uint32_t cacheHeadPartBytes = mha::min(paddedCacheHeadBytes, 128U); -constexpr uint32_t cacheHeadNbParts - = exactDiv(paddedCacheHeadBytes, cacheHeadPartBytes); // @fixme: support divUp in the future +constexpr uint32_t cacheHeadNbParts = + exactDiv(paddedCacheHeadBytes, cacheHeadPartBytes); // @fixme: support divUp in the future constexpr uint32_t cacheHeadPartElems = exactDiv(headElems, cacheHeadNbParts); constexpr uint32_t swizzleBytes = cacheHeadPartBytes; static_assert(swizzleBytes == 128 || swizzleBytes == 64 || swizzleBytes == 32); -constexpr bool needInputCvt = inputElemSize > cacheElemSize&& mha::is_same_v; +constexpr bool needInputCvt = + inputElemSize > cacheElemSize&& mha::is_same_v; constexpr bool needCacheCvt = inputElemSize > cacheElemSize&& mha::is_same_v; static_assert(needInputCvt || needCacheCvt || mha::is_same_v); @@ -131,109 +132,100 @@ using PaddedOutHead = PaddedInputHead; #pragma endregion Config -struct alignas(128) SharedMem -{ - using KBuffer = Array2D; - static constexpr uint32_t nbKBuf = 2; - KBuffer k[nbKBuf]; // as is loaded from global mem. - using XBuffer = Vec, nbXParts>; - static constexpr uint32_t nbXBuf - = 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens ? 1 : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); - using VBuffer = Vec, - cacheHeadNbParts>; +struct alignas(128) SharedMem { + using KBuffer = Array2D; + static constexpr uint32_t nbKBuf = 2; + KBuffer k[nbKBuf]; // as is loaded from global mem. + using XBuffer = Vec, nbXParts>; + static constexpr uint32_t nbXBuf = + 2 * (gemm0CtaTileNbTokens >= gemm1CtaTileNbTokens + ? 1 + : exactDiv(gemm1CtaTileNbTokens, gemm0CtaTileNbTokens)); + using VBuffer = + Vec, + cacheHeadNbParts>; #if !SWAP_AB - using VTBuffer = Array2D; + using VTBuffer = + Array2D; #endif - static constexpr uint32_t nbVBuf = 2; + static constexpr uint32_t nbVBuf = 2; #if CACHE_ELEM_ENUM == 0 - using OutSwizzleBuf = Array2D; + using OutSwizzleBuf = Array2D; #elif CACHE_ELEM_ENUM == 2 - using OutSwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; + using OutSwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; #endif - static_assert(nbXBuf == nbVBuf); + static_assert(nbXBuf == nbVBuf); - union ReusedXVOutSwizzleBuf - { - struct XV - { - XBuffer x; - VBuffer v; + union ReusedXVOutSwizzleBuf { + struct XV { + XBuffer x; + VBuffer v; #if !SWAP_AB - VTBuffer vt; + VTBuffer vt; #endif - // @fixme: also put xColMax and xColSum here - } xv; + // @fixme: also put xColMax and xColSum here + } xv; - OutSwizzleBuf outSwizzle; - } reusedXVOutSwizzleBuf[nbXBuf]; + OutSwizzleBuf outSwizzle; + } reusedXVOutSwizzleBuf[nbXBuf]; - static_assert(sizeof(OutSwizzleBuf) <= sizeof(SharedMem::ReusedXVOutSwizzleBuf::XV), - "need to use split output to avoid excessive shared memory usage"); + static_assert(sizeof(OutSwizzleBuf) <= sizeof(SharedMem::ReusedXVOutSwizzleBuf::XV), + "need to use split output to avoid excessive shared memory usage"); - __device__ inline XBuffer& xBuf(uint32_t i) - { - return reusedXVOutSwizzleBuf[i].xv.x; - } + __device__ inline XBuffer& xBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.x; } - __device__ inline VBuffer& vBuf(uint32_t i) - { - return reusedXVOutSwizzleBuf[i].xv.v; - } + __device__ inline VBuffer& vBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.v; } #if !SWAP_AB - __device__ inline VTBuffer& vtBuf(uint32_t i) - { - return reusedXVOutSwizzleBuf[i].xv.vt; - } + __device__ inline VTBuffer& vtBuf(uint32_t i) { return reusedXVOutSwizzleBuf[i].xv.vt; } #endif - __device__ inline OutSwizzleBuf& outSwizzleBuf(uint32_t i) - { - return reusedXVOutSwizzleBuf[i].outSwizzle; - } + __device__ inline OutSwizzleBuf& outSwizzleBuf(uint32_t i) { + return reusedXVOutSwizzleBuf[i].outSwizzle; + } - using QBuffer = Vec, nbQParts>; - QBuffer q; // For gmma math. Conversion done if needed. + using QBuffer = Vec, nbQParts>; + QBuffer q; // For gmma math. Conversion done if needed. - // @fixme: move these into reusedXVOutSwizzleBuf + // @fixme: move these into reusedXVOutSwizzleBuf #if SWAP_AB - ShmQWiseVec xColMax[nbXBuf]; - ShmQWiseVec xColSum[nbXBuf][gemm0NbWarps]; + ShmQWiseVec xColMax[nbXBuf]; + ShmQWiseVec xColSum[nbXBuf][gemm0NbWarps]; #else - ShmQWiseVec xRowMax[nbXBuf]; - ShmQWiseVec xRowSum[nbXBuf]; + ShmQWiseVec xRowMax[nbXBuf]; + ShmQWiseVec xRowSum[nbXBuf]; #endif - ShmQWiseVec gemm0CurrentSeqMax; - // col sum and max for the current gemm1 acc. Use shared memory to save some registers. register storage will be 8x - // duplicate for swapAB and 4x duplicate for non-swapAB. - ShmQWiseVec gemm1AccColMax; - ShmQWiseVec gemm1AccColSum; + ShmQWiseVec gemm0CurrentSeqMax; + // col sum and max for the current gemm1 acc. Use shared memory to save some registers. register + // storage will be 8x duplicate for swapAB and 4x duplicate for non-swapAB. + ShmQWiseVec gemm1AccColMax; + ShmQWiseVec gemm1AccColSum; #if USE_PAGED_KV_CACHE - static constexpr uint32_t nbPagesPerTile - = gemm0CtaTileNbTokens >= tokensPerPage ? exactDiv(gemm0CtaTileNbTokens, tokensPerPage) : 1; - Vec pages[2]; // one for K and one for V + static constexpr uint32_t nbPagesPerTile = + gemm0CtaTileNbTokens >= tokensPerPage ? exactDiv(gemm0CtaTileNbTokens, tokensPerPage) : 1; + Vec pages[2]; // one for K and one for V #endif - // mem barriers + // mem barriers - CtaBarrierPair qBar; - CtaBarrierPair kBar[nbKBuf]; - CtaBarrierPair vBar[nbVBuf]; + CtaBarrierPair qBar; + CtaBarrierPair kBar[nbKBuf]; + CtaBarrierPair vBar[nbVBuf]; #if !SWAP_AB - CtaBarrierPair vtBar[nbVBuf]; + CtaBarrierPair vtBar[nbVBuf]; #endif - CtaBarrierPair xBar[nbXBuf]; + CtaBarrierPair xBar[nbXBuf]; - // used internally in the gemm0 warp group - // @fixme: use separate arrive and wait for all usage - CtaBarrier gemm0WarpGrpBar; + // used internally in the gemm0 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm0WarpGrpBar; - // used internally in the gemm1 warp group - // @fixme: use separate arrive and wait for all usage - CtaBarrier gemm1WarpGrpBar; + // used internally in the gemm1 warp group + // @fixme: use separate arrive and wait for all usage + CtaBarrier gemm1WarpGrpBar; - bool isLastCta; + bool isLastCta; }; CUBIN_EXPORT __device__ constexpr uint32_t smemSize = sizeof(SharedMem); @@ -246,78 +238,80 @@ constexpr uint32_t nbQLdThrds = warp_size * nbQLdWarps; #if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 template -struct F16QToF8Converter -{ - static_assert(inputElemSize == 2); - using F16Vec = Vec; +struct F16QToF8Converter { + static_assert(inputElemSize == 2); + using F16Vec = Vec; #if CACHE_ELEM_ENUM == 0 - using ShmVec = F16Vec; + using ShmVec = F16Vec; #elif CACHE_ELEM_ENUM == 2 - using F8Vec = Vec; - using ShmVec = F8Vec; + using F8Vec = Vec; + using ShmVec = F8Vec; #endif - static constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); - static constexpr uint32_t grainsPerPaddedInputQHeadGrp = grainsPerPaddedInputHead * headGrpSize; + static constexpr uint32_t grainsPerPaddedInputHead = exactDiv(paddedInputHeadBytes, grainBytes); + static constexpr uint32_t grainsPerPaddedInputQHeadGrp = grainsPerPaddedInputHead * headGrpSize; #if !(SPEC_DEC) - static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * beamWidth; + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * beamWidth; #else - static_assert(beamWidth == 1); - static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * inputTokensPerCta; + static_assert(beamWidth == 1); + static constexpr uint32_t totalGrains = grainsPerPaddedInputQHeadGrp * inputTokensPerCta; #endif - static constexpr uint32_t nbIters = divUp(totalGrains, nbThrds); + static constexpr uint32_t nbIters = divUp(totalGrains, nbThrds); - using RegData = Vec; + using RegData = Vec; - static __device__ RegData load(uint32_t tid, TinyPtr const& src, - uint32_t const nbKHeads /*for beam search and spec dec*/, uint32_t nbTokens); - static __device__ void store(uint32_t tid, SharedMem::QBuffer& dst, RegData const& data); + static __device__ RegData load(uint32_t tid, TinyPtr const& src, + uint32_t const nbKHeads /*for beam search and spec dec*/, + uint32_t nbTokens); + static __device__ void store(uint32_t tid, SharedMem::QBuffer& dst, RegData const& data); }; -#endif // CACHE_ELEM_ENUM +#endif // CACHE_ELEM_ENUM -struct KVTilePartLoader -{ - static constexpr uint32_t nbParts = cacheHeadNbParts; - static constexpr uint32_t partElems = exactDiv(headElems, nbParts); +struct KVTilePartLoader { + static constexpr uint32_t nbParts = cacheHeadNbParts; + static constexpr uint32_t partElems = exactDiv(headElems, nbParts); #if USE_PAGED_KV_CACHE - static_assert(gemm0CtaTileNbTokens % tokensPerPage == 0 || tokensPerPage % gemm0CtaTileNbTokens == 0); - static constexpr uint32_t nbPagesPerTile = SharedMem::nbPagesPerTile; + static_assert(gemm0CtaTileNbTokens % tokensPerPage == 0 || + tokensPerPage % gemm0CtaTileNbTokens == 0); + static constexpr uint32_t nbPagesPerTile = SharedMem::nbPagesPerTile; #endif - uint32_t const nbKHeads; - KVCacheList const& cacheList; - uint32_t const idxReq; - uint32_t const idxHeadGrp; + uint32_t const nbKHeads; + KVCacheList const& cacheList; + uint32_t const idxReq; + uint32_t const idxHeadGrp; - CUtensorMap const& tensorMap; + CUtensorMap const& tensorMap; #if USE_PAGED_KV_CACHE - uint32_t const nbPages; // for bound check - Vec& pages; - uint32_t idxTileRef; // idxTile used to load the pages + uint32_t const nbPages; // for bound check + Vec& pages; + uint32_t idxTileRef; // idxTile used to load the pages #endif - uint32_t const baseOffset; + uint32_t const baseOffset; - __device__ KVTilePartLoader(bool isK, uint32_t nbKHeads, KVCacheList const& cacheList, - uint32_t idxReq, uint32_t idxHeadGrp, CUtensorMap const& tensorMap + __device__ KVTilePartLoader(bool isK, uint32_t nbKHeads, + KVCacheList const& cacheList, uint32_t idxReq, + uint32_t idxHeadGrp, CUtensorMap const& tensorMap #if USE_PAGED_KV_CACHE - , - uint32_t nbPages, Vec& pageBuf -#endif - ); - // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache - template - __device__ void loadData( - Array2D& dst, uint32_t idxTile, - uint32_t idxPart, CtaBarrier& bar); - - __device__ void loadPages(uint32_t idxTile); - __device__ GMemKVCacheHead& getHead(uint32_t pos); + , + uint32_t nbPages, Vec& pageBuf +#endif + ); + // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache + template + __device__ void loadData( + Array2D& dst, + uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar); + + __device__ void loadPages(uint32_t idxTile); + __device__ GMemKVCacheHead& getHead(uint32_t pos); }; using GmmaAccCoreMat = Array2D; template -using GmmaAcc = Array2D; +using GmmaAcc = + Array2D; inline constexpr uint32_t gemm0M = (swapAB ? gemm0CtaTileNbTokens : ctaNbQHeads); inline constexpr uint32_t gemm0N = (swapAB ? ctaNbQHeads : gemm0CtaTileNbTokens); @@ -330,114 +324,108 @@ using UniformNeedRescaleMask = Vec; using RegSeqWiseVec = RegColWiseVec; #else using RegRowWiseVec = Vec, Gemm0Acc::rows>; -using UniformNeedRescaleMask - = Vec; +using UniformNeedRescaleMask = + Vec; using RegSeqWiseVec = RegRowWiseVec; #endif #if SPEC_DEC -__device__ inline uint32_t getInputSeqLen(SpecDecParams const& params, uint32_t idxReq) -{ - return (params.qCuSeqLens == nullptr) ? params.qSeqLen : params.qCuSeqLens[idxReq + 1] - params.qCuSeqLens[idxReq]; +__device__ inline uint32_t getInputSeqLen(SpecDecParams const& params, uint32_t idxReq) { + return (params.qCuSeqLens == nullptr) ? params.qSeqLen + : params.qCuSeqLens[idxReq + 1] - params.qCuSeqLens[idxReq]; } -__device__ inline uint32_t getInputTokOffset(SpecDecParams const& params, uint32_t idxReq) -{ - return (params.qCuSeqLens == nullptr) ? params.qSeqLen * idxReq : params.qCuSeqLens[idxReq]; +__device__ inline uint32_t getInputTokOffset(SpecDecParams const& params, uint32_t idxReq) { + return (params.qCuSeqLens == nullptr) ? params.qSeqLen * idxReq : params.qCuSeqLens[idxReq]; } -struct SpecDec -{ - static inline constexpr uint32_t tileSize = gemm0CtaTileNbTokens; - static inline constexpr uint32_t ctaMaxQSeqLen = (ctaNbQHeads / headGrpSize); - using TileMaskRow = Vec; - - __device__ inline SpecDec(SpecDecParams const& params, uint32_t idxReq, uint32_t idxInputSubSeq, uint32_t seqLen) - : params(params) - , idxInputSubSeq(idxInputSubSeq) - , seqLen(seqLen) - { - inputSeqLen = getInputSeqLen(params, idxReq); - baseOffset = divUp(params.qSeqLen, 32U) * (getInputTokOffset(params, idxReq) + ctaMaxQSeqLen * idxInputSubSeq); - } - - __device__ inline uint32_t unmaskedSeqLen() const - { - return seqLen - inputSeqLen; - } - - __device__ inline bool needMask(uint32_t idxTile, uint32_t idxQTokInCta) const - { - return tileSize * (idxTile + 1) > unmaskedSeqLen() - && ctaMaxQSeqLen * idxInputSubSeq + idxQTokInCta < inputSeqLen && params.mask != nullptr; - } - - __device__ inline int32_t maskColBeg(uint32_t idxTile) const - { - int32_t const convergedSeqLen = int32_t(unmaskedSeqLen()); - return static_cast(exactDiv(tileSize, 32) * idxTile) - - static_cast(divUp(convergedSeqLen, 32)); - } - - __device__ inline TileMaskRow loadTileMaskRow(uint32_t idxTile, uint32_t idxQTokInCta) const - { - assert(needMask(idxTile, idxQTokInCta)); - constexpr uint32_t nbOrigElems = TileMaskRow::size + 1; - Vec orig; - - int32_t const cols = divUp(params.qSeqLen, 32); - uint32_t const rowOffset = baseOffset + idxQTokInCta * cols; - int32_t const colBeg = maskColBeg(idxTile); -#pragma unroll - for (int32_t i = 0; i < int32_t(nbOrigElems); i++) - { - int32_t const idx = colBeg + i; - orig[i] = inRange(idx, 0, cols) ? params.mask[rowOffset + idx] : (idx < 0 ? ~0U : 0U); - } - TileMaskRow mask; - uint32_t const shift = (32 - unmaskedSeqLen() % 32) % 32; -#pragma unroll - for (uint32_t i = 0; i < TileMaskRow::size; i++) - { - asm("shf.r.clamp.b32 %0, %1, %2, %3;\n" : "=r"(mask[i]) : "r"(orig[i]), "r"(orig[i + 1]), "r"(shift)); - } - return mask; - } - - SpecDecParams const& params; - uint32_t const idxInputSubSeq; - uint32_t const seqLen; - uint32_t inputSeqLen; - uint32_t baseOffset; +struct SpecDec { + static inline constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static inline constexpr uint32_t ctaMaxQSeqLen = (ctaNbQHeads / headGrpSize); + using TileMaskRow = Vec; + + __device__ inline SpecDec(SpecDecParams const& params, uint32_t idxReq, uint32_t idxInputSubSeq, + uint32_t seqLen) + : params(params), idxInputSubSeq(idxInputSubSeq), seqLen(seqLen) { + inputSeqLen = getInputSeqLen(params, idxReq); + baseOffset = divUp(params.qSeqLen, 32U) * + (getInputTokOffset(params, idxReq) + ctaMaxQSeqLen * idxInputSubSeq); + } + + __device__ inline uint32_t unmaskedSeqLen() const { return seqLen - inputSeqLen; } + + __device__ inline bool needMask(uint32_t idxTile, uint32_t idxQTokInCta) const { + return tileSize * (idxTile + 1) > unmaskedSeqLen() && + ctaMaxQSeqLen * idxInputSubSeq + idxQTokInCta < inputSeqLen && params.mask != nullptr; + } + + __device__ inline int32_t maskColBeg(uint32_t idxTile) const { + int32_t const convergedSeqLen = int32_t(unmaskedSeqLen()); + return static_cast(exactDiv(tileSize, 32) * idxTile) - + static_cast(divUp(convergedSeqLen, 32)); + } + + __device__ inline TileMaskRow loadTileMaskRow(uint32_t idxTile, uint32_t idxQTokInCta) const { + assert(needMask(idxTile, idxQTokInCta)); + constexpr uint32_t nbOrigElems = TileMaskRow::size + 1; + Vec orig; + + int32_t const cols = divUp(params.qSeqLen, 32); + uint32_t const rowOffset = baseOffset + idxQTokInCta * cols; + int32_t const colBeg = maskColBeg(idxTile); +#pragma unroll + for (int32_t i = 0; i < int32_t(nbOrigElems); i++) { + int32_t const idx = colBeg + i; + orig[i] = inRange(idx, 0, cols) ? params.mask[rowOffset + idx] : (idx < 0 ? ~0U : 0U); + } + TileMaskRow mask; + uint32_t const shift = (32 - unmaskedSeqLen() % 32) % 32; +#pragma unroll + for (uint32_t i = 0; i < TileMaskRow::size; i++) { + asm("shf.r.clamp.b32 %0, %1, %2, %3;\n" + : "=r"(mask[i]) + : "r"(orig[i]), "r"(orig[i + 1]), "r"(shift)); + } + return mask; + } + + SpecDecParams const& params; + uint32_t const idxInputSubSeq; + uint32_t const seqLen; + uint32_t inputSeqLen; + uint32_t baseOffset; }; __device__ void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - int32_t tok0WinBeg, + int32_t tok0WinBeg, #endif - uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank); + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank); #endif #if SWAP_AB -__device__ RegColWiseVec computeWarpGrpColMax_sync( - CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src); -__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd); +__device__ RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, + Gemm0Acc const& src); +__device__ void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, + uint32_t validRowEnd); __device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax); __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); -__device__ void storeGemm0AccToShm( - uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, + CtaBarrier& barConsumed, Gemm0Acc const& acc); __device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); __device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound); #else -__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src); +__device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, + Gemm0Acc const& src); __device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); __device__ void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& colMax); __device__ RegRowWiseVec computeWarpRowSum(Gemm0Acc& src); -__device__ void storeGemm0AccToShm( - uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); +__device__ void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, + CtaBarrier& barConsumed, Gemm0Acc const& acc); __device__ RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec); -__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, RegRowWiseVec const& regVec); +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, + RegRowWiseVec const& regVec); #endif using RegMatAFrag = Array2D, 1, 2>; @@ -445,34 +433,41 @@ constexpr uint32_t gemm1NbGmmaInstK = exactDiv(gemm1CtaTileNbTokens, gmma::instK #if SWAP_AB constexpr uint32_t gemm1NbGmmaInstM = exactDiv(headElems, gmma::instM); -__device__ Vec loadVTileTransposed( - uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK); +__device__ Vec loadVTileTransposed(uint32_t warpRank, uint32_t lane, + SharedMem::VBuffer const& smemV, + uint32_t idxGmmaInstK); using Gemm1Acc = GmmaAcc; __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, - ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, - CtaBarrier& gemm1WarpGrpBar); + ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], + ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar); template -__device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, - SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, +__device__ void finalizeAndWriteOut_sync( + uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum, + ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads = 0 /* only for final result in spec dec. */); #else -__device__ void transposeVTile( - uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src); +__device__ void transposeVTile(uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, + SharedMem::VBuffer const& src); using Gemm1Acc = GmmaAcc; __device__ void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, - ShmQWiseVec const(&shmXRowSum), ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, ShmQWiseVec& shmAccRowSum); + ShmQWiseVec const(&shmXRowSum), + ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccRowSum); template -__device__ void finalizeAndWriteOut_sync(uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, - Gemm1Acc& acc, float xvoScale, ShmQWiseVec const& accColSum, - uint32_t nbKHeads /* only for final result in spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens); +__device__ void finalizeAndWriteOut_sync( + uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, + float xvoScale, ShmQWiseVec const& accColSum, + uint32_t nbKHeads /* only for final result in spec dec. set to 1 for workspace*/, + uint32_t ctaNbValidTokens); #endif -inline constexpr uint32_t ropeNbPairsPerThrdImpl(uint32_t nbThrds) -{ - auto const val = divUp(exactDiv(validElemsPerHead, 2), nbThrds); - assert(val <= 32); - return val <= 2 ? val : (val <= 4 ? 4 : (val <= 8 ? 8 : (val <= 16 ? 16 : 32))); +inline constexpr uint32_t ropeNbPairsPerThrdImpl(uint32_t nbThrds) { + auto const val = divUp(exactDiv(validElemsPerHead, 2), nbThrds); + assert(val <= 32); + return val <= 2 ? val : (val <= 4 ? 4 : (val <= 8 ? 8 : (val <= 16 ? 16 : 32))); } template @@ -482,1367 +477,1285 @@ template , ropeNbPairsPerThrd> loadHead( Vec const& head, uint32_t tid); template -__device__ mha::conditional_t, 2>, Vec, nbPairsPerThrd>> -applyRoPE(Vec, nbPairsPerThrd> const& data, Vec, nbPairsPerThrd> const& ropeCosSin); +__device__ mha::conditional_t, 2>, + Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, + Vec, nbPairsPerThrd> const& ropeCosSin); template -__device__ void storeRotatedPairsForKV(GMemCacheHead& dst, +__device__ void storeRotatedPairsForKV( + GMemCacheHead& dst, mha::conditional_t>, 2>, - Vec, ropeNbPairsPerThrd>> const& src, + Vec, ropeNbPairsPerThrd>> const& src, uint32_t tid); template -__device__ void storeRotatedPairsForQ(SharedMem::QBuffer& dst, +__device__ void storeRotatedPairsForQ( + SharedMem::QBuffer& dst, mha::conditional_t>, 2>, - Vec, ropeNbPairsPerThrd>> const& src, + Vec, ropeNbPairsPerThrd>> const& src, uint32_t row, uint32_t tid); -class ScratchMem -{ -public: - struct alignas(8) SumMax - { - float sum; - float max; - }; - - using ColWiseVec = Vec; - - HOST_DEVICE_FUNC ScratchMem(void* scratch, uint32_t maxTotalNbSubSeq, uint32_t nbInputSeqSplit) - : mScratch{static_cast(scratch)} - { - uint32_t const nbChunks = maxTotalNbSubSeq * nbInputSeqSplit; - Segmenter segmenter; - constexpr uint32_t alignment = sizeof(Vec); - mRowSumMax = segmenter.template newSeg(nbChunks, alignment); - mTokens = segmenter.template newSeg>(nbChunks, alignment); - } - - HOST_DEVICE_FUNC TinyPtr rowSumMax() const - { - return makePtr(mRowSumMax); - } - - HOST_DEVICE_FUNC TinyPtr> tokens() const - { - return makePtr>(mTokens); - } - -private: - template - HOST_DEVICE_FUNC TinyPtr makePtr(uint32_t offset) const - { - return TinyPtr{mScratch, offset}.template cast(); - } - -private: - mha::byte* mScratch; - // offsets - uint32_t mRowSumMax; - uint32_t mTokens; +class ScratchMem { + public: + struct alignas(8) SumMax { + float sum; + float max; + }; + + using ColWiseVec = Vec; + + HOST_DEVICE_FUNC ScratchMem(void* scratch, uint32_t maxTotalNbSubSeq, uint32_t nbInputSeqSplit) + : mScratch{static_cast(scratch)} { + uint32_t const nbChunks = maxTotalNbSubSeq * nbInputSeqSplit; + Segmenter segmenter; + constexpr uint32_t alignment = sizeof(Vec); + mRowSumMax = segmenter.template newSeg(nbChunks, alignment); + mTokens = segmenter.template newSeg>(nbChunks, alignment); + } + + HOST_DEVICE_FUNC TinyPtr rowSumMax() const { return makePtr(mRowSumMax); } + + HOST_DEVICE_FUNC TinyPtr> tokens() const { + return makePtr>(mTokens); + } + + private: + template + HOST_DEVICE_FUNC TinyPtr makePtr(uint32_t offset) const { + return TinyPtr{mScratch, offset}.template cast(); + } + + private: + mha::byte* mScratch; + // offsets + uint32_t mRowSumMax; + uint32_t mTokens; }; -struct MultiBlockSMem -{ - using ColWiseVec = ScratchMem::ColWiseVec; - static constexpr uint32_t nbBuf = useSpecDec ? 2 : 4; - static constexpr uint32_t nbIOWarps = nbBuf; - using Elem = InputElem; - using Head = Vec; - Vec, nbBuf> tokens; - Vec rowSumMax; - Vec barriers; +struct MultiBlockSMem { + using ColWiseVec = ScratchMem::ColWiseVec; + static constexpr uint32_t nbBuf = useSpecDec ? 2 : 4; + static constexpr uint32_t nbIOWarps = nbBuf; + using Elem = InputElem; + using Head = Vec; + Vec, nbBuf> tokens; + Vec rowSumMax; + Vec barriers; }; #ifndef NDEBUG -namespace dbg -{ +namespace dbg { template -__device__ void printAcc( - CtaBarrier& warpGrpBar, uint32_t warpRank, Array2D const& acc) -{ - for (int m = 0; m < nbGmmaInstM; m++) - { - for (int w = 0; w < 4; w++) - { - if (warpRank == w) - { - for (int a = 0; a < 2; a++) - { - for (int b = 0; b < 8; b++) - { - for (int n = 0; n < nbGmmaInstNBase; n++) - { - for (uint32_t i = 0; i < 4; i++) - { - if (laneId() == b * 4 + i) - { - printf("%f, %f, ", acc(m, n)(a, 0), acc(m, n)(a, 1)); - } - __syncwarp(); - } - } - if (laneId() == 0) - { - printf("\n"); - } - __syncwarp(); - } - if (laneId() == 0) - { - printf("\n"); - } - __syncwarp(); +__device__ void printAcc(CtaBarrier& warpGrpBar, uint32_t warpRank, + Array2D const& acc) { + for (int m = 0; m < nbGmmaInstM; m++) { + for (int w = 0; w < 4; w++) { + if (warpRank == w) { + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 8; b++) { + for (int n = 0; n < nbGmmaInstNBase; n++) { + for (uint32_t i = 0; i < 4; i++) { + if (laneId() == b * 4 + i) { + printf("%f, %f, ", acc(m, n)(a, 0), acc(m, n)(a, 1)); } + __syncwarp(); + } + } + if (laneId() == 0) { + printf("\n"); } - warpGrpBar.arrive_and_wait(); + __syncwarp(); + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); } + } + warpGrpBar.arrive_and_wait(); } + } } -__device__ void printShmColWiseVec(ShmQWiseVec const& vec) -{ - for (uint32_t i = 0; i < vec.size; i++) - { - printf("%f, ", vec[i]); - } - printf("\n"); +__device__ void printShmColWiseVec(ShmQWiseVec const& vec) { + for (uint32_t i = 0; i < vec.size; i++) { + printf("%f, ", vec[i]); + } + printf("\n"); } -template -__device__ void printArray2D(Array2D const& src) -{ - for (uint32_t i = 0; i < rows; i++) - { - for (uint32_t j = 0; j < cols; j++) - { - T const val = src.template at(i, j); - for (uint32_t k = 0; k < exactDiv(sizeof(T), sizeof(Elem)); k++) - { - printf("%f, ", float(reinterpret_cast(&val)[k])); - } - } - printf("\n"); +template +__device__ void printArray2D(Array2D const& src) { + for (uint32_t i = 0; i < rows; i++) { + for (uint32_t j = 0; j < cols; j++) { + T const val = src.template at(i, j); + for (uint32_t k = 0; k < exactDiv(sizeof(T), sizeof(Elem)); k++) { + printf("%f, ", float(reinterpret_cast(&val)[k])); + } } + printf("\n"); + } } -} // namespace dbg +} // namespace dbg #endif -CUBIN_EXPORT __device__ constexpr XQAKernelType kernelType = XQAKernelType::kHOPPER_WARP_SPECIALIZED; +CUBIN_EXPORT __device__ constexpr XQAKernelType kernelType = + XQAKernelType::kHOPPER_WARP_SPECIALIZED; CUBIN_EXPORT __global__ #ifdef NDEBUG #if !OPTIMIZE_FOR_LATENCY - __launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2) +__launch_bounds__(128 * 3, headElems* ctaNbQHeads <= 128 * 16 ? 3 : 2) #else - __launch_bounds__(128 * 3) +__launch_bounds__(128 * 3) #endif #else __launch_bounds__(128 * 3, 1) #endif - void kernel_mha(uint32_t const nbKHeads, + void kernel_mha( + uint32_t const nbKHeads, #if SLIDING_WINDOW - uint32_t const slidingWinSize, + uint32_t const slidingWinSize, #endif - float const qScale, - OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] + float const qScale, + OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] #if LOW_PREC_OUTPUT - float const* const rcpOutScale, + float const* const rcpOutScale, #endif #if USE_INPUT_KV - IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads], + IOHead const* __restrict__ const qkv, // [nbReq][beamWidth][nbQHeads+nbKHeads+nbVHeads], #if ROPE_STYLE != 0 - Vec const* __restrict__ const ropeCosSin, // [maxNbPosEmb] + Vec const* __restrict__ const ropeCosSin, // [maxNbPosEmb] #endif #else IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], #endif - float const* attentionSinks, // [headGrpSize] - KVCacheList const cacheList, + float const* attentionSinks, // [headGrpSize] + KVCacheList const cacheList, #if USE_BEAM_SEARCH - BeamSearchParams const beamSearchParams, + BeamSearchParams const beamSearchParams, #endif - uint32_t const batchSize, - float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used - // only for int8/fp8 KV cache. + uint32_t const batchSize, + float const* __restrict__ const kvCacheScale, // Device memory scalar. Same scale for K and + // V cache. Used only for int8/fp8 KV cache. #if PAGED_KV_CACHE_LAYOUT == 1 - __grid_constant__ CUtensorMap const tensorMapVLLMK, __grid_constant__ CUtensorMap const tensorMapVLLMV, + __grid_constant__ CUtensorMap const tensorMapVLLMK, + __grid_constant__ CUtensorMap const tensorMapVLLMV, #else __grid_constant__ CUtensorMap const tensorMap, #endif #if SPEC_DEC - SpecDecParams const specDecParams, -#endif - uint32_t* __restrict__ const semaphores - = nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] - void* __restrict__ const scratch = nullptr) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) \ - && (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 - uint32_t const idxReq = blockIdx.z / nbKHeads; + SpecDecParams const specDecParams, +#endif + uint32_t* __restrict__ const semaphores = + nullptr, // [nbReq][nbKHeads][divUp(specDecParams.qSeqLen, inputTokensPerCta)] + void* __restrict__ const scratch = nullptr) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) && \ + (IS_SUPPORTED_F16_CASE || CACHE_ELEM_ENUM == 2) && BEAM_WIDTH == 1 + uint32_t const idxReq = blockIdx.z / nbKHeads; #if SPEC_DEC - uint32_t const reqInputTokBeg = getInputTokOffset(specDecParams, idxReq); - uint32_t const reqInputTokEnd = getInputTokOffset(specDecParams, idxReq + 1); - uint32_t const nbInputSeqSplit = gridDim.x; - assert(nbInputSeqSplit == divUp(specDecParams.qSeqLen, inputTokensPerCta)); -#else - uint32_t const reqInputTokBeg = idxReq; - uint32_t const reqInputTokEnd = idxReq + 1; - constexpr uint32_t nbInputSeqSplit = 1; - assert(gridDim.x == nbInputSeqSplit); -#endif - uint32_t const idxHeadGrp = blockIdx.z % nbKHeads; // inside one request - assert(gridDim.z == nbKHeads * batchSize); - uint32_t const cacheSeqLen = getCacheSeqLen(cacheList, idxReq); - static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); - constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + uint32_t const reqInputTokBeg = getInputTokOffset(specDecParams, idxReq); + uint32_t const reqInputTokEnd = getInputTokOffset(specDecParams, idxReq + 1); + uint32_t const nbInputSeqSplit = gridDim.x; + assert(nbInputSeqSplit == divUp(specDecParams.qSeqLen, inputTokensPerCta)); +#else + uint32_t const reqInputTokBeg = idxReq; + uint32_t const reqInputTokEnd = idxReq + 1; + constexpr uint32_t nbInputSeqSplit = 1; + assert(gridDim.x == nbInputSeqSplit); +#endif + uint32_t const idxHeadGrp = blockIdx.z % nbKHeads; // inside one request + assert(gridDim.z == nbKHeads * batchSize); + uint32_t const cacheSeqLen = getCacheSeqLen(cacheList, idxReq); + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; #if SPEC_DEC - uint32_t const idxInputSubSeq = blockIdx.x; - uint32_t const inputSeqLen = reqInputTokEnd - reqInputTokBeg; - uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; - uint32_t const ctaNbValidTokens = mha::min(uint32_t{inputTokensPerCta}, inputSeqLen - ctaTokOffset); + uint32_t const idxInputSubSeq = blockIdx.x; + uint32_t const inputSeqLen = reqInputTokEnd - reqInputTokBeg; + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + uint32_t const ctaNbValidTokens = + mha::min(uint32_t{inputTokensPerCta}, inputSeqLen - ctaTokOffset); - if (ctaTokOffset >= inputSeqLen) - { - return; - } + if (ctaTokOffset >= inputSeqLen) { + return; + } #else - uint32_t const idxInputSubSeq = 0; - uint32_t const inputSeqLen = 1; - uint32_t const ctaTokOffset = 0; - uint32_t const ctaNbValidTokens = 1; + uint32_t const idxInputSubSeq = 0; + uint32_t const inputSeqLen = 1; + uint32_t const ctaTokOffset = 0; + uint32_t const ctaNbValidTokens = 1; #endif #if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE - // get the actual start position depending on ctaTokOffset, which is the draft token position per CTA - uint32_t const tok0SeqLen = cacheSeqLen - inputSeqLen + 1 + ctaTokOffset; - int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize); - uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg); + // get the actual start position depending on ctaTokOffset, which is the draft token position per + // CTA + uint32_t const tok0SeqLen = cacheSeqLen - inputSeqLen + 1 + ctaTokOffset; + int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize); + uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg); #elif SLIDING_WINDOW - bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); - // if SPEC_DEC && SLIDING_WINDOW && IS_SPEC_DEC_TREE, it should not do sliding - assert(!SPEC_DEC || !rtIsReallySliding); - uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0; + bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); + // if SPEC_DEC && SLIDING_WINDOW && IS_SPEC_DEC_TREE, it should not do sliding + assert(!SPEC_DEC || !rtIsReallySliding); + uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0; #else - constexpr bool rtIsReallySliding = false; - constexpr uint32_t nbTotalSkipTokens = 0; + constexpr bool rtIsReallySliding = false; + constexpr uint32_t nbTotalSkipTokens = 0; #endif - uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / tileSize; - uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % tileSize; + uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / tileSize; + uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % tileSize; #if USE_BEAM_SEARCH - uint32_t const ctxCacheSeqLen = getCtxCacheSeqLen(beamSearchParams, idxReq); - uint32_t const nbCtxKTiles = useKVCache ? ctxCacheSeqLen / gemm0CtaTileNbTokens : 0; - uint32_t const nbDivergentKTiles - = useKVCache ? divUp(cacheSeqLen - gemm0CtaTileNbTokens * nbCtxKTiles, beamSearchGemm0CtaTileNbTokens) : 0; - uint32_t const nbKTiles = nbCtxKTiles + nbDivergentKTiles; - uint32_t const nbVTiles = nbKTiles; -#else - uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tileSize) : 0; - // uint32_t const nbKTiles = nbTiles; - // uint32_t const nbVTiles = nbTiles; - uint32_t const nbTilesInUse = nbTiles - nbSkipLeadingTiles; -#endif - uint32_t const maxNbSubSeq = gridDim.y; - uint32_t const idxSubSeq = blockIdx.y; - bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTilesInUse >= multiBlockMinNbTiles); - uint32_t const idxKTileInit = nbSkipLeadingTiles + idxSubSeq; - uint32_t const idxVTileInit = idxKTileInit; - uint32_t const nbSubSeq = isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; - static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); - assert(isMultiBlockMode == (nbSubSeq > 1)); - if (idxSubSeq >= nbSubSeq) - { - return; - } - uint32_t const ctaInputTokBeg = reqInputTokBeg + ctaTokOffset; - auto const warpIdx = getWarpIdx(uint3{128, 1, 3}); - auto const wid = warpIdx.z * 4 + warpIdx.x; + uint32_t const ctxCacheSeqLen = getCtxCacheSeqLen(beamSearchParams, idxReq); + uint32_t const nbCtxKTiles = useKVCache ? ctxCacheSeqLen / gemm0CtaTileNbTokens : 0; + uint32_t const nbDivergentKTiles = + useKVCache + ? divUp(cacheSeqLen - gemm0CtaTileNbTokens * nbCtxKTiles, beamSearchGemm0CtaTileNbTokens) + : 0; + uint32_t const nbKTiles = nbCtxKTiles + nbDivergentKTiles; + uint32_t const nbVTiles = nbKTiles; +#else + uint32_t const nbTiles = useKVCache ? divUp(cacheSeqLen, tileSize) : 0; + // uint32_t const nbKTiles = nbTiles; + // uint32_t const nbVTiles = nbTiles; + uint32_t const nbTilesInUse = nbTiles - nbSkipLeadingTiles; +#endif + uint32_t const maxNbSubSeq = gridDim.y; + uint32_t const idxSubSeq = blockIdx.y; + bool const isMultiBlockMode = (maxNbSubSeq > 1 && nbTilesInUse >= multiBlockMinNbTiles); + uint32_t const idxKTileInit = nbSkipLeadingTiles + idxSubSeq; + uint32_t const idxVTileInit = idxKTileInit; + uint32_t const nbSubSeq = + isMultiBlockMode ? mha::min(nbTilesInUse / multiBlockMinNbTilesPerCta, maxNbSubSeq) : 1; + static_assert(multiBlockMinNbTiles >= multiBlockMinNbTilesPerCta * 2); + assert(isMultiBlockMode == (nbSubSeq > 1)); + if (idxSubSeq >= nbSubSeq) { + return; + } + uint32_t const ctaInputTokBeg = reqInputTokBeg + ctaTokOffset; + auto const warpIdx = getWarpIdx(uint3{128, 1, 3}); + auto const wid = warpIdx.z * 4 + warpIdx.x; #if PAGED_KV_CACHE_LAYOUT == 1 - if (wid == 0 && warpElectSync()) - { - tma::prefetchTensorMap(tensorMapVLLMK); - tma::prefetchTensorMap(tensorMapVLLMV); - } -#else - if (wid == 0 && warpElectSync()) - { - tma::prefetchTensorMap(tensorMap); - } -#endif - extern __shared__ char smemByteBuf[]; - assert(dynamicSmemSize() >= sizeof(SharedMem)); - SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); - - constexpr uint32_t nbBuffers = 2; - static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && nbBuffers == SharedMem::nbXBuf); - if (wid < nbBuffers) - { - if (warpElectSync()) - { - smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); - smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); + if (wid == 0 && warpElectSync()) { + tma::prefetchTensorMap(tensorMapVLLMK); + tma::prefetchTensorMap(tensorMapVLLMV); + } +#else + if (wid == 0 && warpElectSync()) { + tma::prefetchTensorMap(tensorMap); + } +#endif + extern __shared__ char smemByteBuf[]; + assert(dynamicSmemSize() >= sizeof(SharedMem)); + SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); + + constexpr uint32_t nbBuffers = 2; + static_assert(nbBuffers == SharedMem::nbKBuf && nbBuffers == SharedMem::nbVBuf && + nbBuffers == SharedMem::nbXBuf); + if (wid < nbBuffers) { + if (warpElectSync()) { + smem.kBar[wid].initialize(gemm0NbThrds, gemm0NbThrds + warp_size); + smem.vBar[wid].initialize(gemm1NbThrds, gemm1NbThrds + warp_size); #if !SWAP_AB - smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); + smem.vtBar[wid].initialize(gemm1NbThrds * 2, gemm1NbThrds * 2); #endif - smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); - } + smem.xBar[wid].initialize(gemm0NbThrds + gemm1NbThrds, gemm0NbThrds + gemm1NbThrds); } - else if (wid == nbBuffers) - { - if (warpElectSync()) - { - smem.qBar.initialize(gemm0NbThrds + nbQLdThrds, gemm0NbThrds + nbQLdThrds); - init(&smem.gemm0WarpGrpBar, gemm0NbThrds); - init(&smem.gemm1WarpGrpBar, gemm1NbThrds); - } + } else if (wid == nbBuffers) { + if (warpElectSync()) { + smem.qBar.initialize(gemm0NbThrds + nbQLdThrds, gemm0NbThrds + nbQLdThrds); + init(&smem.gemm0WarpGrpBar, gemm0NbThrds); + init(&smem.gemm1WarpGrpBar, gemm1NbThrds); } - __syncthreads(); + } + __syncthreads(); #if USE_PAGED_KV_CACHE - uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); + uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); #endif - constexpr bool isKVCacheQuantized = (cacheElemSize < 2); - assert(idxKTileInit < nbTiles); - uint32_t const nbIters = divUp(nbTiles - idxKTileInit, nbSubSeq); - assert(nbIters >= 1); + constexpr bool isKVCacheQuantized = (cacheElemSize < 2); + assert(idxKTileInit < nbTiles); + uint32_t const nbIters = divUp(nbTiles - idxKTileInit, nbSubSeq); + assert(nbIters >= 1); - constexpr uint32_t gmmaInstK = gmma::instK; - constexpr uint32_t grainsPerInstK = exactDiv(sizeof(MathElem) * gmmaInstK, grainBytes); + constexpr uint32_t gmmaInstK = gmma::instK; + constexpr uint32_t grainsPerInstK = exactDiv(sizeof(MathElem) * gmmaInstK, grainBytes); - if (warpIdx.z == 0) - { + if (warpIdx.z == 0) { #if SPEC_DEC - SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen}; + SpecDec const specDec{specDecParams, idxReq, idxInputSubSeq, cacheSeqLen}; #endif - // QK gemm - constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM); - using Acc = GmmaAcc; + // QK gemm + constexpr uint32_t nbGmmaInstM = exactDiv(gemm0CtaTileNbTokens, gmma::instM); + using Acc = GmmaAcc; - unused(smem.qBar.consumed.arrive()); - for (auto& b : smem.kBar) - { - unused(b.consumed.arrive()); - } + unused(smem.qBar.consumed.arrive()); + for (auto& b : smem.kBar) { + unused(b.consumed.arrive()); + } - float const qkScale = qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) - * rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. - uint32_t const warpRank = warpIdx.x; + float const qkScale = + qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * + rsqrtf(validElemsPerHead); // qkScale is applied onto Q*K.T before softmax. + uint32_t const warpRank = warpIdx.x; - // init once per sequence. It also works as global colMax across iterations. - if (threadIdx.x < ctaNbQHeads) - { - smem.gemm0CurrentSeqMax[threadIdx.x] = safeInitRowMax; - } - smem.gemm0WarpGrpBar.arrive_and_wait(); + // init once per sequence. It also works as global colMax across iterations. + if (threadIdx.x < ctaNbQHeads) { + smem.gemm0CurrentSeqMax[threadIdx.x] = safeInitRowMax; + } + smem.gemm0WarpGrpBar.arrive_and_wait(); - smem.qBar.produced.arrive_and_wait(); + smem.qBar.produced.arrive_and_wait(); #if DBG_PRINT - if (threadIdx.x == 0) - { - printf("q:\n"); - dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[0]); - } -#endif - - auto const matDescQBase = gmma::makeMatDesc( - nullptr, 0, SharedMem::QBuffer::Elem::rowBytes * 8, gmma::getSwizzleMode(SharedMem::QBuffer::Elem{})) - .raw(); - for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) - { - uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; - assert(idxKTile < nbTiles); - Acc acc; // no need to initialize. GMMA allows us to ignore acc initial values. - gmma::fence(); - static_assert(cacheHeadNbParts == nbQParts); -#pragma unroll - for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) - { - auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; - auto& kBuf = smem.k[idxKBuf]; - auto& kBar = smem.kBar[idxKBuf]; - static_assert(SharedMem::KBuffer::rows % 8 == 0); - auto const matDescKBase = gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, &smem.k[0], - gmma::getSwizzleMode(SharedMem::KBuffer{})) - .raw(); - assert(matDescKBase - == gmma::makeMatDesc( - nullptr, 0, SharedMem::KBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::KBuffer{})) - .raw()); - arrive_tx_and_wait(kBar.produced, exactDiv(sizeof(SharedMem::KBuffer), gemm0NbThrds)); - // if (threadIdx.x == 0) { - // printf("************* part %u *******\n", idxPart); - // printf("q:\n"); - // dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[idxPart]); - // printf("k:\n"); - // dbg::printArray2D<__nv_fp8_e4m3, true>(kBuf); - // } - constexpr uint32_t nbGmmaInstK = exactDiv(cacheHeadPartElems, gmmaInstK); -#pragma unroll - for (uint32_t k = 0; k < nbGmmaInstK; k++) - { - bool const accHasVal = (idxPart != 0 || k != 0); - auto const matDescQ = addAddr(matDescQBase, &smem.q[idxPart](0, grainsPerInstK * k)); -#pragma unroll - for (uint32_t m = 0; m < nbGmmaInstM; m++) - { - auto const matDescK = addAddr(matDescKBase, &kBuf(64 * m, grainsPerInstK * k)); + if (threadIdx.x == 0) { + printf("q:\n"); + dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[0]); + } +#endif + + auto const matDescQBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::QBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::QBuffer::Elem{})) + .raw(); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + assert(idxKTile < nbTiles); + Acc acc; // no need to initialize. GMMA allows us to ignore acc initial values. + gmma::fence(); + static_assert(cacheHeadNbParts == nbQParts); +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBuf = smem.k[idxKBuf]; + auto& kBar = smem.kBar[idxKBuf]; + static_assert(SharedMem::KBuffer::rows % 8 == 0); + auto const matDescKBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, &smem.k[0], + gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw(); + assert(matDescKBase == gmma::makeMatDesc(nullptr, 0, SharedMem::KBuffer::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::KBuffer{})) + .raw()); + arrive_tx_and_wait(kBar.produced, exactDiv(sizeof(SharedMem::KBuffer), gemm0NbThrds)); + // if (threadIdx.x == 0) { + // printf("************* part %u *******\n", idxPart); + // printf("q:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(smem.q[idxPart]); + // printf("k:\n"); + // dbg::printArray2D<__nv_fp8_e4m3, true>(kBuf); + // } + constexpr uint32_t nbGmmaInstK = exactDiv(cacheHeadPartElems, gmmaInstK); +#pragma unroll + for (uint32_t k = 0; k < nbGmmaInstK; k++) { + bool const accHasVal = (idxPart != 0 || k != 0); + auto const matDescQ = addAddr(matDescQBase, &smem.q[idxPart](0, grainsPerInstK * k)); +#pragma unroll + for (uint32_t m = 0; m < nbGmmaInstM; m++) { + auto const matDescK = addAddr(matDescKBase, &kBuf(64 * m, grainsPerInstK * k)); #if SWAP_AB - gmma::mma_async_shmA( - reinterpret_cast(acc(m, 0)), - matDescK, matDescQ, accHasVal); -#else - gmma::mma_async_shmA( - reinterpret_cast(acc(m, 0)), - matDescQ, matDescK, accHasVal); -#endif - } - } - gmma::commit_group(); - //@fixme: use two sets of acc and let gmma_async overlap with softmax. But this will let tile0_softmax - // wait for - // k loading of tile1 and may harm perf for short-seq cases. - gmma::wait_group<0>(); - unused(kBar.consumed.arrive()); - } + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescK, matDescQ, accHasVal); +#else + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + matDescQ, matDescK, accHasVal); +#endif + } + } + gmma::commit_group(); + //@fixme: use two sets of acc and let gmma_async overlap with softmax. But this will let + // tile0_softmax + // wait for + // k loading of tile1 and may harm perf for short-seq cases. + gmma::wait_group<0>(); + unused(kBar.consumed.arrive()); + } #if !defined(NDEBUG) && DBG_PRINT - dbg::printAcc(smem.gemm0WarpGrpBar, warpRank, acc); + dbg::printAcc(smem.gemm0WarpGrpBar, warpRank, acc); #endif - // apply qkScale - acc = acc * qkScale; - // apply mask + // apply qkScale + acc = acc * qkScale; + // apply mask #if SPEC_DEC - warpGrpApplyMask(acc, specDec, + warpGrpApplyMask(acc, specDec, #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - tok0WinBeg, -#endif - cacheSeqLen, idxKTile, warpRank); -#else - bool const isFirstTile = (idxKTile == nbSkipLeadingTiles); - bool const needMaskLeading = (rtIsReallySliding && isFirstTile && tile0NbSkipTokens > 0); - bool const isLastTile = (idxKTile + 1 == nbTiles); - bool const needMaskTrailing = isLastTile && cacheSeqLen % tileSize != 0; - if (needMaskLeading || needMaskTrailing) - { - uint32_t const validTokenBeg = needMaskLeading ? tile0NbSkipTokens : 0; - uint32_t const validTokenEnd = (needMaskTrailing ? cacheSeqLen % tileSize : tileSize); - if (validTokenBeg > 0 || validTokenEnd < tileSize) - { + tok0WinBeg, +#endif + cacheSeqLen, idxKTile, warpRank); +#else + bool const isFirstTile = (idxKTile == nbSkipLeadingTiles); + bool const needMaskLeading = (rtIsReallySliding && isFirstTile && tile0NbSkipTokens > 0); + bool const isLastTile = (idxKTile + 1 == nbTiles); + bool const needMaskTrailing = isLastTile && cacheSeqLen % tileSize != 0; + if (needMaskLeading || needMaskTrailing) { + uint32_t const validTokenBeg = needMaskLeading ? tile0NbSkipTokens : 0; + uint32_t const validTokenEnd = (needMaskTrailing ? cacheSeqLen % tileSize : tileSize); + if (validTokenBeg > 0 || validTokenEnd < tileSize) { #if SWAP_AB - warpGrpApplyMask(warpRank, acc, validTokenBeg, validTokenEnd); + warpGrpApplyMask(warpRank, acc, validTokenBeg, validTokenEnd); #else - warpGrpApplyMask(acc, validTokenBeg, validTokenEnd); + warpGrpApplyMask(acc, validTokenBeg, validTokenEnd); #endif - } - } + } + } #endif - // update colMax in shared mem and get a register copy + // update colMax in shared mem and get a register copy #if SWAP_AB - RegColWiseVec const colMax = computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc); - warpGrpOnlineSoftmax(acc, colMax); + RegColWiseVec const colMax = + computeWarpGrpColMax_sync(smem.gemm0WarpGrpBar, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, colMax); #else - RegRowWiseVec const rowMax = computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc); - warpGrpOnlineSoftmax(acc, rowMax); + RegRowWiseVec const rowMax = + computeWarpGrpRowMax_sync(warpRank, smem.gemm0CurrentSeqMax, acc); + warpGrpOnlineSoftmax(acc, rowMax); #endif - // @fixme: may need fp32->fp8->fp32 before doing sum. + // @fixme: may need fp32->fp8->fp32 before doing sum. #if SWAP_AB - RegColWiseVec const warpColSum = computeWarpColSum(acc); + RegColWiseVec const warpColSum = computeWarpColSum(acc); #else - RegRowWiseVec const rowSum = computeWarpRowSum(acc); + RegRowWiseVec const rowSum = computeWarpRowSum(acc); #endif - // map 1 to fp8_max before conversion to fp8 - acc = acc * kE4M3_MAX; + // map 1 to fp8_max before conversion to fp8 + acc = acc * kE4M3_MAX; - uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; - auto& xBar = smem.xBar[idxXBuf]; - // @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM. + uint32_t const idxXBuf = idxIter % SharedMem::nbXBuf; + auto& xBar = smem.xBar[idxXBuf]; + // @fixme: for fp16/bf16, try not to transpose acc here, and leave it to the next GEMM. #if SWAP_AB - storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); - // store colMax and warpColSum - auto const lane = laneId(); - if (lane < 4) - { - auto& xColMax = smem.xColMax[idxXBuf]; - auto& xColSum = smem.xColSum[idxXBuf][warpRank]; -#pragma unroll - for (uint32_t n = 0; n < colMax.size; n++) - { -#pragma unroll - for (uint32_t j = 0; j < 2; j++) - { - if (warpRank == 0) - { - xColMax[8 * n + 2 * lane + j] = colMax[n][j]; - } - xColSum[8 * n + 2 * lane + j] = warpColSum[n][j]; - } - } + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + // store colMax and warpColSum + auto const lane = laneId(); + if (lane < 4) { + auto& xColMax = smem.xColMax[idxXBuf]; + auto& xColSum = smem.xColSum[idxXBuf][warpRank]; +#pragma unroll + for (uint32_t n = 0; n < colMax.size; n++) { +#pragma unroll + for (uint32_t j = 0; j < 2; j++) { + if (warpRank == 0) { + xColMax[8 * n + 2 * lane + j] = colMax[n][j]; } + xColSum[8 * n + 2 * lane + j] = warpColSum[n][j]; + } + } + } #else - storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); - storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax); - storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum); + storeGemm0AccToShm(warpRank, laneId(), smem.xBuf(idxXBuf), xBar.consumed, acc); + storeShmRowWiseVec(warpRank, smem.xRowMax[idxXBuf], rowMax); + storeShmRowWiseVec(warpRank, smem.xRowSum[idxXBuf], rowSum); #endif - __syncwarp(); - // the release semantics of arrive does not work for async consumers like gmma. additional fence is - // needed. - asm volatile("fence.proxy.async.shared::cta;\n"); - unused(xBar.produced.arrive()); - } - unused(smem.qBar.consumed.arrive()); - } - else if (warpIdx.z == 1) - { - // XV GEMM - for (auto& b : smem.vBar) - { - unused(b.consumed.arrive()); - } + __syncwarp(); + // the release semantics of arrive does not work for async consumers like gmma. additional + // fence is needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(xBar.produced.arrive()); + } + unused(smem.qBar.consumed.arrive()); + } else if (warpIdx.z == 1) { + // XV GEMM + for (auto& b : smem.vBar) { + unused(b.consumed.arrive()); + } #if !SWAP_AB - for (auto& b : smem.vtBar) - { - unused(b.consumed.arrive()); - } + for (auto& b : smem.vtBar) { + unused(b.consumed.arrive()); + } #endif - for (auto& b : smem.xBar) - { - unused(b.consumed.arrive()); - } + for (auto& b : smem.xBar) { + unused(b.consumed.arrive()); + } - if (threadIdx.x < smem.gemm1AccColMax.size) - { - auto const idx = threadIdx.x; - smem.gemm1AccColMax[idx] = safeInitRowMax; - smem.gemm1AccColSum[idx] = 0; - } - smem.gemm1WarpGrpBar.arrive_and_wait(); + if (threadIdx.x < smem.gemm1AccColMax.size) { + auto const idx = threadIdx.x; + smem.gemm1AccColMax[idx] = safeInitRowMax; + smem.gemm1AccColSum[idx] = 0; + } + smem.gemm1WarpGrpBar.arrive_and_wait(); - uint32_t const warpRank = warpIdx.x; + uint32_t const warpRank = warpIdx.x; - constexpr float xScale = 1.f / kE4M3_MAX; + constexpr float xScale = 1.f / kE4M3_MAX; #if LOW_PREC_OUTPUT - float const oScale = rcpOutScale[0]; + float const oScale = rcpOutScale[0]; #else - constexpr float oScale = 1.F; + constexpr float oScale = 1.F; #endif - float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale; + float const xvoScale = xScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f) * oScale; - Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. - gmma::fence(); + Gemm1Acc acc{}; // init to zeros to avoid runtime checking for first gmma instruction. + gmma::fence(); - static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens, "not implemented"); - for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) - { - uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq; - auto const idxVBuf = idxIter % SharedMem::nbVBuf; - auto const idxXBuf = idxVBuf; - auto& vBar = smem.vBar[idxVBuf]; - arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); - auto const& vBuf = smem.vBuf(idxVBuf); + static_assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens, "not implemented"); + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t idxVTile = idxVTileInit + idxIter * nbSubSeq; + auto const idxVBuf = idxIter % SharedMem::nbVBuf; + auto const idxXBuf = idxVBuf; + auto& vBar = smem.vBar[idxVBuf]; + arrive_tx_and_wait(vBar.produced, exactDiv(sizeof(SharedMem::VBuffer), gemm1NbThrds)); + auto const& vBuf = smem.vBuf(idxVBuf); #if !SWAP_AB - CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; - auto& vtBuf = smem.vtBuf(idxVBuf); - vtBar.consumed.arrive_and_wait(); - transposeVTile(warpRank, laneId(), vtBuf, vBuf); - vBar.consumed.arrive(); - vtBar.produced.arrive(); -#endif - auto& xBar = smem.xBar[idxXBuf]; - xBar.produced.arrive_and_wait(); + CtaBarrierPair& vtBar = smem.vtBar[idxVBuf]; + auto& vtBuf = smem.vtBuf(idxVBuf); + vtBar.consumed.arrive_and_wait(); + transposeVTile(warpRank, laneId(), vtBuf, vBuf); + vBar.consumed.arrive(); + vtBar.produced.arrive(); +#endif + auto& xBar = smem.xBar[idxXBuf]; + xBar.produced.arrive_and_wait(); #if !defined(NDEBUG) && DBG_PRINT #if SWAP_AB - if (threadIdx.x == 0) - { - printf("colMax:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - printf("%f, ", smem.xColMax[idxXBuf][i]); - } - printf("\n"); - printf("colSum:\n"); - for (int n = 0; n < 4; n++) - { - for (int i = 0; i < ctaNbQHeads; i++) - { - printf("%f, ", smem.xColSum[idxXBuf][n][i]); - } - printf("\n"); - } - printf("\n"); - printf("X:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - for (int j = 0; j < gemm0CtaTileNbTokens; j++) - { - auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); - auto const e = reinterpret_cast&>( - smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( - i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; - printf("%.2f, ", float(e)); - if (j % 16 == 15) - { - printf("| "); - } - } - printf("\n\n"); - } + if (threadIdx.x == 0) { + printf("colMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xColMax[idxXBuf][i]); + } + printf("\n"); + printf("colSum:\n"); + for (int n = 0; n < 4; n++) { + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xColSum[idxXBuf][n][i]); + } + printf("\n"); + } + printf("\n"); + printf("X:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + for (int j = 0; j < gemm0CtaTileNbTokens; j++) { + auto const& elemsPerXPart = (cacheElemsPerGrain * grainsPerXPart); + auto const e = reinterpret_cast&>( + smem.xBuf(idxXBuf)[j / elemsPerXPart].template at( + i, j % elemsPerXPart / cacheElemsPerGrain))[j % cacheElemsPerGrain]; + printf("%.2f, ", float(e)); + if (j % 16 == 15) { + printf("| "); } - smem.gemm1WarpGrpBar.arrive_and_wait(); + } + printf("\n\n"); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); #else - if (blockIdx.y == 1 && threadIdx.x == 0) - { - printf("rowMax:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - printf("%f, ", smem.xRowMax[idxXBuf][i]); - } - printf("\n"); - printf("rowSum:\n"); - for (int i = 0; i < ctaNbQHeads; i++) - { - printf("%f, ", smem.xRowSum[idxXBuf][i]); - } - printf("\n"); - } - smem.gemm1WarpGrpBar.arrive_and_wait(); + if (blockIdx.y == 1 && threadIdx.x == 0) { + printf("rowMax:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xRowMax[idxXBuf][i]); + } + printf("\n"); + printf("rowSum:\n"); + for (int i = 0; i < ctaNbQHeads; i++) { + printf("%f, ", smem.xRowSum[idxXBuf][i]); + } + printf("\n"); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); #endif #endif #if SWAP_AB - // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc instead. - rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], - smem.gemm1AccColMax, acc, smem.gemm1AccColSum, smem.gemm1WarpGrpBar); + // @fixme: if first tile, no need to rescale acc. For persistent CTA, just re-initialize acc + // instead. + rescaleGemm1AccForNewColMax_sync(warpRank, smem.xColMax[idxXBuf], smem.xColSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum, + smem.gemm1WarpGrpBar); #else - rescaleGemm1AccForNewRowMax_sync( - warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], smem.gemm1AccColMax, acc, smem.gemm1AccColSum); + rescaleGemm1AccForNewRowMax_sync(warpRank, smem.xRowMax[idxXBuf], smem.xRowSum[idxXBuf], + smem.gemm1AccColMax, acc, smem.gemm1AccColSum); #endif - auto& xBuf = smem.xBuf(idxXBuf); + auto& xBuf = smem.xBuf(idxXBuf); - auto const descXBase = gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, - gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) - .raw(); + auto const descXBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::XBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::XBuffer::Elem{})) + .raw(); #if CACHE_ELEM_ENUM == 0 - auto const descVBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, - gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) - .raw(); + auto const descVBase = + gmma::makeMatDesc(nullptr, 0, SharedMem::VBuffer::Elem::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VBuffer::Elem{})) + .raw(); #endif #if SWAP_AB -//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in loadVTileTransposed. +//@fixme: to reduce code size, we can disable unroll and use double-buffer for LDSM in +// loadVTileTransposed. #pragma unroll - for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) - { + for (uint32_t idxInstK = 0; idxInstK < gemm1NbGmmaInstK; idxInstK++) { #if CACHE_ELEM_ENUM == 2 - Vec const fragA - = loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); + Vec const fragA = + loadVTileTransposed(warpRank, laneId(), vBuf, idxInstK); #if !defined(NDEBUG) && DBG_PRINT - if (threadIdx.x == 0) - { - printf("fragA:\nidxInstK == %u\n", idxInstK); - } - smem.gemm1WarpGrpBar.arrive_and_wait(); - for (int m = 0; m < 2; m++) - { - for (int w = 0; w < 4; w++) - { - if (warpRank == w) - { - if (laneId() == 0) - { - printf(" warpRank = %u\n", warpRank); - } - __syncwarp(); - for (int a = 0; a < 2; a++) - { - for (int b = 0; b < 8; b++) - { - for (int c = 0; c < 2; c++) - { - for (int d = 0; d < 4; d++) - { - if (laneId() == b * 4 + d) - { - for (int e = 0; e < 4; e++) - { - auto const& elem4 = reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>( - fragA[m](0, c)(a, 0)); - printf("%.2f, ", float(elem4[e])); - } - } - __syncwarp(); - } - } - if (laneId() == 0) - { - printf("\n"); - } - __syncwarp(); - } - if (laneId() == 0 && a == 0) - { - printf("----------------------\n"); - } - __syncwarp(); - } + if (threadIdx.x == 0) { + printf("fragA:\nidxInstK == %u\n", idxInstK); + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + for (int m = 0; m < 2; m++) { + for (int w = 0; w < 4; w++) { + if (warpRank == w) { + if (laneId() == 0) { + printf(" warpRank = %u\n", warpRank); + } + __syncwarp(); + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 8; b++) { + for (int c = 0; c < 2; c++) { + for (int d = 0; d < 4; d++) { + if (laneId() == b * 4 + d) { + for (int e = 0; e < 4; e++) { + auto const& elem4 = + reinterpret_cast<__nv_fp8_e4m3 const(&)[4]>(fragA[m](0, c)(a, 0)); + printf("%.2f, ", float(elem4[e])); } - smem.gemm1WarpGrpBar.arrive_and_wait(); + } + __syncwarp(); } + } + if (laneId() == 0) { + printf("\n"); + } + __syncwarp(); + } + if (laneId() == 0 && a == 0) { + printf("----------------------\n"); } + __syncwarp(); + } + } + smem.gemm1WarpGrpBar.arrive_and_wait(); + } + } #endif #endif - BoundedVal const kOffsetInGrains{grainsPerInstK * idxInstK}; - auto const descX = addAddr(descXBase, + BoundedVal const kOffsetInGrains{grainsPerInstK * + idxInstK}; + auto const descX = + addAddr(descXBase, &xBuf[kOffsetInGrains.template divBy().get()]( 0, kOffsetInGrains.template mod().get())); #if CACHE_ELEM_ENUM == 2 - gmma::fence(); + gmma::fence(); #endif #pragma unroll - for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) - { + for (uint32_t idxInstM = 0; idxInstM < gemm1NbGmmaInstM; idxInstM++) { #if CACHE_ELEM_ENUM == 0 - auto const descV - = addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); - gmma::mma_async_shmA( - reinterpret_cast(acc(idxInstM, 0)), - descV, descX, true); + auto const descV = + addAddr(descVBase, &vBuf[idxInstM](kOffsetInGrains.get() * cacheElemsPerGrain, 0)); + gmma::mma_async_shmA( + reinterpret_cast( + acc(idxInstM, 0)), + descV, descX, true); #elif CACHE_ELEM_ENUM == 2 - gmma::mma_async_regA( - reinterpret_cast(acc(idxInstM, 0)), - reinterpret_cast(fragA[idxInstM]), descX, true); -#endif - } - gmma::commit_group(); - //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of - // gmma. - gmma::wait_group<0>(); - } -#else - auto const descVTBase = gmma::makeMatDesc( - nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, gmma::getSwizzleMode(SharedMem::VTBuffer{})) - .raw(); - vtBar.produced.arrive_and_wait(); + gmma::mma_async_regA( + reinterpret_cast( + acc(idxInstM, 0)), + reinterpret_cast(fragA[idxInstM]), descX, true); +#endif + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until + // finish of + // gmma. + gmma::wait_group<0>(); + } +#else + auto const descVTBase = gmma::makeMatDesc(nullptr, 0, SharedMem::VTBuffer::rowBytes * 8, + gmma::getSwizzleMode(SharedMem::VTBuffer{})) + .raw(); + vtBar.produced.arrive_and_wait(); // if (idxIter == 1 && threadIdx.x == 0) { // printf("vtBuf:\n"); // dbg::printArray2D<__nv_fp8_e4m3, true>(vtBuf); // } #pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { -#pragma unroll - for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) - { - BoundedVal const kOffsetInGrains{grainsPerInstK * k}; - auto const descX = addAddr(descXBase, - &xBuf[kOffsetInGrains.template divBy().get()]( - gmma::instM * m, kOffsetInGrains.template mod().get())); - auto const descVT = addAddr( - descVTBase, &vtBuf(0, kOffsetInGrains.template mod().get())); - gmma::mma_async_shmA( - reinterpret_cast(acc(m, 0)), descX, - descVT, true); - } - } - gmma::commit_group(); - //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until finish of gmma. - gmma::wait_group<0>(); -#endif - if (idxIter == nbIters - 1) - { - // gmma::wait_group should have already synchronized threads, so this may be unnecessary. - smem.gemm1WarpGrpBar.arrive_and_wait(); - assert(idxXBuf == idxVBuf); - if (isMultiBlockMode) - { - ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; - uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; - uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxSubSeq; - uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; - // save row max/sum - static_assert(ctaNbValidQHeads <= gmmaWarpsPerGrp * warp_size); - if (threadIdx.x < ctaNbValidQHeads) - { - float const colMax = smem.gemm1AccColMax[threadIdx.x]; - float const colSum = smem.gemm1AccColSum[threadIdx.x]; - ScratchMem::SumMax sumMax; - sumMax.sum = colSum; - sumMax.max = colMax; - (scratchMem.rowSumMax() + idxChunk).template cast()[threadIdx.x] = sumMax; - } - // compute scratch ptr for output writing - IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t k = 0; k < gemm1NbGmmaInstK; k++) { + BoundedVal const kOffsetInGrains{grainsPerInstK * k}; + auto const descX = + addAddr(descXBase, + &xBuf[kOffsetInGrains.template divBy().get()]( + gmma::instM * m, + kOffsetInGrains.template mod().get())); + auto const descVT = + addAddr(descVTBase, + &vtBuf(0, kOffsetInGrains.template mod().get())); + gmma::mma_async_shmA( + reinterpret_cast(acc(m, 0)), + descX, descVT, true); + } + } + gmma::commit_group(); + //@fixme: delay wait and consumption to next tile. Note that fragA must also persist until + // finish of gmma. + gmma::wait_group<0>(); +#endif + if (idxIter == nbIters - 1) { + // gmma::wait_group should have already synchronized threads, so this may be unnecessary. + smem.gemm1WarpGrpBar.arrive_and_wait(); + assert(idxXBuf == idxVBuf); + if (isMultiBlockMode) { + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxSubSeq; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + // save row max/sum + static_assert(ctaNbValidQHeads <= gmmaWarpsPerGrp * warp_size); + if (threadIdx.x < ctaNbValidQHeads) { + float const colMax = smem.gemm1AccColMax[threadIdx.x]; + float const colSum = smem.gemm1AccColSum[threadIdx.x]; + ScratchMem::SumMax sumMax; + sumMax.sum = colSum; + sumMax.max = colMax; + (scratchMem.rowSumMax() + idxChunk).template cast()[threadIdx.x] = + sumMax; + } + // compute scratch ptr for output writing + IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); #if SWAP_AB - finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, - smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr); -#else - finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, - smem.gemm1AccColSum, 1, ctaNbValidTokens); -#endif - } - else - { - uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); - OutputHead* const dst = &output[outOffset]; - ShmQWiseVec const* attentionSinksVec = nullptr; - if (attentionSinks != nullptr) - { - attentionSinksVec - = reinterpret_cast(attentionSinks + headGrpSize * idxHeadGrp); - } + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, + xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, + smem.gemm1AccColMax, nullptr); +#else + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, 1, ctaNbValidTokens); +#endif + } else { + uint32_t const outOffset = + headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + OutputHead* const dst = &output[outOffset]; + ShmQWiseVec const* attentionSinksVec = nullptr; + if (attentionSinks != nullptr) { + attentionSinksVec = + reinterpret_cast(attentionSinks + headGrpSize * idxHeadGrp); + } #if SWAP_AB - finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, - xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec, - nbKHeads); + finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, + smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1WarpGrpBar, smem.gemm1AccColSum, + smem.gemm1AccColMax, attentionSinksVec, nbKHeads); #else - finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, - smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); + finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, + smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); #endif - } - } - unused(xBar.consumed.arrive()); + } + } + unused(xBar.consumed.arrive()); #if SWAP_AB - unused(vBar.consumed.arrive()); + unused(vBar.consumed.arrive()); #else - unused(vtBar.consumed.arrive()); + unused(vtBar.consumed.arrive()); #endif - } } - else - { - // IO warps - static_assert(beamWidth == 1); + } else { + // IO warps + static_assert(beamWidth == 1); #if ENABLE_PDL - preExit(); + preExit(); #endif #if ENABLE_PDL == 1 - acqBulk(); -#endif - assert(warpIdx.z == 2); - uint32_t const newTokenPos = cacheSeqLen - 1; - if (warpIdx.x < nbQLdWarps) - { - // load Q. Use register to load fp16 data and store fp8 to shared mem. - // @fixme: If register pressure is high and shared mem pressure is low, switch to TMA instead. - using QCvt = F16QToF8Converter; - static_assert(beamWidth == 1); + acqBulk(); +#endif + assert(warpIdx.z == 2); + uint32_t const newTokenPos = cacheSeqLen - 1; + if (warpIdx.x < nbQLdWarps) { + // load Q. Use register to load fp16 data and store fp8 to shared mem. + // @fixme: If register pressure is high and shared mem pressure is low, switch to TMA instead. + using QCvt = F16QToF8Converter; + static_assert(beamWidth == 1); #if USE_INPUT_KV - TinyPtr const qData{qkv, headGrpSize * idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq}; - constexpr bool isNeox = (ROPE_STYLE == 1); - constexpr uint32_t thrdsPerHead = mha::min(warp_size, divUp(headElems, 4U)); - uint32_t const lane = laneId(); - uint32_t const idxThrd = warpIdx.x * warp_size + lane; - uint32_t const idxThrdGrp = (thrdsPerHead % 32 == 0 ? makeWarpUniform(this_warp(), idxThrd / thrdsPerHead) - : idxThrd / thrdsPerHead); - constexpr uint32_t nbThrdGrps = exactDiv(warp_size * nbQLdWarps, thrdsPerHead); - uint32_t const tid = idxThrd % thrdsPerHead; - smem.qBar.consumed.arrive_and_wait(); + TinyPtr const qData{ + qkv, headGrpSize * idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq}; + constexpr bool isNeox = (ROPE_STYLE == 1); + constexpr uint32_t thrdsPerHead = mha::min(warp_size, divUp(headElems, 4U)); + uint32_t const lane = laneId(); + uint32_t const idxThrd = warpIdx.x * warp_size + lane; + uint32_t const idxThrdGrp = + (thrdsPerHead % 32 == 0 ? makeWarpUniform(this_warp(), idxThrd / thrdsPerHead) + : idxThrd / thrdsPerHead); + constexpr uint32_t nbThrdGrps = exactDiv(warp_size * nbQLdWarps, thrdsPerHead); + uint32_t const tid = idxThrd % thrdsPerHead; + smem.qBar.consumed.arrive_and_wait(); #if ROPE_STYLE != 0 - auto const& ropeCosSinHead - = reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); - auto const cosSinPairs = loadHead(ropeCosSinHead, tid); + auto const& ropeCosSinHead = + reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, tid); #endif #if ENABLE_PDL == 2 - acqBulk(); + acqBulk(); #endif #pragma unroll - for (uint32_t iter = 0; iter < divUp(headGrpSize, nbThrdGrps); iter++) - { - uint32_t const idxHead = nbThrdGrps * iter + idxThrdGrp; - if (idxHead >= headGrpSize) - { - break; - } + for (uint32_t iter = 0; iter < divUp(headGrpSize, nbThrdGrps); iter++) { + uint32_t const idxHead = nbThrdGrps * iter + idxThrdGrp; + if (idxHead >= headGrpSize) { + break; + } #if ROPE_STYLE == 0 - auto const rotatedPairs = loadHead(qData[idxHead], tid); + auto const rotatedPairs = + loadHead(qData[idxHead], tid); #else - auto const pairs = loadHead(qData[idxHead], tid); - auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); + auto const pairs = loadHead(qData[idxHead], tid); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); #endif - storeRotatedPairsForQ(smem.q, rotatedPairs, idxHead, tid); - } + storeRotatedPairsForQ(smem.q, rotatedPairs, idxHead, tid); + } #else - TinyPtr const qData{q, headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp)}; + TinyPtr const qData{ + q, headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp)}; #if ENABLE_PDL == 2 - acqBulk(); + acqBulk(); #endif - auto const f16QData = QCvt::load(threadIdx.x, qData, nbKHeads, ctaNbValidTokens); + auto const f16QData = QCvt::load(threadIdx.x, qData, nbKHeads, ctaNbValidTokens); - smem.qBar.consumed.arrive_and_wait(); - QCvt::store(threadIdx.x, smem.q, f16QData); + smem.qBar.consumed.arrive_and_wait(); + QCvt::store(threadIdx.x, smem.q, f16QData); #endif - // the release semantics of arrive does not work for async consumers like gmma. additional fence is - // needed. - asm volatile("fence.proxy.async.shared::cta;\n"); - unused(smem.qBar.produced.arrive()); - } - else if (warpIdx.x == nbQLdWarps) - { // load k - KVTilePartLoader kTilePartLoader - { - true, nbKHeads, cacheList, idxReq, idxHeadGrp, + // the release semantics of arrive does not work for async consumers like gmma. additional + // fence is needed. + asm volatile("fence.proxy.async.shared::cta;\n"); + unused(smem.qBar.produced.arrive()); + } else if (warpIdx.x == nbQLdWarps) { // load k + KVTilePartLoader kTilePartLoader{true, nbKHeads, cacheList, idxReq, idxHeadGrp, #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 - tensorMapVLLMK, + tensorMapVLLMK, #else - tensorMap, + tensorMap, #endif - nbPages, smem.pages[0] + nbPages, smem.pages[0] #else - tensorMap + tensorMap #endif - }; - for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) - { - uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; - kTilePartLoader.loadPages(idxKTile); + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxKTile = idxKTileInit + idxIter * nbSubSeq; + kTilePartLoader.loadPages(idxKTile); #if USE_INPUT_KV || ENABLE_PDL == 2 #if SPEC_DEC - bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) > cacheSeqLen - inputSeqLen); + bool const anyNewTokens = + (gemm0CtaTileNbTokens * (idxKTile + 1) > cacheSeqLen - inputSeqLen); #else - bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) >= cacheSeqLen); + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxKTile + 1) >= cacheSeqLen); #endif - if (anyNewTokens) - { + if (anyNewTokens) { #if ENABLE_PDL == 2 - acqBulk(); + acqBulk(); #endif #if USE_INPUT_KV - static_assert(beamWidth == 1); - uint32_t const inputKHeadOffset - = headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; - IOHead const& inKHead = qkv[inputKHeadOffset]; - uint32_t const lane = laneId(); - float const rcpKScale = 1.F / kvCacheScale[0]; + static_assert(beamWidth == 1); + uint32_t const inputKHeadOffset = + headGrpSize * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inKHead = qkv[inputKHeadOffset]; + uint32_t const lane = laneId(); + float const rcpKScale = 1.F / kvCacheScale[0]; #if ROPE_STYLE == 0 - constexpr bool isNeox = false; - auto const pairs = loadHead(inKHead, lane) * rcpKScale; - Vec, decltype(pairs)::size> convertedPairs; - constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; - reinterpret_cast&>(convertedPairs) - = convert(reinterpret_cast const&>(pairs)); - storeRotatedPairsForKV( - kTilePartLoader.getHead(newTokenPos), convertedPairs, lane); -#else - constexpr bool isNeox = (ROPE_STYLE == 1); - auto const pairs = loadHead(inKHead, lane) * rcpKScale; - auto const& ropeCosSinHead - = reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); - auto const cosSinPairs = loadHead(ropeCosSinHead, lane); - auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); - storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), rotatedPairs, lane); -#endif - static_assert(inputSeqLen == 1); - __syncwarp(); -#endif - } -#endif - for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) - { - auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; - auto& kBar = smem.kBar[idxKBuf]; - kBar.consumed.arrive_and_wait(); - if (warpElectSync()) - { - kTilePartLoader.loadData(smem.k[idxKBuf], idxKTile, idxPart, kBar.produced); - } - __syncwarp(); - } - } - } - else if (warpIdx.x == nbQLdWarps + 1) - { // load v - KVTilePartLoader vTileLoader - { - false, nbKHeads, cacheList, idxReq, idxHeadGrp, + constexpr bool isNeox = false; + auto const pairs = + loadHead(inKHead, lane) * rcpKScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) = + convert(reinterpret_cast const&>(pairs)); + storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), + convertedPairs, lane); +#else + constexpr bool isNeox = (ROPE_STYLE == 1); + auto const pairs = loadHead(inKHead, lane) * rcpKScale; + auto const& ropeCosSinHead = + reinterpret_cast const&>(ropeCosSin[cacheSeqLen - 1]); + auto const cosSinPairs = loadHead(ropeCosSinHead, lane); + auto const rotatedPairs = applyRoPE(pairs, cosSinPairs); + storeRotatedPairsForKV(kTilePartLoader.getHead(newTokenPos), + rotatedPairs, lane); +#endif + static_assert(inputSeqLen == 1); + __syncwarp(); +#endif + } +#endif + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + auto const idxKBuf = (idxIter * cacheHeadNbParts + idxPart) % SharedMem::nbKBuf; + auto& kBar = smem.kBar[idxKBuf]; + kBar.consumed.arrive_and_wait(); + if (warpElectSync()) { + kTilePartLoader.loadData(smem.k[idxKBuf], idxKTile, idxPart, kBar.produced); + } + __syncwarp(); + } + } + } else if (warpIdx.x == nbQLdWarps + 1) { // load v + KVTilePartLoader vTileLoader{false, nbKHeads, cacheList, idxReq, idxHeadGrp, #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 - tensorMapVLLMV, + tensorMapVLLMV, #else - tensorMap, + tensorMap, #endif - nbPages, smem.pages[1] + nbPages, smem.pages[1] #else - tensorMap + tensorMap #endif - }; - for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) - { - uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; - vTileLoader.loadPages(idxVTile); + }; + for (uint32_t idxIter = 0; idxIter < nbIters; idxIter++) { + uint32_t const idxVTile = idxVTileInit + idxIter * nbSubSeq; + vTileLoader.loadPages(idxVTile); #if USE_INPUT_KV || ENABLE_PDL == 2 #if SPEC_DEC - bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) > cacheSeqLen - inputSeqLen); + bool const anyNewTokens = + (gemm0CtaTileNbTokens * (idxVTile + 1) > cacheSeqLen - inputSeqLen); #else - bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) >= cacheSeqLen); + bool const anyNewTokens = (gemm0CtaTileNbTokens * (idxVTile + 1) >= cacheSeqLen); #endif - if (anyNewTokens) - { + if (anyNewTokens) { #if ENABLE_PDL == 2 - acqBulk(); + acqBulk(); #endif #if USE_INPUT_KV - static_assert(beamWidth == 1); - uint32_t const inputVHeadOffset - = (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; - IOHead const& inVHead = qkv[inputVHeadOffset]; - uint32_t const lane = laneId(); - float const rcpVScale = 1.F / kvCacheScale[0]; - constexpr bool isNeox = false; - auto const pairs = loadHead(inVHead, lane) * rcpVScale; - Vec, decltype(pairs)::size> convertedPairs; - constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; - reinterpret_cast&>(convertedPairs) - = convert(reinterpret_cast const&>(pairs)); - static_assert(SPEC_DEC == 0); - storeRotatedPairsForKV(vTileLoader.getHead(newTokenPos), convertedPairs, lane); - __syncwarp(); -#endif - } + static_assert(beamWidth == 1); + uint32_t const inputVHeadOffset = + (headGrpSize + 1) * nbKHeads + idxHeadGrp + (headGrpSize + 2) * nbKHeads * idxReq; + IOHead const& inVHead = qkv[inputVHeadOffset]; + uint32_t const lane = laneId(); + float const rcpVScale = 1.F / kvCacheScale[0]; + constexpr bool isNeox = false; + auto const pairs = + loadHead(inVHead, lane) * rcpVScale; + Vec, decltype(pairs)::size> convertedPairs; + constexpr uint32_t nbElems = decltype(pairs)::Elem::size * decltype(pairs)::size; + reinterpret_cast&>(convertedPairs) = + convert(reinterpret_cast const&>(pairs)); + static_assert(SPEC_DEC == 0); + storeRotatedPairsForKV(vTileLoader.getHead(newTokenPos), + convertedPairs, lane); + __syncwarp(); +#endif + } +#endif + + uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; + auto& vBar = smem.vBar[idxVBuf]; + vBar.consumed.arrive_and_wait(); + if (warpElectSync()) { +#pragma unroll + for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) { + vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced); + } + } + __syncwarp(); + } + } + } + __syncthreads(); + uint32_t const nbBarriers = &smem.gemm1WarpGrpBar - &smem.qBar.produced + 1; + uint32_t const tid = + threadIdx.x + blockDim.x * threadIdx.y + blockDim.x * blockDim.y * threadIdx.z; + assert(nbBarriers <= blockDim.x * blockDim.y * blockDim.z); + if (tid < nbBarriers) { + (&smem.qBar.produced)[tid].~CtaBarrier(); + } + if (!isMultiBlockMode) { + return; + } + bool& smemIsLastCta = smem.isLastCta; + if (threadIdx.x == gemm1NbThrds - 1U && threadIdx.z == 0) { + uint32_t const lastOld = nbSubSeq - 1; + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t old; + uint32_t const idxSemaphore = idxSeq * nbInputSeqSplit + idxInputSubSeq; + auto const pSemaphore = &semaphores[idxSemaphore]; + asm volatile("atom.acq_rel.gpu.global.inc.u32 %0, [%1], %2;\n" + : "=r"(old) + : "l"(pSemaphore), "r"(lastOld)); + smemIsLastCta = (old == lastOld); + } + { + assert(dynamicSmemSize() >= sizeof(MultiBlockSMem)); +#ifndef __CUDACC_RTC__ + assert(sizeof(MultiBlockSMem) < offsetof(SharedMem, isLastCta)); #endif + auto& smem = *reinterpret_cast(&smemByteBuf[0]); + assert(blockDim.x >= MultiBlockSMem::nbBuf); + constexpr uint32_t nbMathWarps = gemm0NbWarps + gemm1NbWarps; - uint32_t const idxVBuf = idxIter % SharedMem::nbVBuf; - auto& vBar = smem.vBar[idxVBuf]; - vBar.consumed.arrive_and_wait(); - if (warpElectSync()) - { -#pragma unroll - for (uint32_t idxPart = 0; idxPart < cacheHeadNbParts; idxPart++) - { - vTileLoader.loadData(smem.vBuf(idxVBuf)[idxPart], idxVTile, idxPart, vBar.produced); - } - } - __syncwarp(); - } - } + static_assert(nbWarps >= MultiBlockSMem::nbBuf); + if (wid < MultiBlockSMem::nbBuf) { + if (warpElectSync()) { + smem.barriers[wid].initialize(isHeadPadded ? warp_size : 1U, nbMathWarps * warp_size); + smem.barriers[wid].consumed.arrive(nbMathWarps * warp_size); + } } __syncthreads(); - uint32_t const nbBarriers = &smem.gemm1WarpGrpBar - &smem.qBar.produced + 1; - uint32_t const tid = threadIdx.x + blockDim.x * threadIdx.y + blockDim.x * blockDim.y * threadIdx.z; - assert(nbBarriers <= blockDim.x * blockDim.y * blockDim.z); - if (tid < nbBarriers) - { - (&smem.qBar.produced)[tid].~CtaBarrier(); - } - if (!isMultiBlockMode) - { - return; - } - bool& smemIsLastCta = smem.isLastCta; - if (threadIdx.x == gemm1NbThrds - 1U && threadIdx.z == 0) - { - uint32_t const lastOld = nbSubSeq - 1; - ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; - uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; - uint32_t old; - uint32_t const idxSemaphore = idxSeq * nbInputSeqSplit + idxInputSubSeq; - auto const pSemaphore = &semaphores[idxSemaphore]; - asm volatile("atom.acq_rel.gpu.global.inc.u32 %0, [%1], %2;\n" : "=r"(old) : "l"(pSemaphore), "r"(lastOld)); - smemIsLastCta = (old == lastOld); - } - { - assert(dynamicSmemSize() >= sizeof(MultiBlockSMem)); -#ifndef __CUDACC_RTC__ - assert(sizeof(MultiBlockSMem) < offsetof(SharedMem, isLastCta)); -#endif - auto& smem = *reinterpret_cast(&smemByteBuf[0]); - assert(blockDim.x >= MultiBlockSMem::nbBuf); - constexpr uint32_t nbMathWarps = gemm0NbWarps + gemm1NbWarps; - - static_assert(nbWarps >= MultiBlockSMem::nbBuf); - if (wid < MultiBlockSMem::nbBuf) - { - if (warpElectSync()) - { - smem.barriers[wid].initialize(isHeadPadded ? warp_size : 1U, nbMathWarps * warp_size); - smem.barriers[wid].consumed.arrive(nbMathWarps * warp_size); - } - } - __syncthreads(); - if (!smemIsLastCta) - { - return; + if (!smemIsLastCta) { + return; + } + if (wid < nbMathWarps) { + constexpr uint32_t headsPerWarp = divUp(ctaNbValidQHeads, nbMathWarps); + using Acc = Vec; + + struct HeadState { + Acc acc; + float sum; + float max; + }; + + Vec states{}; + for (auto& s : states.data) { + s.max = safeInitRowMax; + } + uint32_t const lane = laneId(); + for (uint32_t idxBlock = 0; idxBlock < nbSubSeq; idxBlock++) { + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.produced.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { + break; + } + HeadState& state = states[i]; + auto const sumMax = smem.rowSumMax[idxBuf][idxHead]; + auto const data = convert(reinterpret_cast&>( + smem.tokens[idxBuf][idxHead][Acc::size * lane])); + if (sumMax.max > state.max) { + float const scale = expf(state.max - sumMax.max); + state.max = sumMax.max; + state.sum = state.sum * scale + sumMax.sum; + state.acc = state.acc * scale + data * sumMax.sum; + } else { + float const scale = expf(sumMax.max - state.max); + state.sum = state.sum + sumMax.sum * scale; + state.acc = state.acc + data * (sumMax.sum * scale); + } + } + unused(bar.consumed.arrive()); + } + // Add the attention sinks. + if (attentionSinks != nullptr) { + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + float sink = + expf(attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - + states[i].max); + states[i].sum += sink; + } + } + __syncthreads(); + uint32_t const outOffset = + headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); + auto const dst = &output[outOffset]; + for (uint32_t i = 0; i < headsPerWarp; i++) { + uint32_t const idxHead = wid + nbMathWarps * i; + if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) { + break; } - if (wid < nbMathWarps) - { - constexpr uint32_t headsPerWarp = divUp(ctaNbValidQHeads, nbMathWarps); - using Acc = Vec; - - struct HeadState - { - Acc acc; - float sum; - float max; - }; - - Vec states{}; - for (auto& s : states.data) - { - s.max = safeInitRowMax; - } - uint32_t const lane = laneId(); - for (uint32_t idxBlock = 0; idxBlock < nbSubSeq; idxBlock++) - { - uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; - auto& bar = smem.barriers[idxBuf]; - bar.produced.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); - for (uint32_t i = 0; i < headsPerWarp; i++) - { - uint32_t const idxHead = wid + nbMathWarps * i; - if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) - { - break; - } - HeadState& state = states[i]; - auto const sumMax = smem.rowSumMax[idxBuf][idxHead]; - auto const data = convert( - reinterpret_cast&>(smem.tokens[idxBuf][idxHead][Acc::size * lane])); - if (sumMax.max > state.max) - { - float const scale = expf(state.max - sumMax.max); - state.max = sumMax.max; - state.sum = state.sum * scale + sumMax.sum; - state.acc = state.acc * scale + data * sumMax.sum; - } - else - { - float const scale = expf(sumMax.max - state.max); - state.sum = state.sum + sumMax.sum * scale; - state.acc = state.acc + data * (sumMax.sum * scale); - } - } - unused(bar.consumed.arrive()); - } - // Add the attention sinks. - if (attentionSinks != nullptr) - { - for (uint32_t i = 0; i < headsPerWarp; i++) - { - uint32_t const idxHead = wid + nbMathWarps * i; - float sink = expf( - attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max); - states[i].sum += sink; - } - } - __syncthreads(); - uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); - auto const dst = &output[outOffset]; - for (uint32_t i = 0; i < headsPerWarp; i++) - { - uint32_t const idxHead = wid + nbMathWarps * i; - if ((ctaNbValidQHeads % nbMathWarps != 0) && (idxHead >= ctaNbValidQHeads)) - { - break; - } #if SPEC_DEC - uint32_t const idxToken = idxHead / headGrpSize; - if (idxToken >= ctaNbValidTokens) - { - break; - } - uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); - uint32_t const idxDstHead = idxHead + idxToken * tokenPad; -#else - uint32_t const idxDstHead = idxHead; -#endif - auto const& s = states[i]; - auto const outData = convert(s.acc * (1.f / s.sum)); - if (Acc::size * lane < validElemsPerHead) - { - reinterpret_cast&>(dst[idxDstHead][Acc::size * lane]) = outData; - } - } - } - else if (wid < nbMathWarps + MultiBlockSMem::nbIOWarps) - { - static_assert(MultiBlockSMem::nbIOWarps <= MultiBlockSMem::nbBuf); - ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; - uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; - uint32_t const initIdxBlock = wid - nbMathWarps; - // each warp loads data for a block - for (uint32_t idxBlock = initIdxBlock; idxBlock < nbSubSeq; idxBlock += MultiBlockSMem::nbIOWarps) - { - uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxBlock; - uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; - uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; - auto& bar = smem.barriers[idxBuf]; - bar.consumed.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); - auto const lane = laneId(); -#pragma unroll - for (uint32_t iter = 0; iter < divUp(ctaNbValidQHeads, warp_size); iter++) - { - uint32_t const i = iter * warp_size + lane; - if (ctaNbValidQHeads % warp_size != 0 && i >= ctaNbValidQHeads) - { - break; - } - ldgsts::copyAsync( - &smem.rowSumMax[idxBuf][i], &scratchMem.rowSumMax()[idxChunk][i]); - } - ldgsts::barArrive(bar.produced, false); - if constexpr (isHeadPadded) - { - static_assert(grainsPerPaddedInputHead <= warp_size); - constexpr uint32_t headsPerIter = exactDiv(warp_size, grainsPerPaddedInputHead); - constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); - constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; -#pragma unroll - for (uint32_t i = 0; i < nbIters; i++) - { - uint32_t const idxHead = headsPerIter * i - + BoundedVal{lane}.template divBy().get(); - uint32_t const idxGrain - = BoundedVal{lane}.template mod().get(); - if (i < nbWholeIters || idxHead < ctaNbValidQHeads) - { - constexpr uint32_t nbElemsPerGrain = exactDiv(grainBytes, sizeof(MultiBlockSMem::Elem)); - auto const dst = &smem.tokens[idxBuf][idxHead][nbElemsPerGrain * idxGrain]; - auto const src = idxGrain < grainsPerIOHead - ? &scratchMem.tokens()[idxChunk][idxHead][nbElemsPerGrain * idxGrain] - : nullptr; - ldgsts::copyAsync(dst, src, idxGrain < grainsPerIOHead ? grainBytes : 0U); - } - } - ldgsts::barArrive(bar.produced, true); - } - else - { - if (warpElectSync()) - { - tma::loadLinearAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], - sizeof(smem.tokens[idxBuf]), bar.produced); - arrive_tx(bar.produced, sizeof(smem.tokens[idxBuf]), 1); - } - } - } - __syncthreads(); - uint32_t const idxBar = tid - warp_size * nbMathWarps; - if (idxBar < MultiBlockSMem::nbBuf * 2) - { - reinterpret_cast(&smem.barriers[0])[idxBar].~CtaBarrier(); + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken >= ctaNbValidTokens) { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const idxDstHead = idxHead + idxToken * tokenPad; +#else + uint32_t const idxDstHead = idxHead; +#endif + auto const& s = states[i]; + auto const outData = convert(s.acc * (1.f / s.sum)); + if (Acc::size * lane < validElemsPerHead) { + reinterpret_cast&>(dst[idxDstHead][Acc::size * lane]) = + outData; + } + } + } else if (wid < nbMathWarps + MultiBlockSMem::nbIOWarps) { + static_assert(MultiBlockSMem::nbIOWarps <= MultiBlockSMem::nbBuf); + ScratchMem const scratchMem{scratch, maxNbSubSeq * nbKHeads * batchSize, nbInputSeqSplit}; + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; + uint32_t const initIdxBlock = wid - nbMathWarps; + // each warp loads data for a block + for (uint32_t idxBlock = initIdxBlock; idxBlock < nbSubSeq; + idxBlock += MultiBlockSMem::nbIOWarps) { + uint32_t const idxAllSubSeq = maxNbSubSeq * idxSeq + idxBlock; + uint32_t const idxChunk = idxAllSubSeq * nbInputSeqSplit + idxInputSubSeq; + uint32_t const idxBuf = idxBlock % MultiBlockSMem::nbBuf; + auto& bar = smem.barriers[idxBuf]; + bar.consumed.wait_parity(idxBlock / MultiBlockSMem::nbBuf % 2 != 0); + auto const lane = laneId(); +#pragma unroll + for (uint32_t iter = 0; iter < divUp(ctaNbValidQHeads, warp_size); iter++) { + uint32_t const i = iter * warp_size + lane; + if (ctaNbValidQHeads % warp_size != 0 && i >= ctaNbValidQHeads) { + break; + } + ldgsts::copyAsync( + &smem.rowSumMax[idxBuf][i], &scratchMem.rowSumMax()[idxChunk][i]); + } + ldgsts::barArrive(bar.produced, false); + if constexpr (isHeadPadded) { + static_assert(grainsPerPaddedInputHead <= warp_size); + constexpr uint32_t headsPerIter = exactDiv(warp_size, grainsPerPaddedInputHead); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; +#pragma unroll + for (uint32_t i = 0; i < nbIters; i++) { + uint32_t const idxHead = + headsPerIter * i + + BoundedVal{lane}.template divBy().get(); + uint32_t const idxGrain = + BoundedVal{lane}.template mod().get(); + if (i < nbWholeIters || idxHead < ctaNbValidQHeads) { + constexpr uint32_t nbElemsPerGrain = + exactDiv(grainBytes, sizeof(MultiBlockSMem::Elem)); + auto const dst = &smem.tokens[idxBuf][idxHead][nbElemsPerGrain * idxGrain]; + auto const src = + idxGrain < grainsPerIOHead + ? &scratchMem.tokens()[idxChunk][idxHead][nbElemsPerGrain * idxGrain] + : nullptr; + ldgsts::copyAsync(dst, src, idxGrain < grainsPerIOHead ? grainBytes : 0U); } - } - } + } + ldgsts::barArrive(bar.produced, true); + } else { + if (warpElectSync()) { + tma::loadLinearAsync(&smem.tokens[idxBuf], &scratchMem.tokens()[idxChunk], + sizeof(smem.tokens[idxBuf]), bar.produced); + arrive_tx(bar.produced, sizeof(smem.tokens[idxBuf]), 1); + } + } + } + __syncthreads(); + uint32_t const idxBar = tid - warp_size * nbMathWarps; + if (idxBar < MultiBlockSMem::nbBuf * 2) { + reinterpret_cast(&smem.barriers[0])[idxBar].~CtaBarrier(); + } + } + } #else #if GENERATE_CUBIN - static_assert("This kernel is for Hopper only"); + static_assert("This kernel is for Hopper only"); #else - asm volatile("trap;\n"); + asm volatile("trap;\n"); #endif -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && BEAM_WIDTH == 1 +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && BEAM_WIDTH == 1 } #if CACHE_ELEM_ENUM == 0 || CACHE_ELEM_ENUM == 2 template -__device__ inline typename F16QToF8Converter::RegData F16QToF8Converter::load( - uint32_t tid, TinyPtr const& src, uint32_t const nbKHeads /*for beam search only*/, uint32_t nbTokens) -{ +__device__ inline typename F16QToF8Converter::RegData +F16QToF8Converter::load(uint32_t tid, TinyPtr const& src, + uint32_t const nbKHeads /*for beam search only*/, + uint32_t nbTokens) { #if !(SPEC_DEC) - assert(nbTokens == 1); - nbTokens = 1; + assert(nbTokens == 1); + nbTokens = 1; #endif - typename F16QToF8Converter::RegData dst; + typename F16QToF8Converter::RegData dst; #pragma unroll - for (uint32_t iter = 0; iter < nbIters; iter++) - { - uint32_t const idxGrain = nbThrds * iter + tid; - if (idxGrain >= totalGrains) - { - break; - } + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) { + break; + } #if SPEC_DEC - uint32_t const idxToken = idxGrain / grainsPerPaddedInputQHeadGrp; - uint32_t const tokenPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); - uint32_t offsetInGrains = idxGrain + tokenPad * idxToken; - static_assert(beamWidth == 1); + uint32_t const idxToken = idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const tokenPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + tokenPad * idxToken; + static_assert(beamWidth == 1); #else - uint32_t const idxBeam = beamWidth == 1 ? 0 : idxGrain / grainsPerPaddedInputQHeadGrp; - uint32_t const beamPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); - uint32_t offsetInGrains = idxGrain + beamPad * idxBeam; + uint32_t const idxBeam = beamWidth == 1 ? 0 : idxGrain / grainsPerPaddedInputQHeadGrp; + uint32_t const beamPad = grainsPerPaddedInputQHeadGrp * (nbKHeads - 1); + uint32_t offsetInGrains = idxGrain + beamPad * idxBeam; #endif - bool isGrainInBound = true; - if constexpr (isHeadPadded) - { - uint32_t const idxGrainInsideHead = offsetInGrains % grainsPerPaddedInputHead; - offsetInGrains = offsetInGrains / grainsPerPaddedInputHead * grainsPerIOHead + idxGrainInsideHead; - isGrainInBound = (idxGrainInsideHead < grainsPerIOHead); - } + bool isGrainInBound = true; + if constexpr (isHeadPadded) { + uint32_t const idxGrainInsideHead = offsetInGrains % grainsPerPaddedInputHead; + offsetInGrains = + offsetInGrains / grainsPerPaddedInputHead * grainsPerIOHead + idxGrainInsideHead; + isGrainInBound = (idxGrainInsideHead < grainsPerIOHead); + } #if SPEC_DEC - isGrainInBound = isGrainInBound && (idxToken < nbTokens); + isGrainInBound = isGrainInBound && (idxToken < nbTokens); #endif - LdGrain const srcGrain = isGrainInBound ? src.template cast()[offsetInGrains] : LdGrain{}; - static_assert(inputElemSize == 2); - auto const& fp16Data = reinterpret_cast const&>(srcGrain); - dst[iter] - = idxGrain % grainsPerPaddedInputHead < grainsPerIOHead ? fp16Data : mha::decay_t{}; - } - return dst; + LdGrain const srcGrain = + isGrainInBound ? src.template cast()[offsetInGrains] : LdGrain{}; + static_assert(inputElemSize == 2); + auto const& fp16Data = + reinterpret_cast const&>(srcGrain); + dst[iter] = idxGrain % grainsPerPaddedInputHead < grainsPerIOHead + ? fp16Data + : mha::decay_t{}; + } + return dst; } template __device__ inline void F16QToF8Converter::store( - uint32_t tid, SharedMem::QBuffer& dst, F16QToF8Converter::RegData const& data) -{ + uint32_t tid, SharedMem::QBuffer& dst, + F16QToF8Converter::RegData const& data) { #pragma unroll - for (uint32_t iter = 0; iter < nbIters; iter++) - { - uint32_t const idxGrain = nbThrds * iter + tid; - if (idxGrain >= totalGrains) - { - break; - } -#if CACHE_ELEM_ENUM == 0 - static_assert(inputElemSize == cacheElemSize); - ShmVec const& shmData = data[iter]; - uint32_t const r = idxGrain / grainsPerPaddedInputHead; - BoundedVal const c = {idxGrain % grainsPerPaddedInputHead}; - - dst[c.template divBy().get()].template at(r, c.template mod().get()) - = reinterpret_cast(shmData); -#else - auto const& fp16Data = data[iter]; - ShmVec shmData; -#pragma unroll - for (uint32_t i = 0; i < fp16Data.size; i++) - { - shmData[i] = CacheElem{fp16Data[i]}; - } - uint32_t const dstIdxGrain = idxGrain / 2; - uint32_t const dstIdxHalfGrain = idxGrain % 2; - constexpr uint32_t grainsPerCacheHead = exactDiv(paddedCacheHeadBytes, grainBytes); - uint32_t const r = dstIdxGrain / grainsPerCacheHead; - BoundedVal const c = {dstIdxGrain % grainsPerCacheHead}; - reinterpret_cast&>(dst[c.template divBy().get()].template at( - r, c.template mod().get()))[dstIdxHalfGrain] - = shmData; -#endif + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbThrds * iter + tid; + if (idxGrain >= totalGrains) { + break; } +#if CACHE_ELEM_ENUM == 0 + static_assert(inputElemSize == cacheElemSize); + ShmVec const& shmData = data[iter]; + uint32_t const r = idxGrain / grainsPerPaddedInputHead; + BoundedVal const c = {idxGrain % grainsPerPaddedInputHead}; + + dst[c.template divBy().get()].template at( + r, c.template mod().get()) = reinterpret_cast(shmData); +#else + auto const& fp16Data = data[iter]; + ShmVec shmData; +#pragma unroll + for (uint32_t i = 0; i < fp16Data.size; i++) { + shmData[i] = CacheElem{fp16Data[i]}; + } + uint32_t const dstIdxGrain = idxGrain / 2; + uint32_t const dstIdxHalfGrain = idxGrain % 2; + constexpr uint32_t grainsPerCacheHead = exactDiv(paddedCacheHeadBytes, grainBytes); + uint32_t const r = dstIdxGrain / grainsPerCacheHead; + BoundedVal const c = {dstIdxGrain % grainsPerCacheHead}; + reinterpret_cast&>( + dst[c.template divBy().get()].template at( + r, c.template mod().get()))[dstIdxHalfGrain] = shmData; +#endif + } } #endif __device__ inline KVTilePartLoader::KVTilePartLoader(bool isK, uint32_t nbKHeads, - KVCacheList const& cacheList, uint32_t idxReq, uint32_t idxHeadGrp, CUtensorMap const& tensorMap + KVCacheList const& cacheList, + uint32_t idxReq, uint32_t idxHeadGrp, + CUtensorMap const& tensorMap #if USE_PAGED_KV_CACHE - , - uint32_t nbPages, Vec& pageBuf -#endif - ) - : nbKHeads{nbKHeads} - , cacheList{cacheList} - , idxReq{idxReq} - , idxHeadGrp{idxHeadGrp} - , tensorMap{tensorMap} + , + uint32_t nbPages, + Vec& pageBuf +#endif + ) + : nbKHeads{nbKHeads}, + cacheList{cacheList}, + idxReq{idxReq}, + idxHeadGrp{idxHeadGrp}, + tensorMap{tensorMap} #if USE_PAGED_KV_CACHE - , nbPages{nbPages} - , pages{pageBuf} + , + nbPages{nbPages}, + pages{pageBuf} #if PAGED_KV_CACHE_LAYOUT == 1 - , baseOffset{idxReq * cacheList.maxNbPagesPerSeq} + , + baseOffset{idxReq * cacheList.maxNbPagesPerSeq} #else - , baseOffset{((idxReq * beamWidth) * 2 + (isK ? 0 : 1)) * cacheList.maxNbPagesPerSeq} + , + baseOffset{((idxReq * beamWidth) * 2 + (isK ? 0 : 1)) * cacheList.maxNbPagesPerSeq} #endif #else - , baseOffset{(idxReq * beamWidth) * 2 + (isK ? 0 : 1)} + , + baseOffset{(idxReq * beamWidth) * 2 + (isK ? 0 : 1)} #endif { } @@ -1850,79 +1763,76 @@ __device__ inline KVTilePartLoader::KVTilePartLoader(bool isK, uint32_t nbKHeads // tensorMap is for one whole page ([nbKHeads*tokensPerPage][headElems]) or whole cache template __device__ inline void KVTilePartLoader::loadData( - Array2D& dst, uint32_t idxTile, - uint32_t idxPart, CtaBarrier& bar) -{ - static_assert(nbTokens == gemm0CtaTileNbTokens); + Array2D& dst, + uint32_t idxTile, uint32_t idxPart, CtaBarrier& bar) { + static_assert(nbTokens == gemm0CtaTileNbTokens); #if USE_PAGED_KV_CACHE - assert(idxTile == idxTileRef); - if constexpr (nbTokens < tokensPerPage) - { - assert(nbPagesPerTile == 1); - uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); + assert(idxTile == idxTileRef); + if constexpr (nbTokens < tokensPerPage) { + assert(nbPagesPerTile == 1); + uint32_t const offset = nbTokens * (idxTile % exactDiv(tokensPerPage, nbTokens)); #if PAGED_KV_CACHE_LAYOUT == 1 - tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t) pages[0]}, bar); + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, idxHeadGrp, offset, (uint32_t)pages[0]}, bar); #else - tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t) pages[0]}, bar); + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, offset, idxHeadGrp, (uint32_t)pages[0]}, bar); #endif - } - else - { + } else { #pragma unroll - for (uint32_t i = 0; i < nbPagesPerTile; i++) - { + for (uint32_t i = 0; i < nbPagesPerTile; i++) { #if PAGED_KV_CACHE_LAYOUT == 1 - tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, - DimsLE<4>{partElems * idxPart, idxHeadGrp, 0, (uint32_t) pages[i]}, bar); + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, idxHeadGrp, 0, (uint32_t)pages[i]}, bar); #else - tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, - DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t) pages[i]}, bar); + tma::loadAsync(&dst(tokensPerPage * i, 0), tensorMap, + DimsLE<4>{partElems * idxPart, 0, idxHeadGrp, (uint32_t)pages[i]}, bar); #endif - } } + } #else - tma::loadAsync(&dst, tensorMap, DimsLE<4>{partElems * idxPart, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); + tma::loadAsync(&dst, tensorMap, + DimsLE<4>{partElems * idxPart, nbTokens * idxTile, idxHeadGrp, baseOffset}, bar); #endif } -__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) -{ +__device__ inline void KVTilePartLoader::loadPages(uint32_t idxTile) { #if USE_PAGED_KV_CACHE - uint32_t const idxPageBeg = gemm0CtaTileNbTokens >= tokensPerPage - ? nbPagesPerTile * idxTile - : idxTile / exactDiv(tokensPerPage, gemm0CtaTileNbTokens); -#pragma unroll - for (uint32_t i = 0; i < nbPagesPerTile; i++) - { - uint32_t const idxPage = idxPageBeg + i; - auto const page = idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; - if (warpElectSync()) - { - pages[i] = page; - } + uint32_t const idxPageBeg = gemm0CtaTileNbTokens >= tokensPerPage + ? nbPagesPerTile * idxTile + : idxTile / exactDiv(tokensPerPage, gemm0CtaTileNbTokens); +#pragma unroll + for (uint32_t i = 0; i < nbPagesPerTile; i++) { + uint32_t const idxPage = idxPageBeg + i; + auto const page = + idxPage < nbPages ? cacheList.kvCachePageList[baseOffset + idxPage] : kBAD_PAGE_INDEX; + if (warpElectSync()) { + pages[i] = page; } - idxTileRef = idxTile; - __syncwarp(); + } + idxTileRef = idxTile; + __syncwarp(); #endif } -__device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) -{ - constexpr uint32_t nbTokens = gemm0CtaTileNbTokens; +__device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) { + constexpr uint32_t nbTokens = gemm0CtaTileNbTokens; #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 - // Raise a runtime error indicating not implemented - assert(false && "KVTilePartLoader::getHead is not implemented for PAGED_KV_CACHE_LAYOUT == 1"); - __trap(); + // Raise a runtime error indicating not implemented + assert(false && "KVTilePartLoader::getHead is not implemented for PAGED_KV_CACHE_LAYOUT == 1"); + __trap(); #else - uint32_t const idxTile = pos / nbTokens; - assert(idxTile == idxTileRef); - uint32_t const offset = pos % tokensPerPage; - return cacheList.pool[tokensPerPage * (nbKHeads * pages[pos % nbTokens / tokensPerPage] + idxHeadGrp) + offset]; + uint32_t const idxTile = pos / nbTokens; + assert(idxTile == idxTileRef); + uint32_t const offset = pos % tokensPerPage; + return cacheList + .pool[tokensPerPage * (nbKHeads * pages[pos % nbTokens / tokensPerPage] + idxHeadGrp) + + offset]; #endif #else - // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity] - return cacheList.data[cacheList.capacity * (baseOffset * nbKHeads + idxHeadGrp) + pos]; + // shape: KVCacheHead[batchSize][beamWidth][2][nbKHeads][capacity] + return cacheList.data[cacheList.capacity * (baseOffset * nbKHeads + idxHeadGrp) + pos]; #endif } @@ -1930,917 +1840,803 @@ __device__ inline GMemKVCacheHead& KVTilePartLoader::getHead(uint32_t pos) #if SPEC_DEC __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - int32_t tok0WinBeg, + int32_t tok0WinBeg, #endif - uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) -{ - constexpr uint32_t tileSize = gemm0CtaTileNbTokens; - static_assert(SPEC_Q_SEQ_LEN <= sizeof(MaskType) * 8, "not implemented"); + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) { + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + static_assert(SPEC_Q_SEQ_LEN <= sizeof(MaskType) * 8, "not implemented"); - assert(cacheSeqLen >= SPEC_Q_SEQ_LEN); - uint32_t const maskStartRow = cacheSeqLen - SPEC_Q_SEQ_LEN; - uint32_t const tileStartRow = tileSize * idxTile; - if (tileStartRow + tileSize < maskStartRow) - { - return; - } + assert(cacheSeqLen >= SPEC_Q_SEQ_LEN); + uint32_t const maskStartRow = cacheSeqLen - SPEC_Q_SEQ_LEN; + uint32_t const tileStartRow = tileSize * idxTile; + if (tileStartRow + tileSize < maskStartRow) { + return; + } - uint32_t const idxInQuad = laneId() % 4; - uint32_t const idxQuad = laneId() / 4; + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; #pragma unroll - for (uint32_t n = 0; n < acc.cols; n++) - { + for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; - uint32_t const maskCol = col / headGrpSize; - MaskType const bit_mask = (1ULL << (maskCol + 1)) - 1; + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + uint32_t const maskCol = col / headGrpSize; + MaskType const bit_mask = (1ULL << (maskCol + 1)) - 1; #pragma unroll - for (uint32_t m = 0; m < acc.rows; m++) - { + for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; - uint32_t const globalRow = tileStartRow + row; - if (globalRow >= cacheSeqLen) - { - acc(m, n)(i, j) = safeInitRowMax; - continue; - } - if (globalRow >= maskStartRow) - { - uint32_t const maskRow = globalRow - maskStartRow; - if ((bit_mask >> maskRow) == 0) - { - acc(m, n)(i, j) = safeInitRowMax; - } - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const globalRow = tileStartRow + row; + if (globalRow >= cacheSeqLen) { + acc(m, n)(i, j) = safeInitRowMax; + continue; + } + if (globalRow >= maskStartRow) { + uint32_t const maskRow = globalRow - maskStartRow; + if ((bit_mask >> maskRow) == 0) { + acc(m, n)(i, j) = safeInitRowMax; } + } } + } } + } } -#endif // SPEC_DEC +#endif // SPEC_DEC // smemColMax is persistent across multiple iterations -__device__ inline RegColWiseVec computeWarpGrpColMax_sync( - CtaBarrier& warpGrpBar, ShmQWiseVec& smemColMax, Gemm0Acc const& src) -{ - auto colMax = RegColWiseVec::filled(Vec::filled(safeInitRowMax)); +__device__ inline RegColWiseVec computeWarpGrpColMax_sync(CtaBarrier& warpGrpBar, + ShmQWiseVec& smemColMax, + Gemm0Acc const& src) { + auto colMax = RegColWiseVec::filled(Vec::filled(safeInitRowMax)); #pragma unroll - for (uint32_t n = 0; n < src.cols; n++) - { - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { + for (uint32_t n = 0; n < src.cols; n++) { + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { #pragma unroll - for (uint32_t m = 0; m < src.rows; m++) - { + for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - colMax[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : fmax(colMax[n][j], src(m, n)(i, j)); - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + colMax[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : fmax(colMax[n][j], src(m, n)(i, j)); } + } } + } #pragma unroll - for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) - { + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { #pragma unroll - for (uint32_t n = 0; n < src.cols; n++) - { + for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < 2; j++) - { - auto& x = colMax[n][j]; - x = fmax(x, __shfl_xor_sync(~0U, x, xorMask)); - } - } + for (uint32_t j = 0; j < 2; j++) { + auto& x = colMax[n][j]; + x = fmax(x, __shfl_xor_sync(~0U, x, xorMask)); + } } + } - uint32_t const lane = laneId(); - if (lane < 4) - { + uint32_t const lane = laneId(); + if (lane < 4) { #pragma unroll - for (uint32_t n = 0; n < src.cols; n++) - { + for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < 2; j++) - { - atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); - } - } + for (uint32_t j = 0; j < 2; j++) { + atomicMax(&smemColMax[8 * n + 2 * lane + j], colMax[n][j]); + } } - warpGrpBar.arrive_and_wait(); - uint32_t const idxInQuad = lane % 4; + } + warpGrpBar.arrive_and_wait(); + uint32_t const idxInQuad = lane % 4; #pragma unroll - for (uint32_t n = 0; n < src.cols; n++) - { + for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]); - colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j]; - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + assert(colMax[n][j] <= smemColMax[8 * n + 2 * idxInQuad + j]); + colMax[n][j] = smemColMax[8 * n + 2 * idxInQuad + j]; } - warpGrpBar.arrive_and_wait(); - return colMax; + } + warpGrpBar.arrive_and_wait(); + return colMax; } -__device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec) -{ - RegColWiseVec ret; - constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); - auto const idx = laneId() % nbThrdsPerInstNBase; -#pragma unroll - for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) - { - static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); - ret[i] = reinterpret_cast< - Vec, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( - smemVec)[i * nbThrdsPerInstNBase + idx]; - } - return ret; +__device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec) { + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast, + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + smemVec)[i * nbThrdsPerInstNBase + idx]; + } + return ret; } -__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound) -{ - RegColWiseVec ret; - constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); - auto const idx = laneId() % nbThrdsPerInstNBase; -#pragma unroll - for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) - { - static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); - ret[i] = reinterpret_cast< - Vec, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( - gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; - } - return ret; +__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, + uint32_t bound) { + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast, + exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + } + return ret; } -__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd) -{ - uint32_t const idxInQuad = laneId() % 4; - uint32_t const idxQuad = laneId() / 4; -#pragma unroll - for (uint32_t m = 0; m < acc.rows; m++) - { -#pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - uint32_t const row = 64 * m + 16 * warpRank + 8 * i + idxQuad; - if (row >= validRowBeg && row < validRowEnd) - { - continue; - } +__device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, + uint32_t validRowEnd) { + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; #pragma unroll - for (uint32_t n = 0; n < acc.cols; n++) - { + for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - acc(m, n)(i, j) = safeInitRowMax; - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = 64 * m + 16 * warpRank + 8 * i + idxQuad; + if (row >= validRowBeg && row < validRowEnd) { + continue; + } +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) = safeInitRowMax; } + } } + } } -__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax) -{ +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegColWiseVec const& colMax) { #pragma unroll - for (uint32_t n = 0; n < acc.cols; n++) - { + for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - float const maxVal = colMax[n][j]; - float const bias = maxVal * log2e; + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + float const maxVal = colMax[n][j]; + float const bias = maxVal * log2e; #pragma unroll - for (uint32_t m = 0; m < acc.rows; m++) - { + for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - float& elem = acc(m, n)(i, j); - assert(maxVal >= elem); - elem = exp2f(elem * log2e - bias); - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); } + } } + } } -__device__ inline RegColWiseVec computeWarpColSum(Gemm0Acc& src) -{ - auto colSum = RegColWiseVec::filled(Vec::filled(0)); +__device__ inline RegColWiseVec computeWarpColSum(Gemm0Acc& src) { + auto colSum = RegColWiseVec::filled(Vec::filled(0)); #pragma unroll - for (uint32_t n = 0; n < src.cols; n++) - { + for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { #pragma unroll - for (uint32_t m = 0; m < src.rows; m++) - { + for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - colSum[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : colSum[n][j] + src(m, n)(i, j); - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + colSum[n][j] = (m == 0 && i == 0) ? src(m, n)(i, j) : colSum[n][j] + src(m, n)(i, j); } + } } + } #pragma unroll - for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) - { + for (uint32_t xorMask = 16; xorMask > 2; xorMask /= 2) { #pragma unroll - for (uint32_t n = 0; n < src.cols; n++) - { + for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - auto& x = colSum[n][j]; - x += __shfl_xor_sync(~0U, x, xorMask); - } - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + auto& x = colSum[n][j]; + x += __shfl_xor_sync(~0U, x, xorMask); + } } - return colSum; + } + return colSum; } -__device__ inline void storeGemm0AccToShm( - uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc) -{ +__device__ inline void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, + SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, + Gemm0Acc const& acc) { #if CACHE_ELEM_ENUM == 0 - using F16Acc = Array2D, Gemm0Acc::rows, Gemm0Acc::cols>; - F16Acc f16Acc; - reinterpret_cast&>(f16Acc) - = convert(reinterpret_cast const&>(acc)); - static_assert(Gemm0Acc::size == 1 || Gemm0Acc::size % 2 == 0); - uint32_t const idxHalf = lane / 16; - uint32_t const idxInHalf = lane % 16; - uint32_t const idxOctInsideHalf = idxInHalf / 8; - uint32_t const idxRowInsideOct = lane % 8; - uint32_t const warpBaseC = 16 * warpRank; - auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair - { - uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols; - uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols; - return {accR, accC}; - }; - auto const getDstAddr = [&](uint32_t idxAccCoreMat) -> LdGrain* - { - auto const [accR, accC] = toAccCoords(idxAccCoreMat); - static_assert(sizeof(MathElem) * gemm0CtaTileNbTokens == xPartBytes); - uint32_t const idxPart = 0; - uint32_t const dstR = accC * 8 + idxRowInsideOct; - uint32_t const dstC = exactDiv(gmma::instM * accR + warpBaseC + 8 * idxOctInsideHalf, cacheElemsPerGrain); - assert(dstC / exactDiv(xPartBytes, grainBytes) == idxPart); - return &smemX[idxPart].template at(dstR, dstC); - }; - auto const getAccData = [&](uint32_t idxAccCoreMat) - { - auto const [accR, accC] = toAccCoords(idxAccCoreMat); - return f16Acc(accR, accC); - }; - - barConsumed.arrive_and_wait(); -#pragma unroll - for (uint32_t iter = 0; iter < Gemm0Acc::size / 2; iter++) - { - auto const dstAddr = getDstAddr(iter * 2 + idxHalf); - Vec const data[2] = {getAccData(iter * 2), getAccData(iter * 2 + 1)}; - stmatrix(dstAddr, reinterpret_cast(data)); - } - if constexpr (Gemm0Acc::size % 2 != 0) - { - auto const dstAddr = lane < 16 ? getDstAddr(Gemm0Acc::size - 1) : nullptr; - stmatrix(dstAddr, getAccData(Gemm0Acc::size - 1)); - } + using F16Acc = Array2D, Gemm0Acc::rows, Gemm0Acc::cols>; + F16Acc f16Acc; + reinterpret_cast&>(f16Acc) = + convert(reinterpret_cast const&>(acc)); + static_assert(Gemm0Acc::size == 1 || Gemm0Acc::size % 2 == 0); + uint32_t const idxHalf = lane / 16; + uint32_t const idxInHalf = lane % 16; + uint32_t const idxOctInsideHalf = idxInHalf / 8; + uint32_t const idxRowInsideOct = lane % 8; + uint32_t const warpBaseC = 16 * warpRank; + auto const toAccCoords = [](uint32_t const idxAccCoreMat) -> std::pair { + uint32_t const accR = idxAccCoreMat / Gemm0Acc::cols; + uint32_t const accC = idxAccCoreMat % Gemm0Acc::cols; + return {accR, accC}; + }; + auto const getDstAddr = [&](uint32_t idxAccCoreMat) -> LdGrain* { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + static_assert(sizeof(MathElem) * gemm0CtaTileNbTokens == xPartBytes); + uint32_t const idxPart = 0; + uint32_t const dstR = accC * 8 + idxRowInsideOct; + uint32_t const dstC = + exactDiv(gmma::instM * accR + warpBaseC + 8 * idxOctInsideHalf, cacheElemsPerGrain); + assert(dstC / exactDiv(xPartBytes, grainBytes) == idxPart); + return &smemX[idxPart].template at(dstR, dstC); + }; + auto const getAccData = [&](uint32_t idxAccCoreMat) { + auto const [accR, accC] = toAccCoords(idxAccCoreMat); + return f16Acc(accR, accC); + }; + + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t iter = 0; iter < Gemm0Acc::size / 2; iter++) { + auto const dstAddr = getDstAddr(iter * 2 + idxHalf); + Vec const data[2] = {getAccData(iter * 2), getAccData(iter * 2 + 1)}; + stmatrix(dstAddr, reinterpret_cast(data)); + } + if constexpr (Gemm0Acc::size % 2 != 0) { + auto const dstAddr = lane < 16 ? getDstAddr(Gemm0Acc::size - 1) : nullptr; + stmatrix(dstAddr, getAccData(Gemm0Acc::size - 1)); + } #elif CACHE_ELEM_ENUM == 2 - using F8Acc = Array2D; - F8Acc f8Acc; -#pragma unroll - for (uint32_t i = 0; i < acc.rows; i++) - { -#pragma unroll - for (uint32_t j = 0; j < acc.cols; j++) - { - auto const& core = acc(i, j); - static_assert(mha::is_same_v); - Vec const f8Data - = {__nv_cvt_float2_to_fp8x2(float2{core(0, 0), core(1, 0)}, __NV_SATFINITE, __NV_E4M3), - __nv_cvt_float2_to_fp8x2(float2{core(0, 1), core(1, 1)}, __NV_SATFINITE, __NV_E4M3)}; - f8Acc(i, j) = reinterpret_cast(f8Data); - } - } - - if constexpr (F8Acc::size == 4 || F8Acc::size == 2 || F8Acc::size == 1) - { - LdGrain* dst = nullptr; - if (F8Acc::size == 4 || lane < 8 * F8Acc::size) - { - uint32_t const idxCore = lane / 8; - uint32_t const srcRow = idxCore / F8Acc::cols; - uint32_t const srcCol = idxCore % F8Acc::cols; - uint32_t const dstCoreRow = lane % 8; - uint32_t const dstRow = srcCol * 8 + dstCoreRow; - BoundedVal const dstCol{srcRow * 4 + warpRank}; - dst = &smemX[dstCol.template divBy().get()].template at( - dstRow, dstCol.template mod().get()); - } - barConsumed.arrive_and_wait(); - stmatrix(dst, reinterpret_cast const&>(f8Acc)); - } - else - { - // we need to use loops - assert(false); - trap(); + using F8Acc = Array2D; + F8Acc f8Acc; +#pragma unroll + for (uint32_t i = 0; i < acc.rows; i++) { +#pragma unroll + for (uint32_t j = 0; j < acc.cols; j++) { + auto const& core = acc(i, j); + static_assert(mha::is_same_v); + Vec const f8Data = { + __nv_cvt_float2_to_fp8x2(float2{core(0, 0), core(1, 0)}, __NV_SATFINITE, __NV_E4M3), + __nv_cvt_float2_to_fp8x2(float2{core(0, 1), core(1, 1)}, __NV_SATFINITE, __NV_E4M3)}; + f8Acc(i, j) = reinterpret_cast(f8Data); + } + } + + if constexpr (F8Acc::size == 4 || F8Acc::size == 2 || F8Acc::size == 1) { + LdGrain* dst = nullptr; + if (F8Acc::size == 4 || lane < 8 * F8Acc::size) { + uint32_t const idxCore = lane / 8; + uint32_t const srcRow = idxCore / F8Acc::cols; + uint32_t const srcCol = idxCore % F8Acc::cols; + uint32_t const dstCoreRow = lane % 8; + uint32_t const dstRow = srcCol * 8 + dstCoreRow; + BoundedVal const dstCol{ + srcRow * 4 + warpRank}; + dst = &smemX[dstCol.template divBy().get()].template at( + dstRow, dstCol.template mod().get()); } + barConsumed.arrive_and_wait(); + stmatrix(dst, reinterpret_cast const&>(f8Acc)); + } else { + // we need to use loops + assert(false); + trap(); + } #endif } #else -__device__ inline RegRowWiseVec warpRowWiseReduce( - RegRowWiseVec const& init, Gemm0Acc const& src, float (*op)(float, float)) -{ - RegRowWiseVec vec = init; +__device__ inline RegRowWiseVec warpRowWiseReduce(RegRowWiseVec const& init, Gemm0Acc const& src, + float (*op)(float, float)) { + RegRowWiseVec vec = init; #pragma unroll - for (uint32_t m = 0; m < src.rows; m++) - { + for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { #pragma unroll - for (uint32_t n = 0; n < src.cols; n++) - { + for (uint32_t n = 0; n < src.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - // @fixme: check if compiler is reordering these op to hide latency. - vec[m][i] = op(vec[m][i], src(m, n)(i, j)); - } - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + // @fixme: check if compiler is reordering these op to hide latency. + vec[m][i] = op(vec[m][i], src(m, n)(i, j)); } + } } + } #pragma unroll - for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) - { + for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) { #pragma unroll - for (uint32_t m = 0; m < src.rows; m++) - { + for (uint32_t m = 0; m < src.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - auto& x = vec[m][i]; - x = op(x, __shfl_xor_sync(~0U, x, xorMask)); - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + auto& x = vec[m][i]; + x = op(x, __shfl_xor_sync(~0U, x, xorMask)); + } } - return vec; + } + return vec; } -__device__ inline RegRowWiseVec computeWarpGrpRowMax_sync( - uint32_t warpRank, ShmQWiseVec& smemRowMax, Gemm0Acc const& src) -{ - assert(warpRank < 4); - RegRowWiseVec const init = loadShmRowWiseVecWithDup(warpRank, smemRowMax); - RegRowWiseVec rowMax = warpRowWiseReduce(init, src, fmax); +__device__ inline RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, + ShmQWiseVec& smemRowMax, + Gemm0Acc const& src) { + assert(warpRank < 4); + RegRowWiseVec const init = loadShmRowWiseVecWithDup(warpRank, smemRowMax); + RegRowWiseVec rowMax = warpRowWiseReduce(init, src, fmax); - storeShmRowWiseVec(warpRank, smemRowMax, rowMax); - __syncwarp(); - return rowMax; + storeShmRowWiseVec(warpRank, smemRowMax, rowMax); + __syncwarp(); + return rowMax; } #if SPEC_DEC __device__ inline void warpGrpApplyMask(Gemm0Acc& acc, SpecDec const& specDec, #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - int32_t tok0WinBeg, -#endif - uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) -{ - constexpr uint32_t tileSize = gemm0CtaTileNbTokens; - auto const inputSeqLen = specDec.inputSeqLen; - auto const idxInputSubSeq = specDec.idxInputSubSeq; - constexpr uint64_t fullMask = ~uint64_t{0}; - static_assert(tileSize == sizeof(fullMask) * 8); + int32_t tok0WinBeg, +#endif + uint32_t cacheSeqLen, uint32_t idxTile, uint32_t warpRank) { + constexpr uint32_t tileSize = gemm0CtaTileNbTokens; + auto const inputSeqLen = specDec.inputSeqLen; + auto const idxInputSubSeq = specDec.idxInputSubSeq; + constexpr uint64_t fullMask = ~uint64_t{0}; + static_assert(tileSize == sizeof(fullMask) * 8); #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; - Range const tileRange = {tileSize * idxTile, tileSize * idxTile + tileSize}; - Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (inputTokensPerCta - 1)}; - bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end; - assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange)); - int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(tileSize * idxTile); -#else - constexpr bool ctaNeedBegMask = false; - uint64_t const begMask = fullMask; - int32_t const tok0NbMaskOut = -2147483648; -#endif - uint32_t const offset = tileSize * idxTile; - uint32_t const nbValidCols = mha::min(offset < cacheSeqLen ? cacheSeqLen - offset : 0U, tileSize); - bool const ctaNeedEndMask = (nbValidCols < tileSize); - bool const ctaNeedSpecDecMask = specDec.needMask(idxTile, 0); - bool const needMask = ctaNeedBegMask || ctaNeedEndMask || ctaNeedSpecDecMask; - if (!needMask) - { - return; - } - static_assert(tileSize == 64, "not implemented"); - auto const endMask = fullMask >> (tileSize - nbValidCols); - - uint32_t const idxInQuad = laneId() % 4; - uint32_t const idxQuad = laneId() / 4; -#pragma unroll - for (uint32_t m = 0; m < acc.rows; m++) - { -#pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; - uint32_t const idxQTokInCta = row / headGrpSize; - bool const isQTokValid - = (headGrpSize * inputTokensPerCta == ctaNbQHeads) || (idxQTokInCta < inputTokensPerCta); - auto const specDecMask = (isQTokValid && specDec.needMask(idxTile, idxQTokInCta)) - ? specDec.loadTileMaskRow(idxTile, idxQTokInCta) - : SpecDec::TileMaskRow{~0U, ~0U}; + uint32_t const ctaTokOffset = inputTokensPerCta * idxInputSubSeq; + Range const tileRange = {tileSize * idxTile, tileSize * idxTile + tileSize}; + Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (inputTokensPerCta - 1)}; + bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end; + assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange)); + int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(tileSize * idxTile); +#else + constexpr bool ctaNeedBegMask = false; + uint64_t const begMask = fullMask; + int32_t const tok0NbMaskOut = -2147483648; +#endif + uint32_t const offset = tileSize * idxTile; + uint32_t const nbValidCols = mha::min(offset < cacheSeqLen ? cacheSeqLen - offset : 0U, tileSize); + bool const ctaNeedEndMask = (nbValidCols < tileSize); + bool const ctaNeedSpecDecMask = specDec.needMask(idxTile, 0); + bool const needMask = ctaNeedBegMask || ctaNeedEndMask || ctaNeedSpecDecMask; + if (!needMask) { + return; + } + static_assert(tileSize == 64, "not implemented"); + auto const endMask = fullMask >> (tileSize - nbValidCols); + + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const row = gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad; + uint32_t const idxQTokInCta = row / headGrpSize; + bool const isQTokValid = + (headGrpSize * inputTokensPerCta == ctaNbQHeads) || (idxQTokInCta < inputTokensPerCta); + auto const specDecMask = (isQTokValid && specDec.needMask(idxTile, idxQTokInCta)) + ? specDec.loadTileMaskRow(idxTile, idxQTokInCta) + : SpecDec::TileMaskRow{~0U, ~0U}; #if SLIDING_WINDOW && !IS_SPEC_DEC_TREE - int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta); - uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask); + int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta); + uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask); #else - uint64_t const begMask = fullMask; + uint64_t const begMask = fullMask; #endif - auto const mask = begMask & endMask & reinterpret_cast(specDecMask); - if (mask == ~uint64_t{0}) - { - continue; - } + auto const mask = begMask & endMask & reinterpret_cast(specDecMask); + if (mask == ~uint64_t{0}) { + continue; + } #if DBG_PRINT - if (idxInQuad == 0) - { - printf("mask at row %d: %lx\n", row, mask); - } + if (idxInQuad == 0) { + printf("mask at row %d: %lx\n", row, mask); + } #endif #pragma unroll - for (uint32_t n = 0; n < acc.cols; n++) - { + for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; - assert((col < nbValidCols) == bool(endMask & (1ULL << col))); - if ((mask & (1ULL << col)) == 0) - { - acc(m, n)(i, j) = safeInitRowMax; - } - } - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + assert((col < nbValidCols) == bool(endMask & (1ULL << col))); + if ((mask & (1ULL << col)) == 0) { + acc(m, n)(i, j) = safeInitRowMax; + } } + } } + } } #else -__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd) -{ - uint32_t const idxInQuad = laneId() % 4; +__device__ inline void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd) { + uint32_t const idxInQuad = laneId() % 4; #pragma unroll - for (uint32_t n = 0; n < acc.cols; n++) - { + for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; - if (col >= validColBeg && col < validColEnd) - { - continue; - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + uint32_t const col = GmmaAccCoreMat::cols * (4 * n + idxInQuad) + j; + if (col >= validColBeg && col < validColEnd) { + continue; + } #pragma unroll - for (uint32_t m = 0; m < acc.rows; m++) - { + for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - acc(m, n)(i, j) = safeInitRowMax; - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) = safeInitRowMax; } + } } + } } #endif -__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& rowMax) -{ +__device__ inline void warpGrpOnlineSoftmax(Gemm0Acc& acc, RegRowWiseVec const& rowMax) { #pragma unroll - for (uint32_t m = 0; m < acc.rows; m++) - { + for (uint32_t m = 0; m < acc.rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - float const maxVal = rowMax[m][i]; - float const bias = maxVal * log2e; + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + float const maxVal = rowMax[m][i]; + float const bias = maxVal * log2e; #pragma unroll - for (uint32_t n = 0; n < acc.cols; n++) - { + for (uint32_t n = 0; n < acc.cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - float& elem = acc(m, n)(i, j); - assert(maxVal >= elem); - elem = exp2f(elem * log2e - bias); - } - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); } + } } + } } -__device__ inline RegRowWiseVec computeWarpRowSum(Gemm0Acc& src) -{ - return warpRowWiseReduce(RegRowWiseVec{}, src, [](float a, float b) { return a + b; }); +__device__ inline RegRowWiseVec computeWarpRowSum(Gemm0Acc& src) { + return warpRowWiseReduce(RegRowWiseVec{}, src, [](float a, float b) { return a + b; }); } -__device__ inline RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, ShmQWiseVec const& smemVec) -{ - RegRowWiseVec vec; - uint32_t const idxQuad = laneId() / 4; +__device__ inline RegRowWiseVec loadShmRowWiseVecWithDup(uint32_t warpRank, + ShmQWiseVec const& smemVec) { + RegRowWiseVec vec; + uint32_t const idxQuad = laneId() / 4; #pragma unroll - for (uint32_t m = 0; m < RegRowWiseVec::size; m++) - { + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { #pragma unroll - for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) - { - vec[m][i] = smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad]; - } + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { + vec[m][i] = smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad]; } - return vec; + } + return vec; } -__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, RegRowWiseVec const& regVec) -{ - uint32_t const lane = laneId(); - uint32_t const idxQuad = lane / 4; - uint32_t const idxInQuad = lane % 4; - bool const enable = (idxInQuad == 0); -#pragma unroll - for (uint32_t m = 0; m < RegRowWiseVec::size; m++) - { -#pragma unroll - for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) - { - assert(__shfl_sync(~0U, regVec[m][i], idxQuad * 4) == regVec[m][i]); - if (enable) - { - smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad] = regVec[m][i]; - } - } +__device__ void storeShmRowWiseVec(uint32_t warpRank, ShmQWiseVec& smemVec, + RegRowWiseVec const& regVec) { + uint32_t const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + bool const enable = (idxInQuad == 0); +#pragma unroll + for (uint32_t m = 0; m < RegRowWiseVec::size; m++) { +#pragma unroll + for (uint32_t i = 0; i < RegRowWiseVec::Elem::size; i++) { + assert(__shfl_sync(~0U, regVec[m][i], idxQuad * 4) == regVec[m][i]); + if (enable) { + smemVec[gmma::instM * m + gmma::instM / 4 * warpRank + 8 * i + idxQuad] = regVec[m][i]; + } } + } } // for X // order: 0,8,1,9, 2,10,3,11, 4,12,5,13, 6,14,7,15, ... -__device__ inline void storeGemm0AccToShm( - uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc) -{ - uint32_t const idxMat = lane / 8; - uint32_t const idxRow = lane % 8; - barConsumed.arrive_and_wait(); -#pragma unroll - for (uint32_t m = 0; m < Gemm0Acc::rows; m++) - { -#pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - Vec fp8Data; -#pragma unroll - for (uint32_t n = 0; n < exactDiv(Gemm0Acc::cols, 2); n++) - { - reinterpret_cast&>(fp8Data[n]) - = {__nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0)}), - __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)})}; - } - static_assert(decltype(fp8Data)::size == 4); - stmatrix_4x( - this_warp(), &smemX[m].template at(16 * warpRank + 8 * i + idxRow, idxMat), fp8Data); - } - } +__device__ inline void storeGemm0AccToShm(uint32_t warpRank, uint32_t lane, + SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, + Gemm0Acc const& acc) { + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; + barConsumed.arrive_and_wait(); +#pragma unroll + for (uint32_t m = 0; m < Gemm0Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + Vec fp8Data; +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm0Acc::cols, 2); n++) { + reinterpret_cast&>(fp8Data[n]) = { + __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0)}), + __nv_fp8x2_e4m3(float2{acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)})}; + } + static_assert(decltype(fp8Data)::size == 4); + stmatrix_4x(this_warp(), + &smemX[m].template at(16 * warpRank + 8 * i + idxRow, idxMat), + fp8Data); + } + } } #endif #if SWAP_AB __device__ inline Vec loadVTileTransposed( - uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK) -{ - Vec fragA; - constexpr uint32_t instK = gmma::instK; + uint32_t warpRank, uint32_t lane, SharedMem::VBuffer const& smemV, uint32_t idxGmmaInstK) { + Vec fragA; + constexpr uint32_t instK = gmma::instK; #pragma unroll - for (uint32_t i = 0; i < gemm1NbGmmaInstM; i++) - { - static_assert(exactDiv(gmma::instM, gmmaWarpsPerGrp) == grainBytes); - constexpr uint32_t grainsPerPart = exactDiv(cacheHeadPartBytes, grainBytes); + for (uint32_t i = 0; i < gemm1NbGmmaInstM; i++) { + static_assert(exactDiv(gmma::instM, gmmaWarpsPerGrp) == grainBytes); + constexpr uint32_t grainsPerPart = exactDiv(cacheHeadPartBytes, grainBytes); #if CACHE_ELEM_ENUM == 0 - uint32_t idxRow = lane % 8; - uint32_t idxMat = lane / 8; - uint32_t c = idxMat % 2; - uint32_t r = idxMat / 2; - auto const col = BoundedVal<2 * gmmaWarpsPerGrp * gemm1NbGmmaInstM>{2 * (gmmaWarpsPerGrp * i + warpRank) + c}; - auto const src = &smemV[col.template divBy().get()].template at( - instK * idxGmmaInstK + 8 * r + idxRow, col.template mod().get()); - auto const data = ldmatrix(src); - fragA[i] = reinterpret_cast(data); + uint32_t idxRow = lane % 8; + uint32_t idxMat = lane / 8; + uint32_t c = idxMat % 2; + uint32_t r = idxMat / 2; + auto const col = BoundedVal<2 * gmmaWarpsPerGrp * gemm1NbGmmaInstM>{ + 2 * (gmmaWarpsPerGrp * i + warpRank) + c}; + auto const src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + 8 * r + idxRow, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i] = reinterpret_cast(data); #elif CACHE_ELEM_ENUM == 2 - auto const col = BoundedVal{gmmaWarpsPerGrp * i + warpRank}; - LdGrain const* src = &smemV[col.template divBy().get()].template at( - instK * idxGmmaInstK + lane, col.template mod().get()); - auto const data = ldmatrix(src); - fragA[i](0, 0)(0, 0) = prmt(data[0], data[1], {0, 4, 2, 6}); - fragA[i](0, 0)(1, 0) = prmt(data[0], data[1], {1, 5, 3, 7}); - fragA[i](0, 1)(0, 0) = prmt(data[2], data[3], {0, 4, 2, 6}); - fragA[i](0, 1)(1, 0) = prmt(data[2], data[3], {1, 5, 3, 7}); -#endif - } - return fragA; + auto const col = BoundedVal{gmmaWarpsPerGrp * i + warpRank}; + LdGrain const* src = &smemV[col.template divBy().get()].template at( + instK * idxGmmaInstK + lane, col.template mod().get()); + auto const data = ldmatrix(src); + fragA[i](0, 0)(0, 0) = prmt(data[0], data[1], {0, 4, 2, 6}); + fragA[i](0, 0)(1, 0) = prmt(data[0], data[1], {1, 5, 3, 7}); + fragA[i](0, 1)(0, 0) = prmt(data[2], data[3], {0, 4, 2, 6}); + fragA[i](0, 1)(1, 0) = prmt(data[2], data[3], {1, 5, 3, 7}); +#endif + } + return fragA; } #else -__device__ inline void transposeVTile( - uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src) -{ - uint32_t const idxMat = lane / 8; - uint32_t const idxRow = lane % 8; -#pragma unroll - for (uint32_t m = 0; m < exactDiv(SharedMem::VTBuffer::rows, gmma::instM); m++) - { - static_assert(cacheHeadPartElems >= gmma::instM); - uint32_t const idxPart = gmma::instM * m / cacheHeadPartElems; - constexpr uint32_t grainsPerCacheHeadPart = exactDiv(cacheHeadPartElems, cacheElemsPerGrain); -#pragma unroll - for (uint32_t n = 0; n < exactDiv(SharedMem::VTBuffer::cols, 2); n++) - { - LdGrain const a = ldmatrix_4x(this_warp(), - &src[idxPart].template at(32 * n + lane, - exactDiv(gmma::instM, cacheElemsPerGrain) * m - grainsPerCacheHeadPart * idxPart + warpRank)); - LdGrain const b = {prmt(a[0], a[1], {0, 4, 2, 6}), prmt(a[0], a[1], {1, 5, 3, 7}), - prmt(a[2], a[3], {0, 4, 2, 6}), prmt(a[2], a[3], {1, 5, 3, 7})}; - uint32_t const i = idxMat % 2; - uint32_t const j = idxMat / 2; - stmatrix_4x( - this_warp(), &dst.template at(gmma::instM * m + 16 * warpRank + 8 * i + idxRow, 2 * n + j), b); - } - } +__device__ inline void transposeVTile(uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, + SharedMem::VBuffer const& src) { + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#pragma unroll + for (uint32_t m = 0; m < exactDiv(SharedMem::VTBuffer::rows, gmma::instM); m++) { + static_assert(cacheHeadPartElems >= gmma::instM); + uint32_t const idxPart = gmma::instM * m / cacheHeadPartElems; + constexpr uint32_t grainsPerCacheHeadPart = exactDiv(cacheHeadPartElems, cacheElemsPerGrain); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(SharedMem::VTBuffer::cols, 2); n++) { + LdGrain const a = ldmatrix_4x( + this_warp(), &src[idxPart].template at( + 32 * n + lane, exactDiv(gmma::instM, cacheElemsPerGrain) * m - + grainsPerCacheHeadPart * idxPart + warpRank)); + LdGrain const b = {prmt(a[0], a[1], {0, 4, 2, 6}), prmt(a[0], a[1], {1, 5, 3, 7}), + prmt(a[2], a[3], {0, 4, 2, 6}), prmt(a[2], a[3], {1, 5, 3, 7})}; + uint32_t const i = idxMat % 2; + uint32_t const j = idxMat / 2; + stmatrix_4x( + this_warp(), + &dst.template at(gmma::instM * m + 16 * warpRank + 8 * i + idxRow, 2 * n + j), b); + } + } } #endif #if SWAP_AB -__device__ inline Vec loadShmColWiseVecNoDup(ShmQWiseVec const& shmVec) -{ - Vec ret; -#pragma unroll - for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) - { - uint32_t const idx = i * warp_size + laneId(); - bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); - ret[i] = (inBound ? shmVec[idx] : 0); - } - return ret; +__device__ inline Vec loadShmColWiseVecNoDup( + ShmQWiseVec const& shmVec) { + Vec ret; +#pragma unroll + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; } __device__ inline void storeShmColWiseVecNoDup( - ShmQWiseVec& shmVec, Vec const& src) -{ + ShmQWiseVec& shmVec, Vec const& src) { #pragma unroll - for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) - { - uint32_t const idx = i * warp_size + laneId(); - bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); - if (inBound) - { - shmVec[idx] = src[i]; - } + for (uint32_t i = 0; i < divUp(ShmQWiseVec::size, warp_size); i++) { + uint32_t const idx = i * warp_size + laneId(); + bool const inBound = ((ShmQWiseVec::size % warp_size == 0) || (idx < ShmQWiseVec::size)); + if (inBound) { + shmVec[idx] = src[i]; } + } } #else -__device__ inline Vec -loadShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec const& shmVec) -{ - constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); - Vec ret; - uint32_t const lane = laneId(); - uint32_t const idxHalf = lane / (gmma::instM / 4); - uint32_t const idxInHalf = lane % (gmma::instM / 4); -#pragma unroll - for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) - { - uint32_t const idx = gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; - bool const inBound - = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || (idx < ShmQWiseVec::size)); - ret[i] = (inBound ? shmVec[idx] : 0); - } - return ret; +__device__ inline Vec +loadShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec const& shmVec) { + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { + uint32_t const idx = + gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || + (idx < ShmQWiseVec::size)); + ret[i] = (inBound ? shmVec[idx] : 0); + } + return ret; } -__device__ inline void storeShmRowWiseVecNoDup(uint32_t warpRank, ShmQWiseVec& shmVec, - Vec const& src) -{ - constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); - Vec ret; - uint32_t const lane = laneId(); - uint32_t const idxHalf = lane / (gmma::instM / 4); - uint32_t const idxInHalf = lane % (gmma::instM / 4); -#pragma unroll - for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) - { - uint32_t const idx = gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; - bool const inBound - = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || (idx < ShmQWiseVec::size)); - if (inBound) - { - shmVec[idx] = src[i]; - } - } +__device__ inline void storeShmRowWiseVecNoDup( + uint32_t warpRank, ShmQWiseVec& shmVec, + Vec const& src) { + constexpr uint32_t const nbElems = exactDiv(ShmQWiseVec::size, gmma::instM) * (gmma::instM / 4); + Vec ret; + uint32_t const lane = laneId(); + uint32_t const idxHalf = lane / (gmma::instM / 4); + uint32_t const idxInHalf = lane % (gmma::instM / 4); +#pragma unroll + for (uint32_t i = 0; i < divUp(nbElems, warp_size); i++) { + uint32_t const idx = + gmma::instM * 2 * i + gmma::instM * idxHalf + (gmma::instM / 4) * warpRank + idxInHalf; + bool const inBound = ((nbElems % warp_size == 0) || (i + 1 < divUp(nbElems, warp_size)) || + (idx < ShmQWiseVec::size)); + if (inBound) { + shmVec[idx] = src[i]; + } + } } #endif #if SWAP_AB -__device__ inline void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXColMax, - ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, - CtaBarrier& gemm1WarpGrpBar) -{ - auto accColSum = loadShmColWiseVecNoDup(shmAccColSum); - - auto const xColMax = loadShmColWiseVecNoDup(shmXColMax); - auto const accColMax = loadShmColWiseVecNoDup(shmAccColMax); - auto token = gemm1WarpGrpBar.arrive(); - auto const needRescaleVec = (accColMax < xColMax); - UniformNeedRescaleMask rescaleMask; - bool anyNeedRescale = false; -#pragma unroll - for (uint32_t i = 0; i < rescaleMask.size; i++) - { - assert(accColMax[i] <= xColMax[i]); - rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); - anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); - } - if (anyNeedRescale) - { - auto const scaleVec = expf(accColMax - xColMax); - auto const lane = laneId(); +__device__ inline void rescaleGemm1AccForNewColMax_sync( + uint32_t warpRank, ShmQWiseVec const& shmXColMax, ShmQWiseVec const (&shmXColSum)[gemm0NbWarps], + ShmQWiseVec& shmAccColMax, Gemm1Acc& acc, ShmQWiseVec& shmAccColSum, + CtaBarrier& gemm1WarpGrpBar) { + auto accColSum = loadShmColWiseVecNoDup(shmAccColSum); + + auto const xColMax = loadShmColWiseVecNoDup(shmXColMax); + auto const accColMax = loadShmColWiseVecNoDup(shmAccColMax); + auto token = gemm1WarpGrpBar.arrive(); + auto const needRescaleVec = (accColMax < xColMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) { + assert(accColMax[i] <= xColMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + if (anyNeedRescale) { + auto const scaleVec = expf(accColMax - xColMax); + auto const lane = laneId(); #pragma unroll - for (uint32_t n = 0; n < Gemm1Acc::cols; n++) - { - uint32_t const vecIdx = gmma::instNBase * n / warp_size; - uint32_t const offset = gmma::instNBase * n % warp_size; - constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); -#pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - auto const mask = ((rescaleMask[vecIdx] >> (offset + j)) & 0b01010101U); - auto getScale = [&] { - return __shfl_sync( - ~0U, scaleVec[vecIdx], offset + lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols + j); - }; - assert((getScale() != 1) == ((mask >> lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols) & 0x1U)); - bool const needRescale = (mask != 0); - if (!needRescale) - { // this branch is warp-uniform - continue; - } - float const scale = getScale(); + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { + uint32_t const vecIdx = gmma::instNBase * n / warp_size; + uint32_t const offset = gmma::instNBase * n % warp_size; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); #pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + auto const mask = ((rescaleMask[vecIdx] >> (offset + j)) & 0b01010101U); + auto getScale = [&] { + return __shfl_sync(~0U, scaleVec[vecIdx], + offset + lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols + j); + }; + assert((getScale() != 1) == + ((mask >> lane % nbThrdsPerInstNBase * GmmaAccCoreMat::cols) & 0x1U)); + bool const needRescale = (mask != 0); + if (!needRescale) { // this branch is warp-uniform + continue; + } + float const scale = getScale(); #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - acc(m, n)(i, j) *= scale; - } - } - } + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) *= scale; + } } - accColSum = accColSum * scaleVec; + } } - gemm1WarpGrpBar.wait(mha::move(token)); + accColSum = accColSum * scaleVec; + } + gemm1WarpGrpBar.wait(mha::move(token)); - // @fixme: with atomic, we can let the first warp reaching here to do the update, instead of always warp 3. - uint32_t const warpRankForUpdate = gmmaWarpsPerGrp - 1; - if (warpRank == warpRankForUpdate) - { - if (anyNeedRescale) - { - storeShmColWiseVecNoDup(shmAccColMax, xColMax); - } + // @fixme: with atomic, we can let the first warp reaching here to do the update, instead of + // always warp 3. + uint32_t const warpRankForUpdate = gmmaWarpsPerGrp - 1; + if (warpRank == warpRankForUpdate) { + if (anyNeedRescale) { + storeShmColWiseVecNoDup(shmAccColMax, xColMax); + } #pragma unroll - for (uint32_t i = 0; i < gemm0NbWarps; i++) - { - accColSum = accColSum + loadShmColWiseVecNoDup(shmXColSum[i]); - } - storeShmColWiseVecNoDup(shmAccColSum, accColSum); + for (uint32_t i = 0; i < gemm0NbWarps; i++) { + accColSum = accColSum + loadShmColWiseVecNoDup(shmXColSum[i]); } - gemm1WarpGrpBar.arrive_and_wait(); + storeShmColWiseVecNoDup(shmAccColSum, accColSum); + } + gemm1WarpGrpBar.arrive_and_wait(); } #else -__device__ inline void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, ShmQWiseVec const& shmXRowMax, - ShmQWiseVec const& shmXRowSum, ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, ShmQWiseVec& shmAccRowSum) -{ - auto accRowSum = loadShmRowWiseVecNoDup(warpRank, shmAccRowSum); - auto const xRowMax = loadShmRowWiseVecNoDup(warpRank, shmXRowMax); - auto const accRowMax = loadShmRowWiseVecNoDup(warpRank, shmAccRowMax); - assert(all(xRowMax >= accRowMax)); - auto const needRescaleVec = (accRowMax < xRowMax); - UniformNeedRescaleMask rescaleMask; - bool anyNeedRescale = false; -#pragma unroll - for (uint32_t i = 0; i < rescaleMask.size; i++) - { - assert(accRowMax[i] <= xRowMax[i]); - rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); - anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); - } - - if (anyNeedRescale) - { - auto const scaleVec = expf(accRowMax - xRowMax); - auto const lane = laneId(); +__device__ inline void rescaleGemm1AccForNewRowMax_sync(uint32_t warpRank, + ShmQWiseVec const& shmXRowMax, + ShmQWiseVec const& shmXRowSum, + ShmQWiseVec& shmAccRowMax, Gemm1Acc& acc, + ShmQWiseVec& shmAccRowSum) { + auto accRowSum = loadShmRowWiseVecNoDup(warpRank, shmAccRowSum); + auto const xRowMax = loadShmRowWiseVecNoDup(warpRank, shmXRowMax); + auto const accRowMax = loadShmRowWiseVecNoDup(warpRank, shmAccRowMax); + assert(all(xRowMax >= accRowMax)); + auto const needRescaleVec = (accRowMax < xRowMax); + UniformNeedRescaleMask rescaleMask; + bool anyNeedRescale = false; +#pragma unroll + for (uint32_t i = 0; i < rescaleMask.size; i++) { + assert(accRowMax[i] <= xRowMax[i]); + rescaleMask[i] = __ballot_sync(~0U, needRescaleVec[i]); + anyNeedRescale = anyNeedRescale || (rescaleMask[i] != 0); + } + + if (anyNeedRescale) { + auto const scaleVec = expf(accRowMax - xRowMax); + auto const lane = laneId(); #pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - uint8_t const mask = reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; - bool const needRescale = (mask != 0); - if (needRescale) - { // this branch is warp-uniform - float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint8_t const mask = reinterpret_cast(rescaleMask[m / 2])[m % 2][i]; + bool const needRescale = (mask != 0); + if (needRescale) { // this branch is warp-uniform + float const scale = __shfl_sync(~0U, scaleVec[m / 2], 16 * (m % 2) + 8 * i + lane / 4); #pragma unroll - for (uint32_t n = 0; n < Gemm1Acc::cols; n++) - { + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - acc(m, n)(i, j) *= scale; - } - } - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) *= scale; } + } } - accRowSum = accRowSum * scaleVec; + } } - __syncwarp(); - auto const xRowSum = loadShmRowWiseVecNoDup(warpRank, shmXRowSum); - storeShmRowWiseVecNoDup(warpRank, shmAccRowSum, accRowSum + xRowSum); - storeShmRowWiseVecNoDup(warpRank, shmAccRowMax, xRowMax); - __syncwarp(); + accRowSum = accRowSum * scaleVec; + } + __syncwarp(); + auto const xRowSum = loadShmRowWiseVecNoDup(warpRank, shmXRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowSum, accRowSum + xRowSum); + storeShmRowWiseVecNoDup(warpRank, shmAccRowMax, xRowMax); + __syncwarp(); } #endif #if SWAP_AB -__device__ inline void rescaleAcc(Gemm1Acc& acc, RegColWiseVec const& scale) -{ +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegColWiseVec const& scale) { #pragma unroll - for (uint32_t n = 0; n < Gemm1Acc::cols; n++) - { + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { #pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - acc(m, n)(i, j) *= scale[n][j]; - } - } + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + acc(m, n)(i, j) *= scale[n][j]; } + } } + } } #else -__device__ inline void rescaleAcc(Gemm1Acc& acc, RegRowWiseVec const& scale) -{ +__device__ inline void rescaleAcc(Gemm1Acc& acc, RegRowWiseVec const& scale) { #pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { #pragma unroll - for (uint32_t n = 0; n < Gemm1Acc::cols; n++) - { + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { #pragma unroll - for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) - { - acc(m, n)(i, j) *= scale[m][i]; - } - } + for (uint32_t j = 0; j < GmmaAccCoreMat::cols; j++) { + acc(m, n)(i, j) *= scale[m][i]; } + } } + } } #endif @@ -2848,381 +2644,357 @@ __device__ inline void rescaleAcc(Gemm1Acc& acc, RegRowWiseVec const& scale) // @fixme: consider make this noinline template __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRank, DstHead* dst, - SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc const& acc, CtaBarrier& warpGrpBar, uint32_t nbKHeads) -{ - uint32_t const lane = laneId(); + SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc const& acc, CtaBarrier& warpGrpBar, + uint32_t nbKHeads) { + uint32_t const lane = laneId(); #if CACHE_ELEM_ENUM == 0 - uint32_t const idxMat = lane / 8; - uint32_t const idxRow = lane % 8; + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; #elif CACHE_ELEM_ENUM == 2 - uint32_t const idxQuad = lane / 4; - uint32_t const idxInQuad = lane % 4; + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; #endif #pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { #pragma unroll - for (uint32_t n = 0; n < Gemm1Acc::cols; n++) - { - auto const& core = acc(m, n); + for (uint32_t n = 0; n < Gemm1Acc::cols; n++) { + auto const& core = acc(m, n); #if CACHE_ELEM_ENUM == 0 - Vec f16Core; - reinterpret_cast&>(f16Core) - = convert(reinterpret_cast const&>(core)); - auto const dst = idxMat < 2 - ? &swizzleBuf.template at(8 * n + idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat) - : nullptr; - stmatrix(dst, f16Core); + Vec f16Core; + reinterpret_cast&>(f16Core) = + convert(reinterpret_cast const&>(core)); + auto const dst = idxMat < 2 + ? &swizzleBuf.template at( + 8 * n + idxRow, 2 * (gmmaWarpsPerGrp * m + warpRank) + idxMat) + : nullptr; + stmatrix(dst, f16Core); #elif CACHE_ELEM_ENUM == 2 - // each row is part of a b16 8x8 matrix and is transposed - Array2D coreTrans; - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - static_assert(GmmaAccCoreMat::cols == 2 && sizeof(InputElem) == 2); - InputElem2 const coreRow = float2ToInputElem2({core(i, 0), core(i, 1)}); - auto const coreRowTrans = movmatrix(reinterpret_cast(coreRow)); - reinterpret_cast(coreTrans(i, 0)) = coreRowTrans; - } - // expect compiler to generate two PRMT instructions - Vec const data = {coreTrans(0, 0), coreTrans(1, 0), coreTrans(0, 1), coreTrans(1, 1)}; - swizzleBuf.template at(gmma::instNBase * n + idxQuad, - (gmma::instM * m + exactDiv(gmma::instM, gmmaWarpsPerGrp) * warpRank) / 16)[idxInQuad] - = data; -#endif - } - } - warpGrpBar.arrive_and_wait(); - - constexpr uint32_t headsPerIter = exactDiv(grainBytes * gemm1NbThrds, paddedInputHeadBytes); - constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); - constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; - constexpr uint32_t nbGrainsPerHead = exactDiv(paddedInputHeadBytes, grainBytes); - uint32_t const idxHeadBase = threadRank / nbGrainsPerHead; - uint32_t const idxGrain = threadRank % nbGrainsPerHead; -#pragma unroll - for (uint32_t iter = 0; iter < nbIters; iter++) - { - uint32_t const idxHead = idxHeadBase + iter * headsPerIter; - if ((iter < nbWholeIters || idxHead < ctaNbValidQHeads) && (!isHeadPadded || idxGrain < grainsPerIOHead)) - { + // each row is part of a b16 8x8 matrix and is transposed + Array2D coreTrans; + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + static_assert(GmmaAccCoreMat::cols == 2 && sizeof(InputElem) == 2); + InputElem2 const coreRow = float2ToInputElem2({core(i, 0), core(i, 1)}); + auto const coreRowTrans = movmatrix(reinterpret_cast(coreRow)); + reinterpret_cast(coreTrans(i, 0)) = coreRowTrans; + } + // expect compiler to generate two PRMT instructions + Vec const data = {coreTrans(0, 0), coreTrans(1, 0), coreTrans(0, 1), + coreTrans(1, 1)}; + swizzleBuf.template at( + gmma::instNBase * n + idxQuad, + (gmma::instM * m + exactDiv(gmma::instM, gmmaWarpsPerGrp) * warpRank) / 16)[idxInQuad] = + data; +#endif + } + } + warpGrpBar.arrive_and_wait(); + + constexpr uint32_t headsPerIter = exactDiv(grainBytes * gemm1NbThrds, paddedInputHeadBytes); + constexpr uint32_t nbIters = divUp(ctaNbValidQHeads, headsPerIter); + constexpr uint32_t nbWholeIters = ctaNbValidQHeads / headsPerIter; + constexpr uint32_t nbGrainsPerHead = exactDiv(paddedInputHeadBytes, grainBytes); + uint32_t const idxHeadBase = threadRank / nbGrainsPerHead; + uint32_t const idxGrain = threadRank % nbGrainsPerHead; +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxHead = idxHeadBase + iter * headsPerIter; + if ((iter < nbWholeIters || idxHead < ctaNbValidQHeads) && + (!isHeadPadded || idxGrain < grainsPerIOHead)) { #if CACHE_ELEM_ENUM == 0 - auto const data = swizzleBuf.template at(idxHead, idxGrain); + auto const data = swizzleBuf.template at(idxHead, idxGrain); #elif CACHE_ELEM_ENUM == 2 - auto const data - = reinterpret_cast&>(swizzleBuf.template at(idxHead, idxGrain / 2))[idxGrain % 2]; + auto const data = reinterpret_cast&>( + swizzleBuf.template at(idxHead, idxGrain / 2))[idxGrain % 2]; #endif - constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); - auto const outVec - = convert(reinterpret_cast const&>(data)); - uint32_t dstHeadIdx = idxHead; + constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); + auto const outVec = convert( + reinterpret_cast const&>(data)); + uint32_t dstHeadIdx = idxHead; #ifdef SPEC_Q_SEQ_LEN - if constexpr (dstIsStrided) - { - uint32_t const idxToken = idxHead / headGrpSize; - if (idxToken < SPEC_Q_SEQ_LEN) - { - uint32_t const strideBetweenTokens = nbKHeads * headGrpSize; - dstHeadIdx = idxToken * strideBetweenTokens + (idxHead % headGrpSize); - } - } -#endif - reinterpret_cast, nbGrainsPerHead>&>(dst[dstHeadIdx])[idxGrain] = outVec; + if constexpr (dstIsStrided) { + uint32_t const idxToken = idxHead / headGrpSize; + if (idxToken < SPEC_Q_SEQ_LEN) { + uint32_t const strideBetweenTokens = nbKHeads * headGrpSize; + dstHeadIdx = idxToken * strideBetweenTokens + (idxHead % headGrpSize); } + } +#endif + reinterpret_cast, nbGrainsPerHead>&>( + dst[dstHeadIdx])[idxGrain] = outVec; } + } } template -__device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, - SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) -{ - // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp - // static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of - // mufu.rcp"); - auto regColSum = loadShmColWiseVecWithDup(accColSum); - if (attentionSinksVec != nullptr) - { - auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); - auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); - auto regColSinks = expf(regAttentionSinks - regAccColMax); - regColSum = regColSum + regColSinks; - } - auto const regOutScale = __frcp_rn(regColSum) * xvoScale; - rescaleAcc(acc, regOutScale); - - saveTransposedOutput(threadRank, warpRank, dst, swizzleBuf, acc, warpGrpBar, nbKHeads); - warpGrpBar.arrive_and_wait(); +__device__ inline void finalizeAndWriteOut_sync( + uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, + Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, ShmQWiseVec const& accColSum, + ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) { + // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of + // mufu.rcp static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + // + shfl to avoid 8x waste of mufu.rcp"); + auto regColSum = loadShmColWiseVecWithDup(accColSum); + if (attentionSinksVec != nullptr) { + auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); + auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); + auto regColSinks = expf(regAttentionSinks - regAccColMax); + regColSum = regColSum + regColSinks; + } + auto const regOutScale = __frcp_rn(regColSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + saveTransposedOutput(threadRank, warpRank, dst, swizzleBuf, acc, + warpGrpBar, nbKHeads); + warpGrpBar.arrive_and_wait(); } #else template -__device__ inline void finalizeAndWriteOut_sync(uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, - Gemm1Acc& acc, float xvoScale, ShmQWiseVec const& accRowSum, - uint32_t nbKHeads /* for spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens) -{ - auto const regRowSum = loadShmRowWiseVecWithDup(warpRank, accRowSum); - auto const regOutScale = __frcp_rn(regRowSum) * xvoScale; - rescaleAcc(acc, regOutScale); - - using DstElem = typename DstHead::Elem; - auto const lane = laneId(); - uint32_t const idxQuad = lane / 4; - uint32_t const idxInQuad = lane % 4; - using Atom = Vec, 4>; - using SwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; - static_assert(sizeof(SwizzleBuf) <= sizeof(swizzleBuf)); - auto& buf = reinterpret_cast(swizzleBuf); -#pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { -#pragma unroll - for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) - { - uint32_t const r = gmma::instM * m + 16 * warpRank + 8 * i + idxQuad; - static_assert(SwizzleBuf::cols == exactDiv(Gemm1Acc::cols, 2)); -#pragma unroll - for (uint32_t n = 0; n < exactDiv(Gemm1Acc::cols, 2); n++) - { - Vec const v = convert(Vec{ - acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0), acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)}); - //@fixme: without reinterpret_cast to V, the compiler generates wrong code, and require a __syncwarp() - // after rescaleAcc() to work around. Likely a bug of the compiler. - //@todo: report a compiler bug. - using V = Vec; - reinterpret_cast(buf.template at(r, n)[idxInQuad]) = reinterpret_cast(v); - // buf.template at(r, n)[idxInQuad] = v; - } - } - } - __syncwarp(); - -#pragma unroll - for (uint32_t m = 0; m < Gemm1Acc::rows; m++) - { - constexpr uint32_t srcHeadBytes = sizeof(DstElem) * headElems; - constexpr uint32_t grpSize = exactDiv(srcHeadBytes, grainBytes); - constexpr uint32_t nbGrps = exactDiv(warp_size, grpSize); - uint32_t const idxGrp = lane / grpSize; - constexpr uint32_t grainsPerAtom = exactDiv(sizeof(Atom), grainBytes); - uint32_t const rowBase = gmma::instM * m + 16 * warpRank; - constexpr uint32_t totalNbGrains = grainsPerAtom * SwizzleBuf::cols * 16; - uint32_t const nbIters = divUp(totalNbGrains, nbGrps); - constexpr bool wholeIters = (totalNbGrains % nbGrps == 0); - constexpr bool wholeHeads = (validElemsPerHead == headElems); -#pragma unroll - for (uint32_t iter = 0; iter < nbIters; iter++) - { - uint32_t const idxGrain = nbGrps * iter + idxGrp; - constexpr uint32_t grainsPerSrcHead = exactDiv(srcHeadBytes, grainBytes); - uint32_t const r = idxGrain / grainsPerSrcHead; - if (!wholeIters && r >= 16) - { - break; - } - uint32_t const cGrain = idxGrain % grainsPerSrcHead; - uint32_t const cAtom = cGrain / grainsPerAtom; - constexpr uint32_t grainsPerDstHead = exactDiv(sizeof(DstHead), grainBytes); - uint32_t const glbRow = gmma::instM * m + 16 * warpRank + r; - if (ctaNbValidQHeads != ctaNbQHeads && glbRow >= ctaNbValidQHeads) - { - break; - } - if (wholeHeads || cGrain < grainsPerDstHead) - { - uint32_t const srcRow = rowBase + r; - auto const data = reinterpret_cast( - buf.template at(srcRow, cAtom))[cGrain % grainsPerAtom]; +__device__ inline void finalizeAndWriteOut_sync( + uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, + float xvoScale, ShmQWiseVec const& accRowSum, + uint32_t nbKHeads /* for spec dec. set to 1 for workspace*/, uint32_t ctaNbValidTokens) { + auto const regRowSum = loadShmRowWiseVecWithDup(warpRank, accRowSum); + auto const regOutScale = __frcp_rn(regRowSum) * xvoScale; + rescaleAcc(acc, regOutScale); + + using DstElem = typename DstHead::Elem; + auto const lane = laneId(); + uint32_t const idxQuad = lane / 4; + uint32_t const idxInQuad = lane % 4; + using Atom = Vec, 4>; + using SwizzleBuf = Array2D, 4>, ctaNbQHeads, exactDiv(headElems, 4 * 4)>; + static_assert(sizeof(SwizzleBuf) <= sizeof(swizzleBuf)); + auto& buf = reinterpret_cast(swizzleBuf); +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < GmmaAccCoreMat::rows; i++) { + uint32_t const r = gmma::instM * m + 16 * warpRank + 8 * i + idxQuad; + static_assert(SwizzleBuf::cols == exactDiv(Gemm1Acc::cols, 2)); +#pragma unroll + for (uint32_t n = 0; n < exactDiv(Gemm1Acc::cols, 2); n++) { + Vec const v = + convert(Vec{acc(m, n * 2)(i, 0), acc(m, n * 2 + 1)(i, 0), + acc(m, n * 2)(i, 1), acc(m, n * 2 + 1)(i, 1)}); + //@fixme: without reinterpret_cast to V, the compiler generates wrong code, and require a + //__syncwarp() + // after rescaleAcc() to work around. Likely a bug of the compiler. + //@todo: report a compiler bug. + using V = Vec; + reinterpret_cast(buf.template at(r, n)[idxInQuad]) = + reinterpret_cast(v); + // buf.template at(r, n)[idxInQuad] = v; + } + } + } + __syncwarp(); + +#pragma unroll + for (uint32_t m = 0; m < Gemm1Acc::rows; m++) { + constexpr uint32_t srcHeadBytes = sizeof(DstElem) * headElems; + constexpr uint32_t grpSize = exactDiv(srcHeadBytes, grainBytes); + constexpr uint32_t nbGrps = exactDiv(warp_size, grpSize); + uint32_t const idxGrp = lane / grpSize; + constexpr uint32_t grainsPerAtom = exactDiv(sizeof(Atom), grainBytes); + uint32_t const rowBase = gmma::instM * m + 16 * warpRank; + constexpr uint32_t totalNbGrains = grainsPerAtom * SwizzleBuf::cols * 16; + uint32_t const nbIters = divUp(totalNbGrains, nbGrps); + constexpr bool wholeIters = (totalNbGrains % nbGrps == 0); + constexpr bool wholeHeads = (validElemsPerHead == headElems); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t const idxGrain = nbGrps * iter + idxGrp; + constexpr uint32_t grainsPerSrcHead = exactDiv(srcHeadBytes, grainBytes); + uint32_t const r = idxGrain / grainsPerSrcHead; + if (!wholeIters && r >= 16) { + break; + } + uint32_t const cGrain = idxGrain % grainsPerSrcHead; + uint32_t const cAtom = cGrain / grainsPerAtom; + constexpr uint32_t grainsPerDstHead = exactDiv(sizeof(DstHead), grainBytes); + uint32_t const glbRow = gmma::instM * m + 16 * warpRank + r; + if (ctaNbValidQHeads != ctaNbQHeads && glbRow >= ctaNbValidQHeads) { + break; + } + if (wholeHeads || cGrain < grainsPerDstHead) { + uint32_t const srcRow = rowBase + r; + auto const data = reinterpret_cast( + buf.template at(srcRow, cAtom))[cGrain % grainsPerAtom]; #if SPEC_DEC - static_assert(beamWidth == 1); - uint32_t const idxToken = srcRow / headGrpSize; // inside CTA - if (idxToken >= ctaNbValidTokens) - { - break; - } - uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); - uint32_t const dstRow = srcRow + idxToken * tokenPad; + static_assert(beamWidth == 1); + uint32_t const idxToken = srcRow / headGrpSize; // inside CTA + if (idxToken >= ctaNbValidTokens) { + break; + } + uint32_t const tokenPad = headGrpSize * (nbKHeads - 1); + uint32_t const dstRow = srcRow + idxToken * tokenPad; #else - uint32_t const dstRow = srcRow; + uint32_t const dstRow = srcRow; #endif - reinterpret_cast(dst[dstRow])[cGrain] = data; - } - } + reinterpret_cast(dst[dstRow])[cGrain] = data; + } } + } } #endif template __device__ inline Vec, ropeNbPairsPerThrd> loadHead( - Vec const& head, uint32_t tid) -{ - constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); - constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; - constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); - bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); - static_assert(nbPairs % nbPairsPerThrd == 0); - Vec, nbPairsPerThrd> ret; - if constexpr (forNeox) - { - auto const& pairs = reinterpret_cast, nbWorkingThrds>, 2> const&>(head); - auto const data = isWorkingThrd ? Vec, 2>{pairs[0][tid], pairs[1][tid]} - : Vec, 2>{}; - Vec, 2> const tmp = {convert(data[0]), convert(data[1])}; -#pragma unroll - for (uint32_t i = 0; i < nbPairsPerThrd; i++) - { - ret[i][0] = tmp[0][i]; - ret[i][1] = tmp[1][i]; - } - } - else - { - auto const data = isWorkingThrd ? reinterpret_cast, nbPairsPerThrd> const*>(&head)[tid] - : Vec, nbPairsPerThrd>{}; -#pragma unroll - for (uint32_t i = 0; i < nbPairsPerThrd; i++) - { - ret[i] = convert(data[i]); - } - } - return ret; + Vec const& head, uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + Vec, nbPairsPerThrd> ret; + if constexpr (forNeox) { + auto const& pairs = + reinterpret_cast, nbWorkingThrds>, 2> const&>(head); + auto const data = isWorkingThrd + ? Vec, 2>{pairs[0][tid], pairs[1][tid]} + : Vec, 2>{}; + Vec, 2> const tmp = {convert(data[0]), + convert(data[1])}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i][0] = tmp[0][i]; + ret[i][1] = tmp[1][i]; + } + } else { + auto const data = + isWorkingThrd ? reinterpret_cast, nbPairsPerThrd> const*>(&head)[tid] + : Vec, nbPairsPerThrd>{}; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i] = convert(data[i]); + } + } + return ret; } template __device__ inline mha::conditional_t, 2>, - Vec, nbPairsPerThrd>> -applyRoPE(Vec, nbPairsPerThrd> const& data, Vec, nbPairsPerThrd> const& ropeCosSin) -{ - Vec, nbPairsPerThrd> r; -#pragma unroll - for (uint32_t i = 0; i < nbPairsPerThrd; i++) - { - float const x = data[i][0]; - float const y = data[i][1]; - float const c = ropeCosSin[i][0]; - float const s = ropeCosSin[i][1]; - r[i] = Vec{c * x - s * y, s * x + c * y}; - } - if constexpr (forNeox) - { - Vec, 2> tmp; -#pragma unroll - for (uint32_t i = 0; i < nbPairsPerThrd; i++) - { - tmp[0][i] = r[i][0]; - tmp[1][i] = r[i][1]; - } - return Vec, 2>{convert(tmp[0]), convert(tmp[1])}; - } - else - { - Vec, nbPairsPerThrd> ret; -#pragma unroll - for (uint32_t i = 0; i < nbPairsPerThrd; i++) - { - ret[i] = convert(r[i]); - } - return ret; + Vec, nbPairsPerThrd>> +applyRoPE(Vec, nbPairsPerThrd> const& data, + Vec, nbPairsPerThrd> const& ropeCosSin) { + Vec, nbPairsPerThrd> r; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + float const x = data[i][0]; + float const y = data[i][1]; + float const c = ropeCosSin[i][0]; + float const s = ropeCosSin[i][1]; + r[i] = Vec{c * x - s * y, s * x + c * y}; + } + if constexpr (forNeox) { + Vec, 2> tmp; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + tmp[0][i] = r[i][0]; + tmp[1][i] = r[i][1]; + } + return Vec, 2>{convert(tmp[0]), + convert(tmp[1])}; + } else { + Vec, nbPairsPerThrd> ret; +#pragma unroll + for (uint32_t i = 0; i < nbPairsPerThrd; i++) { + ret[i] = convert(r[i]); } + return ret; + } } template -__device__ inline void storeRotatedPairsForKV(GMemCacheHead& dst, +__device__ inline void storeRotatedPairsForKV( + GMemCacheHead& dst, mha::conditional_t>, 2>, - Vec, ropeNbPairsPerThrd>> const& src, - uint32_t tid) -{ - constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); - constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; - constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); - bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); - static_assert(nbPairs % nbPairsPerThrd == 0); - if (!isWorkingThrd) - { - return; - } - if constexpr (forNeox) - { - auto& pairs = reinterpret_cast, nbWorkingThrds>, 2>&>(dst); - pairs[0][tid] = src[0]; - pairs[1][tid] = src[1]; - } - else - { - reinterpret_cast, nbPairsPerThrd>*>(&dst)[tid] = src; - } + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (!isWorkingThrd) { + return; + } + if constexpr (forNeox) { + auto& pairs = + reinterpret_cast, nbWorkingThrds>, 2>&>(dst); + pairs[0][tid] = src[0]; + pairs[1][tid] = src[1]; + } else { + reinterpret_cast, nbPairsPerThrd>*>(&dst)[tid] = src; + } } template -__device__ inline void storeRotatedPairsForQ(SharedMem::QBuffer& dst, +__device__ inline void storeRotatedPairsForQ( + SharedMem::QBuffer& dst, mha::conditional_t>, 2>, - Vec, ropeNbPairsPerThrd>> const& src, - uint32_t row, uint32_t tid) -{ - constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); - constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; - constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); - bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); - static_assert(nbPairs % nbPairsPerThrd == 0); - if (isWorkingThrd) - { - if constexpr (forNeox) - { -#pragma unroll - for (uint32_t i = 0; i < 2; i++) - { - auto const byteOffset - = BoundedVal{cacheElemSize * nbPairsPerThrd * (nbWorkingThrds * i + tid)}; - uint32_t const idxPart = byteOffset.template divBy().get(); - auto const byteOffsetInsidePart = byteOffset.template mod(); - uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); - LdGrain& grain = dst[idxPart].template at(row, idxGrain); - uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); - static_assert( - cacheElemSize * nbPairsPerThrd <= grainBytes && grainBytes % (cacheElemSize * nbPairsPerThrd) == 0); - reinterpret_cast&>( - reinterpret_cast(&grain)[byteOffsetInsideGrain]) - = src[i]; - } - } - else - { - auto const byteOffset = BoundedVal{cacheElemSize * 2 * nbPairsPerThrd * tid}; - uint32_t const idxPart = byteOffset.template divBy().get(); - auto const byteOffsetInsidePart = byteOffset.template mod(); - uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); - LdGrain& grain = dst[idxPart].template at(row, idxGrain); - uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); - static_assert(cacheElemSize * 2 * nbPairsPerThrd <= grainBytes - && grainBytes % (cacheElemSize * 2 * nbPairsPerThrd) == 0); - reinterpret_cast, nbPairsPerThrd>&>( - reinterpret_cast(&grain)[byteOffsetInsideGrain]) - = src; - } - } - static_assert(validElemsPerHead % 16 == 0); - __syncwarp(); - if constexpr (validElemsPerHead < headElems) - { - static_assert(validElemsPerHead >= headElems - exactDiv(headElems, nbQParts)); - constexpr uint32_t nbPadGrainsPerHead = exactDiv(headElems - validElemsPerHead, cacheElemsPerGrain); - constexpr uint32_t nbPadGrains = nbPadGrainsPerHead * ctaNbQHeads; - uint32_t const nbIters = divUp(nbPadGrains, nbThrds); -#pragma unroll - for (uint32_t iter = 0; iter < nbIters; iter++) - { - uint32_t idx = tid + nbThrds * iter; - if (idx >= nbPadGrains) - { - break; - } - uint32_t const r = idx / nbPadGrainsPerHead; - uint32_t const c = grainsPerQPart - nbPadGrainsPerHead + idx % nbPadGrainsPerHead; - dst[dst.size - 1].template at(r, c) = LdGrain{}; - } - } + Vec, ropeNbPairsPerThrd>> const& src, + uint32_t row, uint32_t tid) { + constexpr uint32_t nbPairs = exactDiv(validElemsPerHead, 2); + constexpr uint32_t nbPairsPerThrd = ropeNbPairsPerThrd; + constexpr uint32_t nbWorkingThrds = exactDiv(nbPairs, nbPairsPerThrd); + bool const isWorkingThrd = (nbWorkingThrds == nbThrds || tid < nbWorkingThrds); + static_assert(nbPairs % nbPairsPerThrd == 0); + if (isWorkingThrd) { + if constexpr (forNeox) { +#pragma unroll + for (uint32_t i = 0; i < 2; i++) { + auto const byteOffset = + BoundedVal{cacheElemSize * nbPairsPerThrd * (nbWorkingThrds * i + tid)}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = + byteOffsetInsidePart.template mod().get(); + static_assert(cacheElemSize * nbPairsPerThrd <= grainBytes && + grainBytes % (cacheElemSize * nbPairsPerThrd) == 0); + reinterpret_cast&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src[i]; + } + } else { + auto const byteOffset = BoundedVal{cacheElemSize * 2 * nbPairsPerThrd * tid}; + uint32_t const idxPart = byteOffset.template divBy().get(); + auto const byteOffsetInsidePart = byteOffset.template mod(); + uint32_t const idxGrain = byteOffsetInsidePart.template divBy().get(); + LdGrain& grain = dst[idxPart].template at(row, idxGrain); + uint32_t const byteOffsetInsideGrain = byteOffsetInsidePart.template mod().get(); + static_assert(cacheElemSize * 2 * nbPairsPerThrd <= grainBytes && + grainBytes % (cacheElemSize * 2 * nbPairsPerThrd) == 0); + reinterpret_cast, nbPairsPerThrd>&>( + reinterpret_cast(&grain)[byteOffsetInsideGrain]) = src; + } + } + static_assert(validElemsPerHead % 16 == 0); + __syncwarp(); + if constexpr (validElemsPerHead < headElems) { + static_assert(validElemsPerHead >= headElems - exactDiv(headElems, nbQParts)); + constexpr uint32_t nbPadGrainsPerHead = + exactDiv(headElems - validElemsPerHead, cacheElemsPerGrain); + constexpr uint32_t nbPadGrains = nbPadGrainsPerHead * ctaNbQHeads; + uint32_t const nbIters = divUp(nbPadGrains, nbThrds); +#pragma unroll + for (uint32_t iter = 0; iter < nbIters; iter++) { + uint32_t idx = tid + nbThrds * iter; + if (idx >= nbPadGrains) { + break; + } + uint32_t const r = idx / nbPadGrainsPerHead; + uint32_t const c = grainsPerQPart - nbPadGrainsPerHead + idx % nbPadGrainsPerHead; + dst[dst.size - 1].template at(r, c) = LdGrain{}; + } + } } #ifndef GENERATE_CUBIN -void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, +void launchHopperF8MHA( + cudaDeviceProp const& prop, uint32_t nbKHeads, #if SLIDING_WINDOW uint32_t slidingWinSize, #endif @@ -3238,15 +3010,16 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif - float const* attentionSinks, // [headGrpSize] + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, #else - GMemCacheHead* pool, // global pool of pages + GMemCacheHead* pool, // global pool of pages #endif KVCachePageIndex const* - kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] + kvCachePageList, // device pointer. shape: + // KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] #else GMemKVCacheHead* kvCacheData, #endif @@ -3255,250 +3028,244 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, BeamSearchParams const& beamSearchParams, #endif uint32_t batchSize, - float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for - // int8/fp8 KV cache. + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. + // Used only for int8/fp8 KV cache. #if SPEC_DEC SpecDecParams const& specDecParams, #endif - uint32_t* semaphores, void* scratch, cudaStream_t stream) -{ - if (beamWidth != 1) - { - throw std::runtime_error("not implemented"); - } - static uint32_t const hostSmemSize = [&]() - { - uint32_t size; - checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); - checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); - return size; - }(); - // printf("smemSize = %u\n", hostSmemSize); - uint32_t const nbVHeads = nbKHeads; - uint32_t const nbQHeads = nbKHeads * headGrpSize; - uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; - uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t - { - auto const env = std::getenv("XQA_NB_SUB_SEQ"); - if (env != nullptr) - { - int32_t const val = std::stoi(env); - if (val > 0) - { - return val; - } - } - float const factor = 0.25f; - return mha::min( - mha::max(1U, (uint32_t) round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), - divUp(maxSeqLen, gemm0CtaTileNbTokens)); - }(); + uint32_t* semaphores, void* scratch, cudaStream_t stream) { + if (beamWidth != 1) { + throw std::runtime_error("not implemented"); + } + static uint32_t const hostSmemSize = [&]() { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + // printf("smemSize = %u\n", hostSmemSize); + uint32_t const nbVHeads = nbKHeads; + uint32_t const nbQHeads = nbKHeads * headGrpSize; + uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) { + int32_t const val = std::stoi(env); + if (val > 0) { + return val; + } + } + float const factor = 0.25f; + return mha::min( + mha::max( + 1U, (uint32_t)round(prop.multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); #if SPEC_DEC - uint32_t const qSeqLen = specDecParams.qSeqLen; + uint32_t const qSeqLen = specDecParams.qSeqLen; #else - uint32_t const qSeqLen = 1; + uint32_t const qSeqLen = 1; #endif - // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == nbInputSeqSplit - dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; - dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); + // gridDim.z == nbKHeads * batchSize && gridDim.y == nbSubSeqPerSeq && gridDim.x == + // nbInputSeqSplit + dim3 const dimGrid{divUp(qSeqLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); #if USE_PAGED_KV_CACHE - uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); - auto const dtype = [] - { - if (std::is_same_v) - { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } - else if (std::is_same_v) - { - return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } - else if (std::is_same_v) - { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } - throw std::runtime_error("unsupported cache element type"); - }(); + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); #if PAGED_KV_CACHE_LAYOUT == 1 - KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; - auto const tensorMapVLLMK = makeTensorMapForPagedKVCache( - kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); - auto const tensorMapVLLMV = makeTensorMapForPagedKVCache( - vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMK = + makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = + makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); #else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; - auto const tensorMap = makeTensorMapForPagedKVCache( - pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = + makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); #endif - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, #if SLIDING_WINDOW - slidingWinSize, + slidingWinSize, #endif - qScale, output, + qScale, output, #if LOW_PREC_OUTPUT - rcpOutScale, + rcpOutScale, #endif #if USE_INPUT_KV - qkv, + qkv, #if ROPE_STYLE != 0 - ropeCosSin, + ropeCosSin, #endif #else - q, + q, #endif - attentionSinks, cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH - beamSearchParams, + beamSearchParams, #endif - batchSize, kvCacheScale, + batchSize, kvCacheScale, #if PAGED_KV_CACHE_LAYOUT == 1 - tensorMapVLLMK, tensorMapVLLMV, + tensorMapVLLMK, tensorMapVLLMV, #else - tensorMap, + tensorMap, #endif #if SPEC_DEC - specDecParams, -#endif - semaphores, scratch); -#else - KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; - static_assert(!usePagedKVCache); - assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); - auto const tensorMap = makeTensorMapForContiguousKVCache(kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, - validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, + specDecParams, +#endif + semaphores, scratch); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = + cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, #if SLIDING_WINDOW - slidingWinSize, + slidingWinSize, #endif - qScale, output, + qScale, output, #if LOW_PREC_OUTPUT - rcpOutScale, + rcpOutScale, #endif #if USE_INPUT_KV - qkv, + qkv, #if ROPE_STYLE != 0 - ropeCosSin, + ropeCosSin, #endif #else - q, + q, #endif - attentionSinks, cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH - beamSearchParams, + beamSearchParams, #endif - batchSize, kvCacheScale, tensorMap, semaphores, scratch); + batchSize, kvCacheScale, tensorMap, semaphores, scratch); #endif - checkCuda(err); + checkCuda(err); } #endif -void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, uint32_t slidingWinSize, - float qScale, OutputHead* output, +void launchHopperF8MHAFlashInfer(uint32_t multiProcessorCount, uint32_t nbKHeads, + uint32_t slidingWinSize, float qScale, OutputHead* output, #if LOW_PREC_OUTPUT - float const* rcpOutScale, + float const* rcpOutScale, #endif - InputHead const* q, float const* attentionSinks, GMemCacheHead* pool, - KVCachePageIndex const* kvCachePageList, uint32_t maxSeqLen, - uint32_t const* seqLen, uint32_t batchSize, - float const* __restrict__ kvCacheScale, + InputHead const* q, float const* attentionSinks, + GMemCacheHead* pool, KVCachePageIndex const* kvCachePageList, + uint32_t maxSeqLen, uint32_t const* seqLen, uint32_t batchSize, + float const* __restrict__ kvCacheScale, #if SPEC_DEC - uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, -#endif - uint32_t* semaphores, void* scratch, cudaStream_t stream) - { - static uint32_t const hostSmemSize = [&]() - { - uint32_t size; - checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); - checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); - return size; - }(); - uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t - { - float const factor = 0.25f; - return mha::min( - mha::max(1U, (uint32_t) round(multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), - divUp(maxSeqLen, gemm0CtaTileNbTokens)); - }(); - #if SPEC_DEC - auto specDecParams = SpecDecParams{qSeqLen, qCuSeqLens, mask}; - uint32_t const qLen = qSeqLen; - #else - uint32_t const qLen = 1; - #endif - dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; - dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; - auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); - #if USE_PAGED_KV_CACHE - uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); - auto const dtype = [] - { - if (std::is_same_v) - { - return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } - else if (std::is_same_v) - { - return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } - else if (std::is_same_v) - { - return CU_TENSOR_MAP_DATA_TYPE_UINT8; - } - throw std::runtime_error("unsupported cache element type"); - }(); - - #if PAGED_KV_CACHE_LAYOUT == 1 - KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; - - auto const tensorMapVLLMK = makeTensorMapForPagedKVCache( - kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); - auto const tensorMapVLLMV = makeTensorMapForPagedKVCache( - vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); - #else - KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; - auto const tensorMap = makeTensorMapForPagedKVCache( - pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, cacheHeadPartElems, gemm0CtaTileNbTokens); - #endif - - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, - #if SLIDING_WINDOW - slidingWinSize, - #endif - qScale, output, - #if LOW_PREC_OUTPUT - rcpOutScale, - #endif - q, attentionSinks, cacheList, batchSize, kvCacheScale, - #if PAGED_KV_CACHE_LAYOUT == 1 - tensorMapVLLMK, tensorMapVLLMV, - #else - tensorMap, - #endif - #if SPEC_DEC - specDecParams, - #endif - semaphores, scratch); - #else - KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; - static_assert(!usePagedKVCache); - assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); - auto const tensorMap = makeTensorMapForContiguousKVCache(kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, - validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); - cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, - #if SLIDING_WINDOW - slidingWinSize, - #endif - qScale, output, - #if LOW_PREC_OUTPUT - rcpOutScale, - #endif - q, attentionSinks, cacheList, batchSize, kvCacheScale, tensorMap, semaphores, scratch); - #endif - checkCuda(err); - } + uint32_t qSeqLen, uint32_t const* qCuSeqLens, MaskType const* mask, +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) { + static uint32_t const hostSmemSize = [&]() { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + uint32_t const nbSubSeqPerSeq = [&]() -> uint32_t { + float const factor = 0.25f; + return mha::min( + mha::max( + 1U, (uint32_t)round(multiProcessorCount * 3 / (batchSize * nbKHeads) * factor)), + divUp(maxSeqLen, gemm0CtaTileNbTokens)); + }(); +#if SPEC_DEC + auto specDecParams = SpecDecParams{qSeqLen, qCuSeqLens, mask}; + uint32_t const qLen = qSeqLen; +#else + uint32_t const qLen = 1; +#endif + dim3 const dimGrid{divUp(qLen, inputTokensPerCta), nbSubSeqPerSeq, nbKHeads * batchSize}; + dim3 const dimCta{warp_size * gmmaWarpsPerGrp, 1, 3}; + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); + auto const dtype = [] { + if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (std::is_same_v) { + return CU_TENSOR_MAP_DATA_TYPE_UINT8; + } + throw std::runtime_error("unsupported cache element type"); + }(); + +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, + maxNbPagesPerSeq}; + + auto const tensorMapVLLMK = + makeTensorMapForPagedKVCache(kCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); + auto const tensorMapVLLMV = + makeTensorMapForPagedKVCache(vCacheVLLM, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; + auto const tensorMap = + makeTensorMapForPagedKVCache(pool, dtype, validElemsPerHead, nbKHeads, tokensPerPage, + cacheHeadPartElems, gemm0CtaTileNbTokens); +#endif + + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, &kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, +#if PAGED_KV_CACHE_LAYOUT == 1 + tensorMapVLLMK, tensorMapVLLMV, +#else + tensorMap, +#endif +#if SPEC_DEC + specDecParams, +#endif + semaphores, scratch); +#else + KVCacheList const cacheList{kvCacheData, seqLen, maxSeqLen}; + static_assert(!usePagedKVCache); + assert(gemm0CtaTileNbTokens == gemm1CtaTileNbTokens); + auto const tensorMap = makeTensorMapForContiguousKVCache( + kvCacheData, CU_TENSOR_MAP_DATA_TYPE_UINT8, validElemsPerHead, nbKHeads, maxSeqLen, beamWidth, + batchSize, cacheHeadPartElems, gemm0CtaTileNbTokens); + cudaError_t const err = cudaLaunchKernelEx(&launchCfg, kernel_mha, nbKHeads, +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, attentionSinks, cacheList, batchSize, kvCacheScale, + tensorMap, semaphores, scratch); +#endif + checkCuda(err); +} #endif diff --git a/csrc/xqa/tensorMap.cpp b/csrc/xqa/tensorMap.cpp index 58a608aada..e79272b018 100644 --- a/csrc/xqa/tensorMap.cpp +++ b/csrc/xqa/tensorMap.cpp @@ -1,93 +1,117 @@ #include "tensorMap.h" -#include "utils.h" + #include #include + #include -uint32_t getElemBytes(CUtensorMapDataType_enum dataType) -{ - switch (dataType) - { - case CU_TENSOR_MAP_DATA_TYPE_UINT8: return 1; - case CU_TENSOR_MAP_DATA_TYPE_UINT16: return 2; - case CU_TENSOR_MAP_DATA_TYPE_UINT32: return 4; - case CU_TENSOR_MAP_DATA_TYPE_INT32: return 4; - case CU_TENSOR_MAP_DATA_TYPE_UINT64: return 8; - case CU_TENSOR_MAP_DATA_TYPE_INT64: return 8; - case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: return 2; - case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: return 4; - case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: return 8; - case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: return 2; - case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: return 4; - case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: return 4; - case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: return 4; - default: throw std::runtime_error("unsupported data type"); - } +#include "utils.h" + +uint32_t getElemBytes(CUtensorMapDataType_enum dataType) { + switch (dataType) { + case CU_TENSOR_MAP_DATA_TYPE_UINT8: + return 1; + case CU_TENSOR_MAP_DATA_TYPE_UINT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_UINT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_INT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_UINT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_INT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT64: + return 8; + case CU_TENSOR_MAP_DATA_TYPE_BFLOAT16: + return 2; + case CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32: + return 4; + case CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ: + return 4; + default: + throw std::runtime_error("unsupported data type"); + } } -CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, - uint32_t nbKHeads, uint32_t maxCacheLen, uint32_t beamWidth, uint32_t batchSize, uint32_t partElems, - uint32_t nbTokens) -{ - CUtensorMap tensorMap{}; - uint64_t const globalDims[] = {headElems, maxCacheLen, nbKHeads, 2 * beamWidth * batchSize}; - uint32_t elemBytes = getElemBytes(dataType); - uint32_t const headBytes = elemBytes * headElems; - uint64_t const globalStrides[] = {headBytes, headBytes * maxCacheLen, headBytes * maxCacheLen * nbKHeads}; - uint32_t const boxDims[] = {partElems, nbTokens, 1, 1}; - uint32_t const elemStrides[] = {1, 1, 1, 1}; +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t maxCacheLen, uint32_t beamWidth, + uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens) { + CUtensorMap tensorMap{}; + uint64_t const globalDims[] = {headElems, maxCacheLen, nbKHeads, 2 * beamWidth * batchSize}; + uint32_t elemBytes = getElemBytes(dataType); + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * maxCacheLen, + headBytes * maxCacheLen * nbKHeads}; + uint32_t const boxDims[] = {partElems, nbTokens, 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; - auto const swizzle = [&] - { - switch (partElems) - { - case 128: return CU_TENSOR_MAP_SWIZZLE_128B; - case 64: return CU_TENSOR_MAP_SWIZZLE_64B; - default: throw std::runtime_error("unsupported cache head size"); - } - }(); + auto const swizzle = [&] { + switch (partElems) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache head size"); + } + }(); - checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, globalStrides, boxDims, - elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - return tensorMap; + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, + globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; } -CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, - uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t partElems, uint32_t nbTokensPerTile) -{ - CUtensorMap tensorMap{}; - uint32_t elemBytes = getElemBytes(dataType); +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t tokensPerPage, uint32_t partElems, + uint32_t nbTokensPerTile) { + CUtensorMap tensorMap{}; + uint32_t elemBytes = getElemBytes(dataType); // VLLM Layout #if PAGED_KV_CACHE_LAYOUT == 1 - uint64_t const globalDims[] = {headElems, nbKHeads, tokensPerPage, 1U << 31}; - uint32_t const headBytes = elemBytes * headElems; - uint64_t const globalStrides[] = {headBytes, headBytes * nbKHeads, headBytes * nbKHeads * tokensPerPage}; - uint32_t const partBytes = partElems * elemBytes; - uint32_t const boxDims[] = {partElems, 1, mha::min(tokensPerPage, nbTokensPerTile), 1}; - uint32_t const elemStrides[] = {1, 1, 1, 1}; - // XQA Original Layout + uint64_t const globalDims[] = {headElems, nbKHeads, tokensPerPage, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * nbKHeads, + headBytes * nbKHeads * tokensPerPage}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, 1, mha::min(tokensPerPage, nbTokensPerTile), 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; + // XQA Original Layout #else - uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31}; - uint32_t const headBytes = elemBytes * headElems; - uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage, headBytes * tokensPerPage * nbKHeads}; - uint32_t const partBytes = partElems * elemBytes; - uint32_t const boxDims[] = {partElems, mha::min(tokensPerPage, nbTokensPerTile), 1, 1}; - uint32_t const elemStrides[] = {1, 1, 1, 1}; + uint64_t const globalDims[] = {headElems, tokensPerPage, nbKHeads, 1U << 31}; + uint32_t const headBytes = elemBytes * headElems; + uint64_t const globalStrides[] = {headBytes, headBytes * tokensPerPage, + headBytes * tokensPerPage * nbKHeads}; + uint32_t const partBytes = partElems * elemBytes; + uint32_t const boxDims[] = {partElems, mha::min(tokensPerPage, nbTokensPerTile), 1, 1}; + uint32_t const elemStrides[] = {1, 1, 1, 1}; #endif - auto const swizzle = [&] - { - switch (partBytes) - { - case 128: return CU_TENSOR_MAP_SWIZZLE_128B; - case 64: return CU_TENSOR_MAP_SWIZZLE_64B; - default: throw std::runtime_error("unsupported cache head size"); - } - }(); + auto const swizzle = [&] { + switch (partBytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache head size"); + } + }(); - checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, globalStrides, boxDims, - elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); - return tensorMap; + checkCu(cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast(addr), globalDims, + globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); + return tensorMap; } diff --git a/csrc/xqa/tensorMap.h b/csrc/xqa/tensorMap.h index 83ecb54252..d0b2c76b96 100644 --- a/csrc/xqa/tensorMap.h +++ b/csrc/xqa/tensorMap.h @@ -4,9 +4,13 @@ uint32_t getElemBytes(CUtensorMapDataType_enum dataType); -CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, - uint32_t nbKHeads, uint32_t maxCacheLen, uint32_t beamWidth, uint32_t batchSize, uint32_t partElems, - uint32_t nbTokens); +CUtensorMap makeTensorMapForContiguousKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t maxCacheLen, uint32_t beamWidth, + uint32_t batchSize, uint32_t partElems, + uint32_t nbTokens); -CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, uint32_t headElems, - uint32_t nbKHeads, uint32_t tokensPerPage, uint32_t partElems, uint32_t nbTokensPerTile); +CUtensorMap makeTensorMapForPagedKVCache(void const* addr, CUtensorMapDataType_enum dataType, + uint32_t headElems, uint32_t nbKHeads, + uint32_t tokensPerPage, uint32_t partElems, + uint32_t nbTokensPerTile); diff --git a/csrc/xqa/tma.h b/csrc/xqa/tma.h index 38d7e43928..5cf67238a2 100644 --- a/csrc/xqa/tma.h +++ b/csrc/xqa/tma.h @@ -15,305 +15,288 @@ #include "cuda_hint.cuh" #include "utils.h" #ifndef GENERATE_CUBIN -#include #include #include + +#include #endif #include "barriers.cuh" -enum class StateSpace -{ - kCONSTANT, - kPARAMETER, - kGENERIC -}; +enum class StateSpace { kCONSTANT, kPARAMETER, kGENERIC }; #ifdef GENERATE_CUBIN #define CU_TENSOR_MAP_NUM_QWORDS 16 -typedef struct CUtensorMap_st -{ +typedef struct CUtensorMap_st { #if defined(__cplusplus) && (__cplusplus >= 201103L) - alignas(64) + alignas(64) #elif __STDC_VERSION__ >= 201112L - _Alignas(64) + _Alignas(64) #endif - uint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; + uint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; } CUtensorMap; #endif -namespace tma -{ +namespace tma { -__device__ inline void loadLinearAsync(void* dst, void const* src, uint32_t nbBytes, CtaBarrier& bar) -{ - asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), - "l"(__cvta_generic_to_shared(&bar)) - : "memory"); +__device__ inline void loadLinearAsync(void* dst, void const* src, uint32_t nbBytes, + CtaBarrier& bar) { + asm volatile( + "cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); } -__device__ inline void prefetchLinear(void const* src, uint32_t nbBytes) -{ - asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(reinterpret_cast(src)), "r"(nbBytes) - : "memory"); +__device__ inline void prefetchLinear(void const* src, uint32_t nbBytes) { + asm volatile( + "cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(reinterpret_cast(src)), + "r"(nbBytes) + : "memory"); } // dsr and &bar must be remote address generated by mapa and src must be local address -__device__ inline void sm2smCopyAsync(void* dst, void const* src, uint32_t nbBytes, CgaBarrier& bar) -{ - asm volatile("cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), - "l"(__cvta_generic_to_shared(&bar)) - : "memory"); +__device__ inline void sm2smCopyAsync(void* dst, void const* src, uint32_t nbBytes, + CgaBarrier& bar) { + asm volatile( + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, " + "[%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); } template -__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, CtaBarrier& bar) -{ - if constexpr (nbDims == 1) - { - // nbDims==1 does not need tensormap and should just use cp.async.bulk - asm volatile( - "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2}], [%3];\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else if constexpr (nbDims == 2) - { - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3}], [%4];\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else if constexpr (nbDims == 3) - { - asm volatile( - "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4}], " - "[%5];\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else if constexpr (nbDims == 4) - { - asm volatile( - "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4, " - "%5}], [%6];\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else if constexpr (nbDims == 5) - { - asm volatile( - "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4, %5, " - "%6}], [%7];\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else - { - static_assert(nbDims >= 1 && nbDims <= 5); - } +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, + CtaBarrier& bar) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2}], [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3}], [%4];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4}], " + "[%5];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4, " + "%5}], [%6];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, " + "{%2, %3, %4, %5, " + "%6}], [%7];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } } template -__device__ inline void loadAsync( - void* dst, CUtensorMap const& tensorMap, DimsLE offset, CtaBarrier& bar, uint64_t cacheHint) -{ - if constexpr (nbDims == 1) - { - // nbDims==1 does not need tensormap and should just use cp.async.bulk - asm volatile( - "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " - "{%2}], [%3], %4;\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) - : "memory"); - } - else if constexpr (nbDims == 2) - { - asm volatile( - "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " - "{%2, %3}], [%4], %5;\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) - : "memory"); - } - else if constexpr (nbDims == 3) - { - asm volatile( - "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " - "{%2, %3, %4}], [%5], %6;\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) - : "memory"); - } - else if constexpr (nbDims == 4) - { - asm volatile( - "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " - "{%2, %3, %4, %5}], [%6], %7;\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) - : "memory"); - } - else if constexpr (nbDims == 5) - { - asm volatile( - "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " - "{%2, %3, %4, %5, %6}], [%7], %8;\n" - : - : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), - "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(&bar)), - "l"(cacheHint) - : "memory"); - } - else - { - static_assert(nbDims >= 1 && nbDims <= 5); - } +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, + CtaBarrier& bar, uint64_t cacheHint) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2}], [%3], %4;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3}], [%4], %5;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4}], [%5], %6;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)), + "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4, %5}], [%6], %7;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_" + "hint [%0], [%1, " + "{%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), + "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } } // shared::cta -> global -__device__ inline void store1DAsync(void* dst, void const* src, uint32_t nbBytes) -{ - asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" - : - : "l"(reinterpret_cast(dst)), "l"(__cvta_generic_to_shared(src)), "r"(nbBytes)); +__device__ inline void store1DAsync(void* dst, void const* src, uint32_t nbBytes) { + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(reinterpret_cast(dst)), "l"(__cvta_generic_to_shared(src)), + "r"(nbBytes)); } template -__device__ inline void storeAsync(CUtensorMap const& tensorMap, DimsLE const& offset, void* src) -{ - if constexpr (nbDims == 1) - { - // nbDims==1 does not need tensormap and should just use cp.async.bulk - asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group.tile [%0, {%1}], [%2];\n" - : - : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "l"(__cvta_generic_to_shared(src)) - : "memory"); - } - else if constexpr (nbDims == 2) - { - asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2}], [%3];\n" - : - : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), - "l"(__cvta_generic_to_shared(src)) - : "memory"); - } - else if constexpr (nbDims == 3) - { - asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" - : - : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), - "l"(__cvta_generic_to_shared(src)) - : "memory"); - } - else if constexpr (nbDims == 4) - { - asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" - : - : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), - "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) - : "memory"); - } - else if constexpr (nbDims == 5) - { - asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n" - : - : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), - "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) - : "memory"); - } - else - { - static_assert(nbDims >= 1 && nbDims <= 5); - } +__device__ inline void storeAsync(CUtensorMap const& tensorMap, DimsLE const& offset, + void* src) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group.tile [%0, {%1}], [%2];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2}], [%3];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], " + "[%6];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } } -__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) -{ - asm volatile("tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap), "l"(ptr) - : "memory"); +__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) { + asm volatile( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap), + "l"(ptr) + : "memory"); } -__device__ inline void commitGroup() -{ - asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); +__device__ inline void commitGroup() { + asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); } // wait until only targetNbInFlightGroups groups are still in-flight. template -__device__ inline void waitGroup() -{ - asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups) : "memory"); +__device__ inline void waitGroup() { + asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups) : "memory"); } -__device__ inline void prefetchTensorMap(CUtensorMap const& tensorMap, StateSpace loc = StateSpace::kGENERIC) -{ - assert(reinterpret_cast(&tensorMap) % alignof(CUtensorMap) == 0); - switch (loc) - { +__device__ inline void prefetchTensorMap(CUtensorMap const& tensorMap, + StateSpace loc = StateSpace::kGENERIC) { + assert(reinterpret_cast(&tensorMap) % alignof(CUtensorMap) == 0); + switch (loc) { case StateSpace::kCONSTANT: - asm volatile("prefetch.const.tensormap [%0];\n" ::"l"(__cvta_generic_to_constant(&tensorMap)) : "memory"); - break; + asm volatile("prefetch.const.tensormap [%0];\n" ::"l"(__cvta_generic_to_constant(&tensorMap)) + : "memory"); + break; case StateSpace::kPARAMETER: - asm volatile("prefetch.param.tensormap [%0];\n" ::"l"(__cvta_generic_to_grid_constant(&tensorMap)) : "memory"); - break; + asm volatile( + "prefetch.param.tensormap [%0];\n" ::"l"(__cvta_generic_to_grid_constant(&tensorMap)) + : "memory"); + break; case StateSpace::kGENERIC: - asm volatile("prefetch.tensormap [%0];\n" ::"l"(reinterpret_cast(&tensorMap)) : "memory"); - break; - default: asm volatile("trap;\n"); - } + asm volatile("prefetch.tensormap [%0];\n" ::"l"(reinterpret_cast(&tensorMap)) + : "memory"); + break; + default: + asm volatile("trap;\n"); + } } template -__device__ inline void storeAsync(void* dst, T const& src, CgaBarrier& bar) -{ - constexpr uint32_t nbWords = exactDiv(sizeof(T), sizeof(uint32_t)); - Vec const& srcVec = reinterpret_cast const&>(src); - if constexpr (nbWords == 1) - { - asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];\n" ::"l"( - __cvta_generic_to_shared(dst)), - "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else if constexpr (nbWords == 2) - { - asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.u32 [%0], {%1, %2}, [%3];\n" ::"l"( - __cvta_generic_to_shared(dst)), - "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else if constexpr (nbWords == 4) - { - asm volatile( - "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v4.u32 [%0], {%1, %2, %3, %4}, [%5];\n" ::"l"( - __cvta_generic_to_shared(dst)), - "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), "l"(__cvta_generic_to_shared(&bar)) - : "memory"); - } - else - { - static_assert(nbWords == 1 || nbWords == 2 || nbWords == 4, "src size must be 4, 8 or 16 bytes"); - } +__device__ inline void storeAsync(void* dst, T const& src, CgaBarrier& bar) { + constexpr uint32_t nbWords = exactDiv(sizeof(T), sizeof(uint32_t)); + Vec const& srcVec = reinterpret_cast const&>(src); + if constexpr (nbWords == 1) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 2) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.u32 [%0], {%1, %2}, " + "[%3];\n" ::"l"(__cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 4) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v4.u32 [%0], {%1, %2, %3, %4}, " + "[%5];\n" ::"l"(__cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbWords == 1 || nbWords == 2 || nbWords == 4, + "src size must be 4, 8 or 16 bytes"); + } } -} // namespace tma +} // namespace tma From 22441b084eaefa303ef62d05645ea3bfca435fb6 Mon Sep 17 00:00:00 2001 From: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> Date: Thu, 25 Sep 2025 01:49:07 -0700 Subject: [PATCH 3/3] remove duplicate code Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com> --- csrc/xqa/xqa_wrapper.cu | 45 ++++++++++++----------------------------- tests/test_xqa.py | 2 ++ 2 files changed, 15 insertions(+), 32 deletions(-) diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index a71967da3d..2be90d9d23 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -33,41 +33,22 @@ void xqa_wrapper(bool run_fp8_mha, int64_t multiProcessorCount, int64_t nbKHeads float const* attentionSinksPtr = attentionSinks.defined() ? reinterpret_cast(attentionSinks.data_ptr()) : nullptr; + auto const mha_func = run_fp8_mha ? &launchHopperF8MHAFlashInfer : &launchMHAFlashInfer; - if (run_fp8_mha) { - launchHopperF8MHAFlashInfer( - multiProcessorCount, nbKHeads, slidingWinSize, qScale, - reinterpret_cast(output.data_ptr()), + mha_func(multiProcessorCount, nbKHeads, slidingWinSize, qScale, + reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT - reinterpret_cast(rcpOutScale.data_ptr()), + reinterpret_cast(rcpOutScale.data_ptr()), #endif - reinterpret_cast(q.data_ptr()), attentionSinksPtr, - reinterpret_cast(pool.data_ptr()), - reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, - reinterpret_cast(seqLen.data_ptr()), batchSize, - reinterpret_cast(kvCacheScale.data_ptr()), + reinterpret_cast(q.data_ptr()), attentionSinksPtr, + reinterpret_cast(pool.data_ptr()), + reinterpret_cast(kvCachePageList.data_ptr()), maxSeqLen, + reinterpret_cast(seqLen.data_ptr()), batchSize, + reinterpret_cast(kvCacheScale.data_ptr()), #if SPEC_DEC - qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), - reinterpret_cast(mask.data_ptr()), + qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), + reinterpret_cast(mask.data_ptr()), #endif - reinterpret_cast(semaphores.data_ptr()), - reinterpret_cast(scratch.data_ptr()), stream); - } else { - launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, - reinterpret_cast(output.data_ptr()), -#if LOW_PREC_OUTPUT - reinterpret_cast(rcpOutScale.data_ptr()), -#endif - reinterpret_cast(q.data_ptr()), attentionSinksPtr, - reinterpret_cast(pool.data_ptr()), - reinterpret_cast(kvCachePageList.data_ptr()), - maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, - reinterpret_cast(kvCacheScale.data_ptr()), -#if SPEC_DEC - qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), - reinterpret_cast(mask.data_ptr()), -#endif - reinterpret_cast(semaphores.data_ptr()), - reinterpret_cast(scratch.data_ptr()), stream); - } + reinterpret_cast(semaphores.data_ptr()), + reinterpret_cast(scratch.data_ptr()), stream); } diff --git a/tests/test_xqa.py b/tests/test_xqa.py index 51d0d2159a..c7f2a58819 100644 --- a/tests/test_xqa.py +++ b/tests/test_xqa.py @@ -228,6 +228,8 @@ def test_xqa( ) cache_heads.normal_(0, 1) if fp8_kv_cache: + # Scale down the cache heads to keep values within the representable range of FP8 + # and prevent overflow during computation. The factor 4.0 is chosen empirically. cache_heads /= 4.0 nb_pages_per_seq = div_up(max_seq_len, tokens_per_page)