Skip to content
This repository was archived by the owner on Oct 31, 2023. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions tt_embeddings_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ void init_batch_gemm_backward_cuda(
a1_ptr,
b1_ptr,
c1_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (T == 3) {
init_batch_gemm_backward_3T_kernel<<<
num_blocks,
Expand Down Expand Up @@ -324,6 +325,7 @@ void init_batch_gemm_backward_cuda(
a1_ptr,
b1_ptr,
c1_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (T == 4) {
init_batch_gemm_backward_4T_kernel<<<
num_blocks,
Expand Down Expand Up @@ -356,6 +358,7 @@ void init_batch_gemm_backward_cuda(
a1_ptr,
b1_ptr,
c1_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

Expand Down Expand Up @@ -574,6 +577,7 @@ std::vector<Tensor> tt_embeddings_backward_cuda(
&(tableidx.data_ptr<int64_t>()[start_idx]),
tr_tt_cores[t + 1].packed_accessor32<float, 2, RestrictPtrTraits>(),
d_tt_cores[t + 1].packed_accessor32<float, 3, RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
cuda_gemm_batched_fp32_fp32(
CUBLAS_OP_T,
CUBLAS_OP_N,
Expand Down Expand Up @@ -604,6 +608,7 @@ std::vector<Tensor> tt_embeddings_backward_cuda(
&(tableidx.data_ptr<int64_t>()[start_idx]),
tr_tt_cores[0].packed_accessor32<float, 2, RestrictPtrTraits>(),
d_tt_cores[0].packed_accessor32<float, 3, RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} // for (int32_t t = T - 2; t >=0 ; --t)
} // for (int32_t start_idx = 0; start_idx < nnz; start_idx += batch_count)
Expand All @@ -627,6 +632,7 @@ std::vector<Tensor> tt_embeddings_backward_cuda(
(*optimizer_state)[t]
.packed_accessor32<float, 3, RestrictPtrTraits>(),
tt_cores[t].packed_accessor32<float, 3, RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
} else if (optim == OPTIM_SGD) {
for (int32_t t = 0; t < T; ++t) {
Expand All @@ -645,6 +651,7 @@ std::vector<Tensor> tt_embeddings_backward_cuda(
learning_rate,
d_tt_cores[t].packed_accessor32<float, 3, RestrictPtrTraits>(),
tt_cores[t].packed_accessor32<float, 3, RestrictPtrTraits>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

Expand Down Expand Up @@ -876,6 +883,7 @@ void init_batch_gemm_forward_cuda(
a_ptr,
b_ptr,
c_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (T == 3) {
init_batch_gemm_forward_3T_kernel<<<
num_blocks,
Expand All @@ -894,6 +902,7 @@ void init_batch_gemm_forward_cuda(
a_ptr,
b_ptr,
c_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else if (T == 4) {
init_batch_gemm_forward_4T_kernel<<<
num_blocks,
Expand All @@ -914,6 +923,7 @@ void init_batch_gemm_forward_cuda(
a_ptr,
b_ptr,
c_ptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

Expand Down Expand Up @@ -1069,6 +1079,7 @@ Tensor tt_embeddings_forward_cuda(
&(tableidx.data_ptr<int64_t>()[start_idx]),
tr[T - 2].data_ptr<float>(),
output.data_ptr<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
} // for (int start_idx = 0; start_idx < nnz; start_idx += batch_count)

return output;
Expand Down Expand Up @@ -1110,6 +1121,7 @@ void update_cache_state_cuda(Tensor colidx, Tensor hashtbl, Tensor cache_freq) {
hashtbl.numel(),
hashtbl.data_ptr<int64_t>(),
cache_freq.data_ptr<int64_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

__global__ void mark_popular_colidx_kernel(
Expand Down Expand Up @@ -1254,6 +1266,7 @@ void prefetch_cached_weights_cuda(
start_idx,
tr[T - 2].data_ptr<float>(),
cache_weight.data_ptr<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
} // for (int start_idx = 0; start_idx < nnz; start_idx += batch_count)
}

Expand Down Expand Up @@ -1322,6 +1335,7 @@ void cache_populate_cuda(
hashtbl.data_ptr<int64_t>(),
cache_freq.data_ptr<int64_t>(),
cache_state.data_ptr<int32_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();

int32_t batch_count = 200;
prefetch_cached_weights_cuda(
Expand Down Expand Up @@ -1406,6 +1420,7 @@ preprocess_indices_sync_cuda(
offsets.data_ptr<int64_t>(),
rowidx.data_ptr<int64_t>(),
tableidx.data_ptr<int64_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();

if (warmup || num_tables != 1) {
// if in warmup phase or num_tables != 1, we do not lookup cache
Expand Down Expand Up @@ -1433,6 +1448,7 @@ preprocess_indices_sync_cuda(
cache_state.data_ptr<int32_t>(),
is_tt.data_ptr<bool>(),
cache_locations.data_ptr<int32_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
size_t temp_storage_bytes = 0;
AT_CUDA_CHECK(cub::DevicePartition::Flagged(
nullptr,
Expand Down Expand Up @@ -1569,6 +1585,7 @@ void cache_forward_cuda(
cache_locations.data_ptr<int32_t>(),
cache_weight.data_ptr<float>(),
output.data_ptr<float>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

__global__ void cache_backward_sgd_kernel(
Expand Down Expand Up @@ -1652,7 +1669,7 @@ void cache_backward_sgd_cuda(
rowidx.data_ptr<int64_t>(),
learning_rate,
cache_weight.data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return;
}

Expand Down Expand Up @@ -1728,7 +1745,7 @@ Tensor cache_backward_dense_cuda(
cache_locations.data_ptr<int32_t>(),
rowidx.data_ptr<int64_t>(),
grad_cache_weight.data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
C10_CUDA_KERNEL_LAUNCH_CHECK();
return grad_cache_weight;
}

Expand Down Expand Up @@ -1831,5 +1848,5 @@ void cache_backward_rowwise_adagrad_approx_cuda(
eps,
cache_optimizer_state.data_ptr<float>(),
cache_weight.data_ptr<float>());
AT_CUDA_CHECK(cudaGetLastError());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}