From 755d8ad8863d569f68e9772882e216ec592cff89 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Mon, 30 Mar 2026 12:50:53 -0700 Subject: [PATCH 1/2] fix vamana serialization bug --- .../detail/vamana/vamana_serialize.cuh | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh index 887c9eb448..e39a9e7ae4 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh @@ -65,13 +65,20 @@ void serialize_dataset(raft::resources const& res, const auto* strided_dataset = dynamic_cast*>(dataset); if (strided_dataset) { - auto h_dataset = - raft::make_host_matrix(strided_dataset->n_rows(), strided_dataset->dim()); - raft::copy(res, - raft::make_host_vector_view(h_dataset.data_handle(), - strided_dataset->n_rows() * strided_dataset->dim()), - raft::make_device_vector_view(strided_dataset->view().data_handle(), - strided_dataset->n_rows() * strided_dataset->dim())); + auto nrows = strided_dataset->n_rows(); + auto dim = strided_dataset->dim(); + auto stride = strided_dataset->stride(); + auto d_data = strided_dataset->view(); + auto h_dataset = raft::make_host_matrix(nrows, dim); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(h_dataset.data_handle(), + sizeof(T) * dim, + d_data.data_handle(), + sizeof(T) * stride, + sizeof(T) * dim, + nrows, + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + raft::resource::sync_stream(res); to_file(dataset_base_file, h_dataset); } else { RAFT_LOG_DEBUG("dynamic_cast to strided_dataset failed"); @@ -91,6 +98,7 @@ void serialize_dataset(raft::resources const& res, try { auto h_dataset = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); raft::copy(res, h_dataset.view(), dataset); + raft::resource::sync_stream(res); to_file(dataset_base_file, h_dataset); } catch (std::bad_alloc& e) { RAFT_LOG_INFO("Failed to serialize dataset"); From 3806ab36715fac194104d828bb6ab85393caf9a1 Mon Sep 17 00:00:00 2001 From: Tarang Jain Date: Tue, 31 Mar 2026 20:06:52 -0700 Subject: [PATCH 2/2] raft::copy_matrix --- .../neighbors/detail/vamana/vamana_serialize.cuh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh index e39a9e7ae4..68ce334cd2 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh @@ -16,6 +16,7 @@ #include #include #include +#include #include "../dataset_serialize.hpp" @@ -70,14 +71,13 @@ void serialize_dataset(raft::resources const& res, auto stride = strided_dataset->stride(); auto d_data = strided_dataset->view(); auto h_dataset = raft::make_host_matrix(nrows, dim); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(h_dataset.data_handle(), - sizeof(T) * dim, - d_data.data_handle(), - sizeof(T) * stride, - sizeof(T) * dim, - nrows, - cudaMemcpyDefault, - raft::resource::get_cuda_stream(res))); + raft::copy_matrix(h_dataset.data_handle(), + dim, + d_data.data_handle(), + stride, + dim, + nrows, + raft::resource::get_cuda_stream(res)); raft::resource::sync_stream(res); to_file(dataset_base_file, h_dataset); } else {