diff --git a/06-block-copy/block_copy.cu b/06-block-copy/block_copy.cu index d06ea00..f80ba61 100644 --- a/06-block-copy/block_copy.cu +++ b/06-block-copy/block_copy.cu @@ -29,9 +29,9 @@ __global__ void block_copy(void *__restrict__ Cptr, using SmemLayoutC = typename Spec::SmemLayoutC; using SmemLayoutO = typename Spec::SmemLayoutO; - constexpr int kTileM = Spec::kTileM; - constexpr int kTileN = Spec::kTileN; - constexpr int kTileK = Spec::kTileK; + constexpr int kBlockM = Spec::kBlockM; + constexpr int kBlockN = Spec::kBlockN; + constexpr int kBlockK = Spec::kBlockK; constexpr int kShmSizeA = Spec::kShmSizeA; constexpr int kShmSizeB = Spec::kShmSizeB; @@ -49,18 +49,18 @@ __global__ void block_copy(void *__restrict__ Cptr, Tensor mC = make_tensor(make_gmem_ptr((ComputeTypeC *)Cptr), make_shape(m, n), make_stride(n, Int<1>{})); // (M, N) Tensor mO = make_tensor(make_gmem_ptr((OutType *)Outptr), make_shape(m, n), make_stride(n, Int<1>{})); // (M, N) - auto tiler = make_tile(Int{}, Int{}, Int{}); + auto tiler = make_tile(Int{}, Int{}, Int{}); auto coord = make_coord(0, 0, 0); - Tensor gA = local_tile(mA, tiler, coord, Step<_1, X, _1>{}); // (kTileM, kTileK) - Tensor gB = local_tile(mB, tiler, coord, Step{}); // (kTileN, kTileK) - Tensor gC = local_tile(mC, tiler, coord, Step<_1, _1, X>{}); // (kTileM, kTileN) - Tensor gO = local_tile(mO, tiler, coord, Step<_1, _1, X>{}); // (kTileM, kTileN) + Tensor gA = local_tile(mA, tiler, coord, Step<_1, X, _1>{}); // (kBlockM, kBlockK) + Tensor gB = local_tile(mB, tiler, coord, Step{}); // (kBlockN, kBlockK) + Tensor gC = local_tile(mC, tiler, coord, Step<_1, _1, X>{}); // (kBlockM, kBlockN) + Tensor gO = local_tile(mO, tiler, coord, Step<_1, _1, X>{}); // (kBlockM, kBlockN) - Tensor sA = make_tensor(make_smem_ptr((ComputeTypeA *)Aptr_smem), SmemLayoutA{}); // (kTileM, kTileK) - Tensor sB = make_tensor(make_smem_ptr((ComputeTypeB *)Bptr_smem), SmemLayoutB{}); // (kTileN, kTileK) - Tensor sC = make_tensor(make_smem_ptr((ComputeTypeC *)Cptr_smem), SmemLayoutC{}); // (kTileM, kTileN) - Tensor sO = make_tensor(make_smem_ptr((OutType *)Optr_smem), SmemLayoutO{}); // (kTileM, kTileN) + Tensor sA = make_tensor(make_smem_ptr((ComputeTypeA *)Aptr_smem), SmemLayoutA{}); // (kBlockM, kBlockK) + Tensor sB = make_tensor(make_smem_ptr((ComputeTypeB *)Bptr_smem), SmemLayoutB{}); // (kBlockN, kBlockK) + Tensor sC = make_tensor(make_smem_ptr((ComputeTypeC *)Cptr_smem), SmemLayoutC{}); // (kBlockM, kBlockN) + Tensor sO = make_tensor(make_smem_ptr((OutType *)Optr_smem), SmemLayoutO{}); // (kBlockM, kBlockN) typename Spec::TiledMMA tiled_mma; ThrMMA thr_mma = tiled_mma.get_slice(tid); @@ -92,7 +92,7 @@ __global__ void block_copy(void *__restrict__ Cptr, copy(g2s_tiled_copy_a, tAgA_g2s, tAsA_g2s); copy(g2s_tiled_copy_b, tBgB_g2s, tBsB_g2s); - if constexpr (!is_gemm) { + if constexpr (!IsGemm) { copy(g2s_tiled_copy_c, tCgC_g2s, tCsC_g2s); } @@ -140,9 +140,9 @@ __global__ void block_copy(void *__restrict__ Cptr, constexpr int kMmaTileN = Spec::kMmaTileN; constexpr int kMmaTileK = Spec::kMmaTileK; - constexpr int NTilesM = kTileM / kMmaTileM; - constexpr int NTilesN = kTileN / kMmaTileN; - constexpr int NTilesK = kTileK / kMmaTileK; + constexpr int NTilesM = kBlockM / kMmaTileM; + constexpr int NTilesN = kBlockN / kMmaTileN; + constexpr int NTilesK = kBlockK / kMmaTileK; #pragma unroll for (int m_tile = 0; m_tile < NTilesM; ++m_tile) { @@ -214,18 +214,18 @@ template + int kBlockM_ = 128, + int kBlockN_ = 128, + int kBlockK_ = 64> struct KernelSpec { using OutType = OutType_; using ComputeTypeA = ComputeTypeA_; using ComputeTypeB = ComputeTypeB_; using ComputeTypeC = ComputeTypeC_; - static constexpr int kTileM = kTileM_; - static constexpr int kTileN = kTileN_; - static constexpr int kTileK = kTileK_; + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kBlockK = kBlockK_; using MMA_op = SM80_16x8x16_F32BF16BF16F32_TN; using MMA_traits = MMA_Traits; @@ -273,34 +273,24 @@ struct KernelSpec { using CopyO_S2G_atom = Copy_Atom; static constexpr int kThreadNum = size(TiledMMA{}); - static constexpr int kThreadsPerWarp = 32; - static constexpr int kTileM_Copy = cute::min(kThreadsPerWarp, kTileM); - static constexpr int kTileN_Copy = cute::min(kThreadsPerWarp, kTileN); - - // Here we omit cases that `kAlignedCopyItems < 1` - static constexpr int kAlignedCopyItemsA = - cute::min(128 / 8 / sizeof(ComputeTypeA), kTileK *kTileM_Copy / kThreadNum); - static constexpr int kAlignedCopyItemsB = - cute::min(128 / 8 / sizeof(ComputeTypeB), kTileK *kTileN_Copy / kThreadNum); - static constexpr int kAlignedCopyItemsC = - cute::min(128 / 8 / sizeof(ComputeTypeC), kTileN *kTileM_Copy / kThreadNum); - static constexpr int kAlignedCopyItemsO = cute::min(128 / 8 / sizeof(OutType), kTileN *kTileM_Copy / kThreadNum); + static constexpr int kBlockK_Copy = cute::min(64, kBlockK) / 8; + static constexpr int kBlockN_Copy = cute::min(64, kBlockN) / 8; using TiledCopyA_G2S = decltype(make_tiled_copy(CopyA_G2S_atom{}, - make_layout(make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{})), - make_layout(make_shape(Int<1>{}, Int{})))); + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})), + make_layout(make_shape(Int<1>{}, Int<8>{})))); using TiledCopyB_G2S = decltype(make_tiled_copy(CopyB_G2S_atom{}, - make_layout(make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{})), - make_layout(make_shape(Int<1>{}, Int{})))); + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})), + make_layout(make_shape(Int<1>{}, Int<8>{})))); using TiledCopyC_G2S = decltype(make_tiled_copy(CopyC_G2S_atom{}, - make_layout(make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{})), - make_layout(make_shape(Int<1>{}, Int{})))); + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})), + make_layout(make_shape(Int<1>{}, Int<8>{})))); using TiledCopyA_S2R = decltype(make_tiled_copy_A(CopyA_S2R_atom{}, TiledMMA{})); using TiledCopyB_S2R = decltype(make_tiled_copy_B(CopyB_S2R_atom{}, TiledMMA{})); @@ -311,30 +301,30 @@ struct KernelSpec { using TiledCopyC_S2G = decltype(make_tiled_copy(CopyC_S2G_atom{}, - make_layout(make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{})), - make_layout(make_shape(Int<1>{}, Int{})))); + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})), + make_layout(make_shape(Int<1>{}, Int<8>{})))); using TiledCopyO_S2G = decltype(make_tiled_copy(CopyO_S2G_atom{}, - make_layout(make_shape(Int{}, Int{}), - make_stride(Int{}, Int<1>{})), - make_layout(make_shape(Int<1>{}, Int{})))); + make_layout(make_shape(Int{}, Int{}), + make_stride(Int{}, Int<1>{})), + make_layout(make_shape(Int<1>{}, Int<8>{})))); using SmemLayoutA = - decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); + decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); using SmemLayoutB = - decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); + decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); using SmemLayoutC = - decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); + decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); using SmemLayoutO = - decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); + decltype(make_layout(make_shape(Int{}, Int{}), make_stride(Int{}, Int<1>{}))); static constexpr int kShmSizeA = cosize(SmemLayoutA{}) * sizeof(ComputeTypeA); static constexpr int kShmSizeB = cosize(SmemLayoutB{}) * sizeof(ComputeTypeB); static constexpr int kShmSizeC = cosize(SmemLayoutC{}) * sizeof(ComputeTypeC); static constexpr int kShmSizeO = cosize(SmemLayoutO{}) * sizeof(OutType); - static constexpr int kShmSize = kShmSizeA + kShmSizeB + kShmSizeC; + static constexpr int kShmSize = cute::max(kShmSizeA + kShmSizeB + kShmSizeC, kShmSizeO); }; } // namespace spec @@ -438,7 +428,7 @@ torch::Tensor run_block_copy(const torch::Tensor a, const torch::Tensor b, std:: using Spec = spec::KernelSpec; dim3 block = Spec::kThreadNum; - dim3 grid((N + Spec::kTileN - 1) / Spec::kTileN, (M + Spec::kTileM - 1) / Spec::kTileM); + dim3 grid((N + Spec::kBlockN - 1) / Spec::kBlockN, (M + Spec::kBlockM - 1) / Spec::kBlockM); int shm_size = Spec::kShmSize; printf("Block Size: (%d, %d, %d) | Grid Size: (%d, %d, %d) | Shared Memory Size: %d Bytes\n", block.x, block.y, @@ -458,6 +448,10 @@ torch::Tensor run_block_copy(const torch::Tensor a, const torch::Tensor b, std:: // Kernel launch BOOL_SWITCH(is_gemm, IsGemm, [&] { cudaEventRecord(start, stream); + if (shm_size >= 48 * 1024) { + cudaFuncSetAttribute(block_copy, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } block_copy <<>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr); cudaEventRecord(stop, stream); diff --git a/07-swizzling/swizzling.cu b/07-swizzling/swizzling.cu index 7ed7bfd..6374822 100644 --- a/07-swizzling/swizzling.cu +++ b/07-swizzling/swizzling.cu @@ -92,7 +92,7 @@ __global__ void swizzling(void *__restrict__ Cptr, copy(g2s_tiled_copy_a, tAgA_g2s, tAsA_g2s); copy(g2s_tiled_copy_b, tBgB_g2s, tBsB_g2s); - if constexpr (!is_gemm) { + if constexpr (!IsGemm) { copy(g2s_tiled_copy_c, tCgC_g2s, tCsC_g2s); } @@ -489,6 +489,10 @@ torch::Tensor run_swizzling(const torch::Tensor a, const torch::Tensor b, std::o // Kernel launch BOOL_SWITCH(is_gemm, IsGemm, [&] { cudaEventRecord(start, stream); + if (shm_size >= 48 * 1024) { + cudaFuncSetAttribute(swizzling, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } swizzling <<>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr); cudaEventRecord(stop, stream); diff --git a/08-dynamic-mma/dynamic_mma.cu b/08-dynamic-mma/dynamic_mma.cu index 0a088e7..9f18bff 100644 --- a/08-dynamic-mma/dynamic_mma.cu +++ b/08-dynamic-mma/dynamic_mma.cu @@ -582,8 +582,10 @@ torch::Tensor run_dynamic_mma(const torch::Tensor a, const torch::Tensor b, std: // Kernel launch BOOL_SWITCH(is_gemm, IsGemm, [&] { cudaEventRecord(start, stream); - cudaFuncSetAttribute(dynamic_mma, cudaFuncAttributeMaxDynamicSharedMemorySize, - shm_size); + if (shm_size >= 48 * 1024) { + cudaFuncSetAttribute(dynamic_mma, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } dynamic_mma <<>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr); cudaEventRecord(stop, stream); diff --git a/09-pipelining/pipelining.cu b/09-pipelining/pipelining.cu index f55224e..c509a7c 100644 --- a/09-pipelining/pipelining.cu +++ b/09-pipelining/pipelining.cu @@ -634,8 +634,10 @@ torch::Tensor run_pipelining(const torch::Tensor a, const torch::Tensor b, std:: // Kernel launch BOOL_SWITCH(is_gemm, IsGemm, [&] { cudaEventRecord(start, stream); - cudaFuncSetAttribute(pipelining, cudaFuncAttributeMaxDynamicSharedMemorySize, - shm_size); + if (shm_size >= 48 * 1024) { + cudaFuncSetAttribute(pipelining, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } pipelining <<>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr); cudaEventRecord(stop, stream); diff --git a/09-pipelining/pipelining_no_reg_prefetch.cu b/09-pipelining/pipelining_no_reg_prefetch.cu index f811210..e9dd2fb 100644 --- a/09-pipelining/pipelining_no_reg_prefetch.cu +++ b/09-pipelining/pipelining_no_reg_prefetch.cu @@ -620,8 +620,10 @@ torch::Tensor run_pipelining(const torch::Tensor a, const torch::Tensor b, std:: // Kernel launch BOOL_SWITCH(is_gemm, IsGemm, [&] { cudaEventRecord(start, stream); - cudaFuncSetAttribute(pipelining, cudaFuncAttributeMaxDynamicSharedMemorySize, - shm_size); + if (shm_size >= 48 * 1024) { + cudaFuncSetAttribute(pipelining, cudaFuncAttributeMaxDynamicSharedMemorySize, + shm_size); + } pipelining <<>>(c.data_ptr(), a.data_ptr(), b.data_ptr(), M, N, K, out_ptr); cudaEventRecord(stop, stream);