Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 49 additions & 55 deletions 06-block-copy/block_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ __global__ void block_copy(void *__restrict__ Cptr,
using SmemLayoutC = typename Spec::SmemLayoutC;
using SmemLayoutO = typename Spec::SmemLayoutO;

constexpr int kTileM = Spec::kTileM;
constexpr int kTileN = Spec::kTileN;
constexpr int kTileK = Spec::kTileK;
constexpr int kBlockM = Spec::kBlockM;
constexpr int kBlockN = Spec::kBlockN;
constexpr int kBlockK = Spec::kBlockK;
constexpr int kShmSizeA = Spec::kShmSizeA;
constexpr int kShmSizeB = Spec::kShmSizeB;

Expand All @@ -49,18 +49,18 @@ __global__ void block_copy(void *__restrict__ Cptr,
Tensor mC = make_tensor(make_gmem_ptr((ComputeTypeC *)Cptr), make_shape(m, n), make_stride(n, Int<1>{})); // (M, N)
Tensor mO = make_tensor(make_gmem_ptr((OutType *)Outptr), make_shape(m, n), make_stride(n, Int<1>{})); // (M, N)

auto tiler = make_tile(Int<kTileM>{}, Int<kTileN>{}, Int<kTileK>{});
auto tiler = make_tile(Int<kBlockM>{}, Int<kBlockN>{}, Int<kBlockK>{});
auto coord = make_coord(0, 0, 0);

Tensor gA = local_tile(mA, tiler, coord, Step<_1, X, _1>{}); // (kTileM, kTileK)
Tensor gB = local_tile(mB, tiler, coord, Step<X, _1, _1>{}); // (kTileN, kTileK)
Tensor gC = local_tile(mC, tiler, coord, Step<_1, _1, X>{}); // (kTileM, kTileN)
Tensor gO = local_tile(mO, tiler, coord, Step<_1, _1, X>{}); // (kTileM, kTileN)
Tensor gA = local_tile(mA, tiler, coord, Step<_1, X, _1>{}); // (kBlockM, kBlockK)
Tensor gB = local_tile(mB, tiler, coord, Step<X, _1, _1>{}); // (kBlockN, kBlockK)
Tensor gC = local_tile(mC, tiler, coord, Step<_1, _1, X>{}); // (kBlockM, kBlockN)
Tensor gO = local_tile(mO, tiler, coord, Step<_1, _1, X>{}); // (kBlockM, kBlockN)

Tensor sA = make_tensor(make_smem_ptr((ComputeTypeA *)Aptr_smem), SmemLayoutA{}); // (kTileM, kTileK)
Tensor sB = make_tensor(make_smem_ptr((ComputeTypeB *)Bptr_smem), SmemLayoutB{}); // (kTileN, kTileK)
Tensor sC = make_tensor(make_smem_ptr((ComputeTypeC *)Cptr_smem), SmemLayoutC{}); // (kTileM, kTileN)
Tensor sO = make_tensor(make_smem_ptr((OutType *)Optr_smem), SmemLayoutO{}); // (kTileM, kTileN)
Tensor sA = make_tensor(make_smem_ptr((ComputeTypeA *)Aptr_smem), SmemLayoutA{}); // (kBlockM, kBlockK)
Tensor sB = make_tensor(make_smem_ptr((ComputeTypeB *)Bptr_smem), SmemLayoutB{}); // (kBlockN, kBlockK)
Tensor sC = make_tensor(make_smem_ptr((ComputeTypeC *)Cptr_smem), SmemLayoutC{}); // (kBlockM, kBlockN)
Tensor sO = make_tensor(make_smem_ptr((OutType *)Optr_smem), SmemLayoutO{}); // (kBlockM, kBlockN)

typename Spec::TiledMMA tiled_mma;
ThrMMA thr_mma = tiled_mma.get_slice(tid);
Expand Down Expand Up @@ -92,7 +92,7 @@ __global__ void block_copy(void *__restrict__ Cptr,

copy(g2s_tiled_copy_a, tAgA_g2s, tAsA_g2s);
copy(g2s_tiled_copy_b, tBgB_g2s, tBsB_g2s);
if constexpr (!is_gemm) {
if constexpr (!IsGemm) {
copy(g2s_tiled_copy_c, tCgC_g2s, tCsC_g2s);
}

Expand Down Expand Up @@ -140,9 +140,9 @@ __global__ void block_copy(void *__restrict__ Cptr,
constexpr int kMmaTileN = Spec::kMmaTileN;
constexpr int kMmaTileK = Spec::kMmaTileK;

constexpr int NTilesM = kTileM / kMmaTileM;
constexpr int NTilesN = kTileN / kMmaTileN;
constexpr int NTilesK = kTileK / kMmaTileK;
constexpr int NTilesM = kBlockM / kMmaTileM;
constexpr int NTilesN = kBlockN / kMmaTileN;
constexpr int NTilesK = kBlockK / kMmaTileK;

#pragma unroll
for (int m_tile = 0; m_tile < NTilesM; ++m_tile) {
Expand Down Expand Up @@ -214,18 +214,18 @@ template <typename OutType_,
typename ComputeTypeA_,
typename ComputeTypeB_,
typename ComputeTypeC_,
int kTileM_ = 128,
int kTileN_ = 128,
int kTileK_ = 64>
int kBlockM_ = 128,
int kBlockN_ = 128,
int kBlockK_ = 64>
struct KernelSpec {
using OutType = OutType_;
using ComputeTypeA = ComputeTypeA_;
using ComputeTypeB = ComputeTypeB_;
using ComputeTypeC = ComputeTypeC_;

static constexpr int kTileM = kTileM_;
static constexpr int kTileN = kTileN_;
static constexpr int kTileK = kTileK_;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kBlockK = kBlockK_;

using MMA_op = SM80_16x8x16_F32BF16BF16F32_TN;
using MMA_traits = MMA_Traits<MMA_op>;
Expand Down Expand Up @@ -273,34 +273,24 @@ struct KernelSpec {
using CopyO_S2G_atom = Copy_Atom<Copy_S2G_op, OutType>;

static constexpr int kThreadNum = size(TiledMMA{});
static constexpr int kThreadsPerWarp = 32;
static constexpr int kTileM_Copy = cute::min(kThreadsPerWarp, kTileM);
static constexpr int kTileN_Copy = cute::min(kThreadsPerWarp, kTileN);

// Here we omit cases that `kAlignedCopyItems < 1`
static constexpr int kAlignedCopyItemsA =
cute::min(128 / 8 / sizeof(ComputeTypeA), kTileK *kTileM_Copy / kThreadNum);
static constexpr int kAlignedCopyItemsB =
cute::min(128 / 8 / sizeof(ComputeTypeB), kTileK *kTileN_Copy / kThreadNum);
static constexpr int kAlignedCopyItemsC =
cute::min(128 / 8 / sizeof(ComputeTypeC), kTileN *kTileM_Copy / kThreadNum);
static constexpr int kAlignedCopyItemsO = cute::min(128 / 8 / sizeof(OutType), kTileN *kTileM_Copy / kThreadNum);
static constexpr int kBlockK_Copy = cute::min(64, kBlockK) / 8;
static constexpr int kBlockN_Copy = cute::min(64, kBlockN) / 8;

using TiledCopyA_G2S =
decltype(make_tiled_copy(CopyA_G2S_atom{},
make_layout(make_shape(Int<kTileM_Copy>{}, Int<kThreadNum / kTileM_Copy>{}),
make_stride(Int<kThreadNum / kTileM_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<kAlignedCopyItemsA>{}))));
make_layout(make_shape(Int<kThreadNum / kBlockK_Copy>{}, Int<kBlockK_Copy>{}),
make_stride(Int<kBlockK_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));
using TiledCopyB_G2S =
decltype(make_tiled_copy(CopyB_G2S_atom{},
make_layout(make_shape(Int<kTileN_Copy>{}, Int<kThreadNum / kTileN_Copy>{}),
make_stride(Int<kThreadNum / kTileN_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<kAlignedCopyItemsB>{}))));
make_layout(make_shape(Int<kThreadNum / kBlockK_Copy>{}, Int<kBlockK_Copy>{}),
make_stride(Int<kBlockK_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));
using TiledCopyC_G2S =
decltype(make_tiled_copy(CopyC_G2S_atom{},
make_layout(make_shape(Int<kTileM_Copy>{}, Int<kThreadNum / kTileM_Copy>{}),
make_stride(Int<kThreadNum / kTileM_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<kAlignedCopyItemsC>{}))));
make_layout(make_shape(Int<kThreadNum / kBlockN_Copy>{}, Int<kBlockN_Copy>{}),
make_stride(Int<kBlockN_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));

using TiledCopyA_S2R = decltype(make_tiled_copy_A(CopyA_S2R_atom{}, TiledMMA{}));
using TiledCopyB_S2R = decltype(make_tiled_copy_B(CopyB_S2R_atom{}, TiledMMA{}));
Expand All @@ -311,30 +301,30 @@ struct KernelSpec {

using TiledCopyC_S2G =
decltype(make_tiled_copy(CopyC_S2G_atom{},
make_layout(make_shape(Int<kTileM_Copy>{}, Int<kThreadNum / kTileM_Copy>{}),
make_stride(Int<kThreadNum / kTileM_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<kAlignedCopyItemsC>{}))));
make_layout(make_shape(Int<kThreadNum / kBlockN_Copy>{}, Int<kBlockN_Copy>{}),
make_stride(Int<kBlockN_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));
using TiledCopyO_S2G =
decltype(make_tiled_copy(CopyO_S2G_atom{},
make_layout(make_shape(Int<kTileM_Copy>{}, Int<kThreadNum / kTileM_Copy>{}),
make_stride(Int<kThreadNum / kTileM_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<kAlignedCopyItemsO>{}))));
make_layout(make_shape(Int<kThreadNum / kBlockN_Copy>{}, Int<kBlockN_Copy>{}),
make_stride(Int<kBlockN_Copy>{}, Int<1>{})),
make_layout(make_shape(Int<1>{}, Int<8>{}))));

using SmemLayoutA =
decltype(make_layout(make_shape(Int<kTileM>{}, Int<kTileK>{}), make_stride(Int<kTileK>{}, Int<1>{})));
decltype(make_layout(make_shape(Int<kBlockM>{}, Int<kBlockK>{}), make_stride(Int<kBlockK>{}, Int<1>{})));
using SmemLayoutB =
decltype(make_layout(make_shape(Int<kTileN>{}, Int<kTileK>{}), make_stride(Int<kTileK>{}, Int<1>{})));
decltype(make_layout(make_shape(Int<kBlockN>{}, Int<kBlockK>{}), make_stride(Int<kBlockK>{}, Int<1>{})));
using SmemLayoutC =
decltype(make_layout(make_shape(Int<kTileM>{}, Int<kTileN>{}), make_stride(Int<kTileN>{}, Int<1>{})));
decltype(make_layout(make_shape(Int<kBlockM>{}, Int<kBlockN>{}), make_stride(Int<kBlockN>{}, Int<1>{})));
using SmemLayoutO =
decltype(make_layout(make_shape(Int<kTileM>{}, Int<kTileN>{}), make_stride(Int<kTileN>{}, Int<1>{})));
decltype(make_layout(make_shape(Int<kBlockM>{}, Int<kBlockN>{}), make_stride(Int<kBlockN>{}, Int<1>{})));

static constexpr int kShmSizeA = cosize(SmemLayoutA{}) * sizeof(ComputeTypeA);
static constexpr int kShmSizeB = cosize(SmemLayoutB{}) * sizeof(ComputeTypeB);
static constexpr int kShmSizeC = cosize(SmemLayoutC{}) * sizeof(ComputeTypeC);
static constexpr int kShmSizeO = cosize(SmemLayoutO{}) * sizeof(OutType);

static constexpr int kShmSize = kShmSizeA + kShmSizeB + kShmSizeC;
static constexpr int kShmSize = cute::max(kShmSizeA + kShmSizeB + kShmSizeC, kShmSizeO);
};

} // namespace spec
Expand Down Expand Up @@ -438,7 +428,7 @@ torch::Tensor run_block_copy(const torch::Tensor a, const torch::Tensor b, std::
using Spec = spec::KernelSpec<OutType, ComputeTypeA, ComputeTypeB, ComputeTypeC, M, N, K>;

dim3 block = Spec::kThreadNum;
dim3 grid((N + Spec::kTileN - 1) / Spec::kTileN, (M + Spec::kTileM - 1) / Spec::kTileM);
dim3 grid((N + Spec::kBlockN - 1) / Spec::kBlockN, (M + Spec::kBlockM - 1) / Spec::kBlockM);
int shm_size = Spec::kShmSize;

printf("Block Size: (%d, %d, %d) | Grid Size: (%d, %d, %d) | Shared Memory Size: %d Bytes\n", block.x, block.y,
Expand All @@ -458,6 +448,10 @@ torch::Tensor run_block_copy(const torch::Tensor a, const torch::Tensor b, std::
// Kernel launch
BOOL_SWITCH(is_gemm, IsGemm, [&] {
cudaEventRecord(start, stream);
if (shm_size >= 48 * 1024) {
cudaFuncSetAttribute(block_copy<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
}
block_copy<Spec, IsGemm, IsCvtPrecision>
<<<grid, block, shm_size, stream>>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr);
cudaEventRecord(stop, stream);
Expand Down
6 changes: 5 additions & 1 deletion 07-swizzling/swizzling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ __global__ void swizzling(void *__restrict__ Cptr,

copy(g2s_tiled_copy_a, tAgA_g2s, tAsA_g2s);
copy(g2s_tiled_copy_b, tBgB_g2s, tBsB_g2s);
if constexpr (!is_gemm) {
if constexpr (!IsGemm) {
copy(g2s_tiled_copy_c, tCgC_g2s, tCsC_g2s);
}

Expand Down Expand Up @@ -489,6 +489,10 @@ torch::Tensor run_swizzling(const torch::Tensor a, const torch::Tensor b, std::o
// Kernel launch
BOOL_SWITCH(is_gemm, IsGemm, [&] {
cudaEventRecord(start, stream);
if (shm_size >= 48 * 1024) {
cudaFuncSetAttribute(swizzling<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
}
swizzling<Spec, IsGemm, IsCvtPrecision>
<<<grid, block, shm_size, stream>>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr);
cudaEventRecord(stop, stream);
Expand Down
6 changes: 4 additions & 2 deletions 08-dynamic-mma/dynamic_mma.cu
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,10 @@ torch::Tensor run_dynamic_mma(const torch::Tensor a, const torch::Tensor b, std:
// Kernel launch
BOOL_SWITCH(is_gemm, IsGemm, [&] {
cudaEventRecord(start, stream);
cudaFuncSetAttribute(dynamic_mma<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
if (shm_size >= 48 * 1024) {
cudaFuncSetAttribute(dynamic_mma<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
}
dynamic_mma<Spec, IsGemm, IsCvtPrecision>
<<<grid, block, shm_size, stream>>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr);
cudaEventRecord(stop, stream);
Expand Down
6 changes: 4 additions & 2 deletions 09-pipelining/pipelining.cu
Original file line number Diff line number Diff line change
Expand Up @@ -634,8 +634,10 @@ torch::Tensor run_pipelining(const torch::Tensor a, const torch::Tensor b, std::
// Kernel launch
BOOL_SWITCH(is_gemm, IsGemm, [&] {
cudaEventRecord(start, stream);
cudaFuncSetAttribute(pipelining<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
if (shm_size >= 48 * 1024) {
cudaFuncSetAttribute(pipelining<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
}
pipelining<Spec, IsGemm, IsCvtPrecision>
<<<grid, block, shm_size, stream>>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr);
cudaEventRecord(stop, stream);
Expand Down
6 changes: 4 additions & 2 deletions 09-pipelining/pipelining_no_reg_prefetch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,10 @@ torch::Tensor run_pipelining(const torch::Tensor a, const torch::Tensor b, std::
// Kernel launch
BOOL_SWITCH(is_gemm, IsGemm, [&] {
cudaEventRecord(start, stream);
cudaFuncSetAttribute(pipelining<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
if (shm_size >= 48 * 1024) {
cudaFuncSetAttribute(pipelining<Spec, IsGemm, IsCvtPrecision>, cudaFuncAttributeMaxDynamicSharedMemorySize,
shm_size);
}
pipelining<Spec, IsGemm, IsCvtPrecision>
<<<grid, block, shm_size, stream>>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr);
cudaEventRecord(stop, stream);
Expand Down