diff --git a/14-warp-specialization/warp_specialization_pipeline.py b/14-warp-specialization/warp_specialization_pipeline.py index 4887ac4..fd4c824 100644 --- a/14-warp-specialization/warp_specialization_pipeline.py +++ b/14-warp-specialization/warp_specialization_pipeline.py @@ -63,7 +63,7 @@ def relative_error(target: torch.Tensor, ref: torch.Tensor, eps: float = 1e-8): num_failed = 0 -def compare_matrix(kernel_output: torch.Tensor, torch_output: torch.Tensor): +def compare_matrix(kernel_output: torch.Tensor, torch_output: torch.Tensor): kernel_output = kernel_output.float() torch_output = torch_output.float() diff --git a/14-warp-specialization/warp_specialization_pipeline_api.cu b/14-warp-specialization/warp_specialization_pipeline_api.cu index 3c13cdb..d55a11b 100644 --- a/14-warp-specialization/warp_specialization_pipeline_api.cu +++ b/14-warp-specialization/warp_specialization_pipeline_api.cu @@ -306,13 +306,13 @@ __global__ __launch_bounds__(Spec::kThreadNum) void warp_specialization(__grid_c Tensor tCsC_s2r = s2r_thr_copy_c.partition_S(sC); // (CPY, CPY_M, CPY_K) Tensor tCrC_s2r = s2r_thr_copy_c.retile_D(tCrC_load); // (CPY, CPY_M, CPY_K) - if (consumer_tid == 0) { - initialize_barrier(tma_load_c_mbarrier, /* arrival thread count */ 1); - cutlass::arch::fence_view_async_shared(); + if (consumer_tid == 0) { + initialize_barrier(tma_load_c_mbarrier, /* arrival thread count */ 1); + cutlass::arch::fence_view_async_shared(); - copy(tma_C.with(tma_load_c_mbarrier), tCgC, tCsC); - set_barrier_transaction_bytes(tma_load_c_mbarrier, tma_transaction_load_c_bytes); - } + copy(tma_C.with(tma_load_c_mbarrier), tCgC, tCsC); + set_barrier_transaction_bytes(tma_load_c_mbarrier, tma_transaction_load_c_bytes); + } warpgroup_sync(kNumMmaWarpGroups); wait_barrier(tma_load_c_mbarrier, /* phase */ 0);