diff --git a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh index 887c9eb448..68ce334cd2 100644 --- a/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh +++ b/cpp/src/neighbors/detail/vamana/vamana_serialize.cuh @@ -16,6 +16,7 @@ #include #include #include +#include #include "../dataset_serialize.hpp" @@ -65,13 +66,19 @@ void serialize_dataset(raft::resources const& res, const auto* strided_dataset = dynamic_cast*>(dataset); if (strided_dataset) { - auto h_dataset = - raft::make_host_matrix(strided_dataset->n_rows(), strided_dataset->dim()); - raft::copy(res, - raft::make_host_vector_view(h_dataset.data_handle(), - strided_dataset->n_rows() * strided_dataset->dim()), - raft::make_device_vector_view(strided_dataset->view().data_handle(), - strided_dataset->n_rows() * strided_dataset->dim())); + auto nrows = strided_dataset->n_rows(); + auto dim = strided_dataset->dim(); + auto stride = strided_dataset->stride(); + auto d_data = strided_dataset->view(); + auto h_dataset = raft::make_host_matrix(nrows, dim); + raft::copy_matrix(h_dataset.data_handle(), + dim, + d_data.data_handle(), + stride, + dim, + nrows, + raft::resource::get_cuda_stream(res)); + raft::resource::sync_stream(res); to_file(dataset_base_file, h_dataset); } else { RAFT_LOG_DEBUG("dynamic_cast to strided_dataset failed"); @@ -91,6 +98,7 @@ void serialize_dataset(raft::resources const& res, try { auto h_dataset = raft::make_host_matrix(dataset.extent(0), dataset.extent(1)); raft::copy(res, h_dataset.view(), dataset); + raft::resource::sync_stream(res); to_file(dataset_base_file, h_dataset); } catch (std::bad_alloc& e) { RAFT_LOG_INFO("Failed to serialize dataset");