diff --git a/tt_embeddings_cuda.cu b/tt_embeddings_cuda.cu index dbbdbe9..7ffe57b 100644 --- a/tt_embeddings_cuda.cu +++ b/tt_embeddings_cuda.cu @@ -295,6 +295,7 @@ void init_batch_gemm_backward_cuda( a1_ptr, b1_ptr, c1_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (T == 3) { init_batch_gemm_backward_3T_kernel<<< num_blocks, @@ -324,6 +325,7 @@ void init_batch_gemm_backward_cuda( a1_ptr, b1_ptr, c1_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (T == 4) { init_batch_gemm_backward_4T_kernel<<< num_blocks, @@ -356,6 +358,7 @@ void init_batch_gemm_backward_cuda( a1_ptr, b1_ptr, c1_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -574,6 +577,7 @@ std::vector tt_embeddings_backward_cuda( &(tableidx.data_ptr()[start_idx]), tr_tt_cores[t + 1].packed_accessor32(), d_tt_cores[t + 1].packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); cuda_gemm_batched_fp32_fp32( CUBLAS_OP_T, CUBLAS_OP_N, @@ -604,6 +608,7 @@ std::vector tt_embeddings_backward_cuda( &(tableidx.data_ptr()[start_idx]), tr_tt_cores[0].packed_accessor32(), d_tt_cores[0].packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } // for (int32_t t = T - 2; t >=0 ; --t) } // for (int32_t start_idx = 0; start_idx < nnz; start_idx += batch_count) @@ -627,6 +632,7 @@ std::vector tt_embeddings_backward_cuda( (*optimizer_state)[t] .packed_accessor32(), tt_cores[t].packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } else if (optim == OPTIM_SGD) { for (int32_t t = 0; t < T; ++t) { @@ -645,6 +651,7 @@ std::vector tt_embeddings_backward_cuda( learning_rate, d_tt_cores[t].packed_accessor32(), tt_cores[t].packed_accessor32()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -876,6 +883,7 @@ void init_batch_gemm_forward_cuda( a_ptr, b_ptr, c_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (T == 3) { init_batch_gemm_forward_3T_kernel<<< num_blocks, @@ -894,6 +902,7 @@ void init_batch_gemm_forward_cuda( a_ptr, b_ptr, c_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } else if (T == 4) { init_batch_gemm_forward_4T_kernel<<< num_blocks, @@ -914,6 +923,7 @@ void init_batch_gemm_forward_cuda( a_ptr, b_ptr, c_ptr); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } } @@ -1069,6 +1079,7 @@ Tensor tt_embeddings_forward_cuda( &(tableidx.data_ptr()[start_idx]), tr[T - 2].data_ptr(), output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } // for (int start_idx = 0; start_idx < nnz; start_idx += batch_count) return output; @@ -1110,6 +1121,7 @@ void update_cache_state_cuda(Tensor colidx, Tensor hashtbl, Tensor cache_freq) { hashtbl.numel(), hashtbl.data_ptr(), cache_freq.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } __global__ void mark_popular_colidx_kernel( @@ -1254,6 +1266,7 @@ void prefetch_cached_weights_cuda( start_idx, tr[T - 2].data_ptr(), cache_weight.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } // for (int start_idx = 0; start_idx < nnz; start_idx += batch_count) } @@ -1322,6 +1335,7 @@ void cache_populate_cuda( hashtbl.data_ptr(), cache_freq.data_ptr(), cache_state.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); int32_t batch_count = 200; prefetch_cached_weights_cuda( @@ -1406,6 +1420,7 @@ preprocess_indices_sync_cuda( offsets.data_ptr(), rowidx.data_ptr(), tableidx.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); if (warmup || num_tables != 1) { // if in warmup phase or num_tables != 1, we do not lookup cache @@ -1433,6 +1448,7 @@ preprocess_indices_sync_cuda( cache_state.data_ptr(), is_tt.data_ptr(), cache_locations.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); size_t temp_storage_bytes = 0; AT_CUDA_CHECK(cub::DevicePartition::Flagged( nullptr, @@ -1569,6 +1585,7 @@ void cache_forward_cuda( cache_locations.data_ptr(), cache_weight.data_ptr(), output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); } __global__ void cache_backward_sgd_kernel( @@ -1652,7 +1669,7 @@ void cache_backward_sgd_cuda( rowidx.data_ptr(), learning_rate, cache_weight.data_ptr()); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return; } @@ -1728,7 +1745,7 @@ Tensor cache_backward_dense_cuda( cache_locations.data_ptr(), rowidx.data_ptr(), grad_cache_weight.data_ptr()); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); return grad_cache_weight; } @@ -1831,5 +1848,5 @@ void cache_backward_rowwise_adagrad_approx_cuda( eps, cache_optimizer_state.data_ptr(), cache_weight.data_ptr()); - AT_CUDA_CHECK(cudaGetLastError()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); }