From c172b3f82a2f564dabeed479df300c24318aed5b Mon Sep 17 00:00:00 2001 From: Anupam <54245698+aamijar@users.noreply.github.com> Date: Wed, 18 Mar 2026 16:56:17 -0700 Subject: [PATCH] Move PCA and TSVD from cuml to raft (#2952) Required for https://github.com/rapidsai/cuvs/issues/1207 and https://github.com/rapidsai/cuml/pull/7802. This PR moves `pca.cuh`, `tsvd.cuh`, and gtests into raft. Resolves https://github.com/rapidsai/raft/issues/2978 Authors: - Anupam (https://github.com/aamijar) Approvers: - Jinsol Park (https://github.com/jinsolp) - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/raft/pull/2952 --- cpp/include/raft/linalg/detail/pca.cuh | 321 +++++++++++++++ cpp/include/raft/linalg/detail/tsvd.cuh | 521 ++++++++++++++++++++++++ cpp/include/raft/linalg/pca.cuh | 191 +++++++++ cpp/include/raft/linalg/pca_types.hpp | 36 ++ cpp/include/raft/linalg/tsvd.cuh | 169 ++++++++ cpp/tests/CMakeLists.txt | 2 + cpp/tests/linalg/pca.cu | 335 +++++++++++++++ cpp/tests/linalg/tsvd.cu | 228 +++++++++++ docs/source/cpp_api/linalg_solver.rst | 24 ++ 9 files changed, 1827 insertions(+) create mode 100644 cpp/include/raft/linalg/detail/pca.cuh create mode 100644 cpp/include/raft/linalg/detail/tsvd.cuh create mode 100644 cpp/include/raft/linalg/pca.cuh create mode 100644 cpp/include/raft/linalg/pca_types.hpp create mode 100644 cpp/include/raft/linalg/tsvd.cuh create mode 100644 cpp/tests/linalg/pca.cu create mode 100644 cpp/tests/linalg/tsvd.cu diff --git a/cpp/include/raft/linalg/detail/pca.cuh b/cpp/include/raft/linalg/detail/pca.cuh new file mode 100644 index 0000000000..d3d34c5549 --- /dev/null +++ b/cpp/include/raft/linalg/detail/pca.cuh @@ -0,0 +1,321 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::linalg::detail { + +template +void trunc_comp_exp_vars(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view in, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_scalar_view noise_vars, + std::size_t n_rows) +{ + auto stream = resource::get_cuda_stream(handle); + + auto n_cols = in.extent(0); + auto n_components = components.extent(0); + + auto len = static_cast(n_cols * n_cols); + rmm::device_uvector components_all(len, stream); + rmm::device_uvector explained_var_all(static_cast(n_cols), stream); + rmm::device_uvector explained_var_ratio_all(static_cast(n_cols), stream); + + detail::cal_eig( + handle, + prms, + in, + raft::make_device_matrix_view( + components_all.data(), n_cols, n_cols), + raft::make_device_vector_view(explained_var_all.data(), n_cols)); + raft::matrix::trunc_zero_origin( + handle, + raft::make_device_matrix_view( + components_all.data(), n_cols, n_cols), + raft::make_device_matrix_view( + components.data_handle(), n_components, n_cols)); + raft::matrix::ratio(handle, + raft::make_device_matrix_view( + explained_var_all.data(), n_cols, idx_t(1)), + raft::make_device_matrix_view( + explained_var_ratio_all.data(), n_cols, idx_t(1))); + raft::matrix::trunc_zero_origin( + handle, + raft::make_device_matrix_view( + explained_var_all.data(), n_cols, idx_t(1)), + raft::make_device_matrix_view( + explained_var.data_handle(), n_components, idx_t(1))); + raft::matrix::trunc_zero_origin( + handle, + raft::make_device_matrix_view( + explained_var_ratio_all.data(), n_cols, idx_t(1)), + raft::make_device_matrix_view( + explained_var_ratio.data_handle(), n_components, idx_t(1))); + + if (static_cast(n_components) < static_cast(n_cols) && + static_cast(n_components) < n_rows) { + raft::stats::mean(noise_vars.data_handle(), + explained_var_all.data() + static_cast(n_components), + std::size_t{1}, + static_cast(n_cols - n_components), + false, + stream); + } else { + raft::matrix::fill( + handle, + raft::make_device_vector_view(noise_vars.data_handle(), idx_t(1)), + math_t{0}); + } +} + +/** + * @brief perform fit operation for PCA. + * @param[in] handle: raft::resources + * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) + * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major). + * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[out] explained_var: explained variances. Size n_components. + * @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components. + * @param[out] singular_vals: singular values. Size n_components. + * @param[out] mu: mean of all features. Size n_cols. + * @param[out] noise_vars: noise variance scalar. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void pca_fit(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false) +{ + auto stream = resource::get_cuda_stream(handle); + auto cublas_handle = raft::resource::get_cublas_handle(handle); + + auto n_rows = input.extent(0); + auto n_cols = input.extent(1); + + auto n_components = components.extent(0); + + ASSERT(n_cols > 1, "Parameter n_cols: number of columns cannot be less than two"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); + ASSERT(n_components <= n_cols, "n_components cannot exceed n_cols"); + + raft::stats::mean(mu.data_handle(), input.data_handle(), n_cols, n_rows, false, stream); + + auto len = static_cast(n_cols * n_cols); + rmm::device_uvector cov(len, stream); + + raft::stats::cov( + handle, cov.data(), input.data_handle(), mu.data_handle(), n_cols, n_rows, true, true, stream); + + detail::trunc_comp_exp_vars( + handle, + prms, + raft::make_device_matrix_view(cov.data(), n_cols, n_cols), + components, + explained_var, + explained_var_ratio, + noise_vars, + static_cast(n_rows)); + + math_t scalar = (n_rows - 1); + raft::matrix::weighted_sqrt(handle, + raft::make_device_matrix_view( + explained_var.data_handle(), idx_t(1), n_components), + raft::make_device_matrix_view( + singular_vals.data_handle(), idx_t(1), n_components), + raft::make_host_scalar_view(&scalar), + true); + + raft::stats::meanAdd( + input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); + + detail::sign_flip_components(handle, input, components, true, flip_signs_based_on_U); +} + +/** + * @brief performs transform operation for PCA. Transforms the data to eigenspace. + * @param[in] handle: raft::resources + * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) + * @param[inout] input: the data to transform. Size n_rows x n_cols (col-major). + * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[in] singular_vals: singular values. Size n_components. + * @param[in] mu: mean of features. Size n_cols. + * @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major). + */ +template +void pca_transform(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view trans_input) +{ + auto stream = resource::get_cuda_stream(handle); + + auto n_rows = input.extent(0); + auto n_cols = input.extent(1); + auto n_components = components.extent(0); + + ASSERT(n_cols > 1, "Parameter n_cols: number of columns cannot be less than two"); + ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); + ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); + + auto components_len = static_cast(n_cols * n_components); + rmm::device_uvector components_copy{components_len, stream}; + raft::copy(components_copy.data(), components.data_handle(), components_len, stream); + + if (prms.whiten) { + math_t scalar = math_t(sqrt(n_rows - 1)); + raft::linalg::scalarMultiply( + components_copy.data(), components_copy.data(), scalar, components_len, stream); + raft::linalg::binary_div_skip_zero( + handle, + raft::make_device_matrix_view( + components_copy.data(), n_cols, n_components), + raft::make_device_vector_view(singular_vals.data_handle(), + n_components)); + } + + raft::stats::meanCenter( + input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); + detail::tsvd_transform(handle, + prms, + input, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), + trans_input); + raft::stats::meanAdd( + input.data_handle(), input.data_handle(), mu.data_handle(), n_cols, n_rows, stream); +} + +/** + * @brief performs inverse transform operation for PCA. + * @param[in] handle: raft::resources + * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) + * @param[in] trans_input: the transformed data. Size n_rows x n_components (col-major). + * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[in] singular_vals: singular values. Size n_components. + * @param[in] mu: mean of features. Size n_cols. + * @param[out] output: the reconstructed data. Size n_rows x n_cols (col-major). + */ +template +void pca_inverse_transform(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view output) +{ + auto stream = resource::get_cuda_stream(handle); + + auto n_rows = output.extent(0); + auto n_cols = output.extent(1); + auto n_components = components.extent(0); + + ASSERT(n_cols > 1, "Parameter n_cols: number of columns cannot be less than two"); + ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); + ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); + + auto components_len = static_cast(n_cols * n_components); + rmm::device_uvector components_copy{components_len, stream}; + raft::copy(components_copy.data(), components.data_handle(), components_len, stream); + + if (prms.whiten) { + math_t sqrt_n_samples = sqrt(n_rows - 1); + math_t scalar = n_rows - 1 > 0 ? math_t(1 / sqrt_n_samples) : 0; + raft::linalg::scalarMultiply( + components_copy.data(), components_copy.data(), scalar, components_len, stream); + raft::linalg::binary_mult_skip_zero( + handle, + raft::make_device_matrix_view( + components_copy.data(), n_cols, n_components), + raft::make_device_vector_view(singular_vals.data_handle(), + n_components)); + } + + detail::tsvd_inverse_transform(handle, + prms, + trans_input, + raft::make_device_matrix_view( + components_copy.data(), n_components, n_cols), + output); + raft::stats::meanAdd( + output.data_handle(), output.data_handle(), mu.data_handle(), n_cols, n_rows, stream); +} + +/** + * @brief perform fit and transform operations for PCA. + * @param[in] handle: raft::resources + * @param[in] prms: PCA parameters (n_components, algorithm, whiten, etc.) + * @param[inout] input: the data is fitted to PCA. Size n_rows x n_cols (col-major). + * @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major). + * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[out] explained_var: explained variances. Size n_components. + * @param[out] explained_var_ratio: ratio of explained to total variance. Size n_components. + * @param[out] singular_vals: singular values. Size n_components. + * @param[out] mu: mean of all features. Size n_cols. + * @param[out] noise_vars: noise variance scalar. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void pca_fit_transform(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false) +{ + detail::pca_fit(handle, + prms, + input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); + detail::pca_transform(handle, prms, input, components, singular_vals, mu, trans_input); +} + +}; // end namespace raft::linalg::detail diff --git a/cpp/include/raft/linalg/detail/tsvd.cuh b/cpp/include/raft/linalg/detail/tsvd.cuh new file mode 100644 index 0000000000..8c6ba4cbfe --- /dev/null +++ b/cpp/include/raft/linalg/detail/tsvd.cuh @@ -0,0 +1,521 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace raft::linalg::detail { + +template +void cal_comp_exp_vars_svd(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view in, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view explained_vars, + raft::device_vector_view explained_var_ratio) +{ + auto stream = resource::get_cuda_stream(handle); + auto cusolver_handle = raft::resource::get_cusolver_dn_handle(handle); + auto cublas_handle = raft::resource::get_cublas_handle(handle); + + auto n_rows = in.extent(0); + auto n_cols = in.extent(1); + auto n_components = components.extent(0); + + auto diff = n_cols - n_components; + math_t ratio = math_t(diff) / math_t(n_cols); + ASSERT(ratio >= math_t(0.2), + "Number of components should be less than at least 80 percent of the " + "number of features"); + + std::size_t p = static_cast(math_t(0.1) * math_t(n_cols)); + ASSERT(p >= 5, "RSVD should be used where the number of columns are at least 50"); + + auto total_random_vecs = static_cast(n_components) + p; + ASSERT(total_random_vecs < static_cast(n_cols), + "RSVD should be used where the number of columns are at least 50"); + + rmm::device_uvector components_temp(static_cast(n_cols * n_components), + stream); + math_t* left_eigvec = nullptr; + raft::linalg::rsvdFixedRank(handle, + in.data_handle(), + n_rows, + n_cols, + singular_vals.data_handle(), + left_eigvec, + components_temp.data(), + n_components, + p, + true, + false, + true, + false, + (math_t)prms.tol, + prms.n_iterations, + stream); + + raft::linalg::transpose( + handle, components_temp.data(), components.data_handle(), n_cols, n_components, stream); + + raft::matrix::weighted_power(handle, + raft::make_device_matrix_view( + singular_vals.data_handle(), idx_t(1), n_components), + raft::make_device_matrix_view( + explained_vars.data_handle(), idx_t(1), n_components), + math_t(1)); + raft::matrix::ratio( + handle, explained_vars.data_handle(), explained_var_ratio.data_handle(), n_components, stream); +} + +template +void cal_eig(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view in, + raft::device_matrix_view components, + raft::device_vector_view explained_var) +{ + auto stream = resource::get_cuda_stream(handle); + auto cusolver_handle = raft::resource::get_cusolver_dn_handle(handle); + + auto n_cols = in.extent(0); + + if (prms.algorithm == solver::COV_EIG_JACOBI) { + raft::linalg::eigJacobi(handle, + in.data_handle(), + n_cols, + n_cols, + components.data_handle(), + explained_var.data_handle(), + stream, + (math_t)prms.tol, + prms.n_iterations); + } else { + raft::linalg::eigDC(handle, + in.data_handle(), + n_cols, + n_cols, + components.data_handle(), + explained_var.data_handle(), + stream); + } + raft::resources handle_stream_zero; + raft::resource::set_cuda_stream(handle_stream_zero, stream); + + raft::matrix::col_reverse(handle_stream_zero, + raft::make_device_matrix_view( + components.data_handle(), n_cols, n_cols)); + raft::linalg::transpose(components.data_handle(), n_cols, stream); + + raft::matrix::row_reverse(handle_stream_zero, + raft::make_device_matrix_view( + explained_var.data_handle(), n_cols, idx_t(1))); +} + +/** + * @brief sign flip for PCA and tSVD. Stabilizes the sign of column major eigenvectors. + * @param handle: raft::resources + * @param input: input data [n_samples x n_features] (col-major) + * @param components: components matrix [n_components x n_features] (col-major) + * @param center whether to mean-center input before computing signs + * @param flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void sign_flip_components(raft::resources const& handle, + raft::device_matrix_view input, + raft::device_matrix_view components, + bool center, + bool flip_signs_based_on_U = false) +{ + auto stream = resource::get_cuda_stream(handle); + auto n_samples = input.extent(0); + auto n_features = input.extent(1); + auto n_components = components.extent(0); + + rmm::device_uvector max_vals(static_cast(n_components), stream); + auto components_view = raft::make_device_matrix_view( + components.data_handle(), n_components, n_features); + auto max_vals_view = raft::make_device_vector_view(max_vals.data(), n_components); + + if (flip_signs_based_on_U) { + if (center) { + rmm::device_uvector col_means(static_cast(n_features), stream); + raft::stats::mean( + col_means.data(), input.data_handle(), n_features, n_samples, stream); + raft::stats::meanCenter( + input.data_handle(), input.data_handle(), col_means.data(), n_features, n_samples, stream); + } + rmm::device_uvector US(static_cast(n_samples * n_components), stream); + raft::linalg::gemm(handle, + input.data_handle(), + n_samples, + n_features, + components.data_handle(), + US.data(), + n_samples, + n_components, + CUBLAS_OP_N, + CUBLAS_OP_T, + math_t(1), + math_t(0), + stream); + raft::linalg::reduce( + max_vals.data(), + US.data(), + n_components, + n_samples, + math_t(0), + stream, + false, + raft::identity_op(), + [] __device__(math_t a, math_t b) { + math_t abs_a = a >= math_t(0) ? a : -a; + math_t abs_b = b >= math_t(0) ? b : -b; + return abs_a >= abs_b ? a : b; + }, + raft::identity_op()); + } else { + raft::linalg::reduce( + max_vals.data(), + components.data_handle(), + n_features, + n_components, + math_t(0), + stream, + false, + raft::identity_op(), + [] __device__(math_t a, math_t b) { + math_t abs_a = a >= math_t(0) ? a : -a; + math_t abs_b = b >= math_t(0) ? b : -b; + return abs_a >= abs_b ? a : b; + }, + raft::identity_op()); + } + + raft::linalg::map_offset( + handle, + components_view, + [components_view, max_vals_view, n_components, n_features] __device__(auto idx) { + auto row = idx % n_components; + auto column = idx / n_components; + return (max_vals_view(row) < math_t(0)) ? (-components_view(row, column)) + : components_view(row, column); + }); +} + +/** + * @brief sign flip for PCA and tSVD. Stabilizes the sign of column major eigenvectors. + * @param handle: raft::resources + * @param input: input matrix [n_rows x n_cols] (col-major). Modified in place. + * @param components: components matrix [n_rows x n_cols_comp] (col-major). Modified in place. + */ +template +void sign_flip(raft::resources const& handle, + raft::device_matrix_view input, + raft::device_matrix_view components) +{ + auto stream = resource::get_cuda_stream(handle); + auto n_rows = input.extent(0); + auto n_cols = input.extent(1); + auto n_cols_comp = components.extent(1); + + auto* input_ptr = input.data_handle(); + auto* components_ptr = components.data_handle(); + auto counting = thrust::make_counting_iterator(0); + auto m = n_rows; + + thrust::for_each( + rmm::exec_policy(stream), counting, counting + n_cols, [=] __device__(idx_t idx) { + auto d_i = idx * m; + auto end = d_i + m; + + math_t max = 0.0; + idx_t max_index = 0; + for (auto i = d_i; i < end; i++) { + math_t val = input_ptr[i]; + if (val < 0.0) { val = -val; } + if (val > max) { + max = val; + max_index = i; + } + } + + if (input_ptr[max_index] < 0.0) { + for (auto i = d_i; i < end; i++) { + input_ptr[i] = -input_ptr[i]; + } + + auto len = n_cols * n_cols_comp; + for (auto i = idx; i < len; i = i + n_cols) { + components_ptr[i] = -components_ptr[i]; + } + } + }); +} + +/** + * @brief perform fit operation for the tsvd. + * @param[in] handle: raft::resources + * @param[in] prms: data structure that includes all the parameters from input size to algorithm. + * @param[in] input: the data is fitted to tSVD. Size n_rows x n_cols (col-major). + * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[out] singular_vals: singular values of the data. Size n_components. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void tsvd_fit(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + bool flip_signs_based_on_U = false) +{ + auto stream = resource::get_cuda_stream(handle); + auto cublas_handle = raft::resource::get_cublas_handle(handle); + + auto n_rows = input.extent(0); + auto n_cols = input.extent(1); + + auto n_components = components.extent(0); + + ASSERT(n_cols > 1, "Parameter n_cols: number of columns cannot be less than two"); + ASSERT(n_rows > 1, "Parameter n_rows: number of rows cannot be less than two"); + ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); + ASSERT(n_components <= n_cols, "n_components cannot exceed n_cols"); + + auto len = static_cast(n_cols * n_cols); + rmm::device_uvector input_cross_mult(len, stream); + + math_t alpha = math_t(1); + math_t beta = math_t(0); + raft::linalg::gemm(handle, + input.data_handle(), + n_rows, + n_cols, + input.data_handle(), + input_cross_mult.data(), + n_cols, + n_cols, + CUBLAS_OP_T, + CUBLAS_OP_N, + alpha, + beta, + stream); + + rmm::device_uvector components_all(len, stream); + rmm::device_uvector explained_var_all(static_cast(n_cols), stream); + + detail::cal_eig(handle, + prms, + raft::make_device_matrix_view( + input_cross_mult.data(), n_cols, n_cols), + raft::make_device_matrix_view( + components_all.data(), n_cols, n_cols), + raft::make_device_vector_view(explained_var_all.data(), n_cols)); + + raft::matrix::trunc_zero_origin( + handle, + raft::make_device_matrix_view( + components_all.data(), n_cols, n_cols), + raft::make_device_matrix_view( + components.data_handle(), n_components, n_cols)); + + math_t scalar = math_t(1); + raft::matrix::weighted_sqrt(handle, + raft::make_device_matrix_view( + explained_var_all.data(), idx_t(1), n_components), + raft::make_device_matrix_view( + singular_vals.data_handle(), idx_t(1), n_components), + raft::make_host_scalar_view(&scalar)); + + detail::sign_flip_components(handle, + input, + raft::make_device_matrix_view( + components.data_handle(), n_components, n_cols), + false, + flip_signs_based_on_U); +} + +/** + * @brief performs transform operation for the tsvd. Transforms the data to eigenspace. + * @param[in] handle raft::resources + * @param[in] prms: data structure that includes all the parameters from input size to algorithm. + * @param[in] input: the data to transform. Size n_rows x n_cols (col-major). + * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[out] trans_input: transformed output. Size n_rows x n_components (col-major). + */ +template +void tsvd_transform(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_matrix_view trans_input) +{ + auto stream = resource::get_cuda_stream(handle); + + auto n_rows = input.extent(0); + auto n_cols = input.extent(1); + auto n_components = components.extent(0); + + ASSERT(n_cols > 1, "Parameter n_cols: number of columns cannot be less than two"); + ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); + ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); + + math_t alpha = math_t(1); + math_t beta = math_t(0); + raft::linalg::gemm(handle, + input.data_handle(), + n_rows, + n_cols, + components.data_handle(), + trans_input.data_handle(), + n_rows, + n_components, + CUBLAS_OP_N, + CUBLAS_OP_T, + alpha, + beta, + stream); +} + +/** + * @brief performs inverse transform operation for the tsvd. + * @param[in] handle raft::resources + * @param[in] prms: data structure that includes all the parameters from input size to algorithm. + * @param[in] trans_input: the transformed data. Size n_rows x n_components (col-major). + * @param[in] components: principal components. Size n_components x n_cols (col-major). + * @param[out] output: reconstructed output. Size n_rows x n_cols (col-major). + */ +template +void tsvd_inverse_transform(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_matrix_view output) +{ + auto stream = resource::get_cuda_stream(handle); + + auto n_rows = output.extent(0); + auto n_cols = output.extent(1); + auto n_components = components.extent(0); + + ASSERT(n_cols > 1, "Parameter n_cols: number of columns cannot be less than one"); + ASSERT(n_rows > 0, "Parameter n_rows: number of rows cannot be less than one"); + ASSERT(n_components > 0, "Parameter n_components: number of components cannot be less than one"); + + math_t alpha = math_t(1); + math_t beta = math_t(0); + + raft::linalg::gemm(handle, + trans_input.data_handle(), + n_rows, + n_components, + components.data_handle(), + output.data_handle(), + n_rows, + n_cols, + CUBLAS_OP_N, + CUBLAS_OP_N, + alpha, + beta, + stream); +} + +/** + * @brief performs fit and transform operations for the tsvd. + * @param[in] handle: raft::resources + * @param[in] prms: data structure that includes all the parameters from input size to algorithm. + * @param[in] input: the data is fitted to tSVD. Size n_rows x n_cols (col-major). + * @param[out] trans_input: the transformed data. Size n_rows x n_components (col-major). + * @param[out] components: the principal components. Size n_components x n_cols (col-major). + * @param[out] explained_var: explained variances. Size n_components. + * @param[out] explained_var_ratio: ratio of explained variance to total. Size n_components. + * @param[out] singular_vals: singular values of the data. Size n_components. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void tsvd_fit_transform(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + bool flip_signs_based_on_U = false) +{ + auto stream = resource::get_cuda_stream(handle); + + auto n_rows = input.extent(0); + auto n_cols = input.extent(1); + auto n_components = components.extent(0); + + detail::tsvd_fit(handle, prms, input, components, singular_vals, flip_signs_based_on_U); + detail::tsvd_transform(handle, prms, input, components, trans_input); + + rmm::device_uvector mu_trans(static_cast(n_components), stream); + raft::stats::mean( + mu_trans.data(), trans_input.data_handle(), n_components, n_rows, false, stream); + raft::stats::vars(explained_var.data_handle(), + trans_input.data_handle(), + mu_trans.data(), + n_components, + n_rows, + false, + stream); + + rmm::device_uvector mu(static_cast(n_cols), stream); + rmm::device_uvector vars(static_cast(n_cols), stream); + + raft::stats::mean(mu.data(), input.data_handle(), n_cols, n_rows, false, stream); + raft::stats::vars( + vars.data(), input.data_handle(), mu.data(), n_cols, n_rows, false, stream); + + rmm::device_scalar total_vars(stream); + raft::stats::sum( + total_vars.data(), vars.data(), std::size_t(1), static_cast(n_cols), stream); + + math_t total_vars_h; + raft::update_host(&total_vars_h, total_vars.data(), 1, stream); + raft::resource::sync_stream(handle, stream); + math_t scalar = math_t(1) / total_vars_h; + + raft::linalg::scalarMultiply( + explained_var_ratio.data_handle(), explained_var.data_handle(), scalar, n_components, stream); +} + +}; // end namespace raft::linalg::detail diff --git a/cpp/include/raft/linalg/pca.cuh b/cpp/include/raft/linalg/pca.cuh new file mode 100644 index 0000000000..218d12cf77 --- /dev/null +++ b/cpp/include/raft/linalg/pca.cuh @@ -0,0 +1,191 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "detail/pca.cuh" + +#include + +namespace raft::linalg { + +/** + * @defgroup pca PCA operations + * @{ + */ + +/** + * @brief perform fit operation for PCA. Generates eigenvectors, explained vars, singular vals, etc. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle: raft::resources + * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) + * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols (col-major). Modified + * temporarily during computation. + * @param[out] components the principal components of the input data. Size n_components x n_cols + * (col-major). + * @param[out] explained_var explained variances (eigenvalues) of the principal components. Size + * n_components. + * @param[out] explained_var_ratio the ratio of the explained variance and total variance. Size + * n_components. + * @param[out] singular_vals singular values of the data. Size n_components. + * @param[out] mu mean of all the features (all the columns in the data). Size n_cols. + * @param[out] noise_vars variance of the noise. Scalar. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void pca_fit(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false) +{ + detail::pca_fit(handle, + prms, + input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); +} + +/** + * @brief perform fit and transform operations for PCA. Generates transformed data, + * eigenvectors, explained vars, singular vals, etc. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) + * @param[inout] input the data is fitted to PCA. Size n_rows x n_cols (col-major). Modified + * temporarily during computation. + * @param[out] trans_input the transformed data. Size n_rows x n_components (col-major). + * @param[out] components the principal components of the input data. Size n_components x n_cols + * (col-major). + * @param[out] explained_var explained variances (eigenvalues) of the principal components. Size + * n_components. + * @param[out] explained_var_ratio the ratio of the explained variance and total variance. Size + * n_components. + * @param[out] singular_vals singular values of the data. Size n_components. + * @param[out] mu mean of all the features (all the columns in the data). Size n_cols. + * @param[out] noise_vars variance of the noise. Scalar. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void pca_fit_transform(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_scalar_view noise_vars, + bool flip_signs_based_on_U = false) +{ + detail::pca_fit_transform(handle, + prms, + input, + trans_input, + components, + explained_var, + explained_var_ratio, + singular_vals, + mu, + noise_vars, + flip_signs_based_on_U); +} + +/** + * @brief performs inverse transform operation for PCA. Transforms the transformed data back to + * original data. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) + * @param[in] trans_input the transformed data. Size n_rows x n_components (col-major). + * @param[in] components the principal components of the input data. Size n_components x n_cols + * (col-major). + * @param[in] singular_vals singular values of the data. Size n_components. + * @param[in] mu mean of features (every column). Size n_cols. + * @param[out] output the reconstructed data. Size n_rows x n_cols (col-major). + */ +template +void pca_inverse_transform(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view output) +{ + detail::pca_inverse_transform(handle, prms, trans_input, components, singular_vals, mu, output); +} + +/** + * @brief performs transform operation for PCA. Transforms the data to eigenspace. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms PCA parameters (n_components, algorithm, whiten, etc.) + * @param[inout] input the data to be transformed. Size n_rows x n_cols (col-major). Modified + * temporarily during computation (mean-centered then restored). + * @param[in] components principal components of the input data. Size n_components x n_cols + * (col-major). + * @param[in] singular_vals singular values of the data. Size n_components. + * @param[in] mu mean value of the input data. Size n_cols. + * @param[out] trans_input the transformed data. Size n_rows x n_components (col-major). + */ +template +void pca_transform(raft::resources const& handle, + const paramsPCA& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + raft::device_vector_view mu, + raft::device_matrix_view trans_input) +{ + detail::pca_transform(handle, prms, input, components, singular_vals, mu, trans_input); +} + +/** + * @brief Compute truncated components, explained variances, explained variance ratios, + * and noise variance from a covariance matrix. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms tSVD parameters (controls algorithm, tolerance, iterations) + * @param[inout] in covariance matrix [n_cols x n_cols] (col-major). Overwritten. + * @param[out] components truncated eigenvectors [n_components x n_cols] (col-major) + * @param[out] explained_var explained variances [n_components] + * @param[out] explained_var_ratio explained variance ratios [n_components] + * @param[out] noise_vars noise variance scalar + * @param[in] n_rows number of rows in the original data (needed for noise variance computation) + */ +template +void trunc_comp_exp_vars(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view in, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_scalar_view noise_vars, + std::size_t n_rows) +{ + detail::trunc_comp_exp_vars( + handle, prms, in, components, explained_var, explained_var_ratio, noise_vars, n_rows); +} + +/** @} */ // end group pca + +}; // end namespace raft::linalg diff --git a/cpp/include/raft/linalg/pca_types.hpp b/cpp/include/raft/linalg/pca_types.hpp new file mode 100644 index 0000000000..cca3b09155 --- /dev/null +++ b/cpp/include/raft/linalg/pca_types.hpp @@ -0,0 +1,36 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace raft::linalg { + +/** + * @brief Solver algorithm for PCA/TSVD eigen decomposition. + * + * @param COV_EIG_DQ covariance + divide-and-conquer eigen decomposition for symmetric matrices + * @param COV_EIG_JACOBI covariance + Jacobi eigen decomposition for symmetric matrices + */ +enum class solver : int { + COV_EIG_DQ, + COV_EIG_JACOBI, +}; + +/** @brief Parameters for TSVD (and base for PCA). */ +struct paramsTSVD { + float tol = 0.0; + uint64_t n_iterations = 15; + solver algorithm = solver::COV_EIG_DQ; +}; + +/** @brief Parameters for PCA (extends TSVD with whitening / copy controls). */ +struct paramsPCA : paramsTSVD { + bool copy = true; + bool whiten = false; +}; + +}; // end namespace raft::linalg diff --git a/cpp/include/raft/linalg/tsvd.cuh b/cpp/include/raft/linalg/tsvd.cuh new file mode 100644 index 0000000000..8f946d5b39 --- /dev/null +++ b/cpp/include/raft/linalg/tsvd.cuh @@ -0,0 +1,169 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "detail/tsvd.cuh" + +#include + +namespace raft::linalg { + +/** + * @defgroup tsvd Truncated SVD operations + * @{ + */ + +/** + * @brief perform fit operation for tSVD. Generates eigenvectors, singular vals, etc. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms data structure that includes all the parameters from input size to algorithm. + * @param[inout] input the data is fitted to tSVD. Size n_rows x n_cols (col-major). + * @param[out] components the principal components of the input data. Size n_components x n_cols + * (col-major). + * @param[out] singular_vals singular values of the data. Size n_components. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void tsvd_fit(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_vector_view singular_vals, + bool flip_signs_based_on_U = false) +{ + detail::tsvd_fit(handle, prms, input, components, singular_vals, flip_signs_based_on_U); +} + +/** + * @brief performs fit and transform operations for tSVD. Generates transformed data, + * eigenvectors, explained vars, singular vals, etc. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms data structure that includes all the parameters from input size to algorithm. + * @param[inout] input the data is fitted to tSVD. Size n_rows x n_cols (col-major). + * @param[out] trans_input the transformed data. Size n_rows x n_components (col-major). + * @param[out] components the principal components of the input data. Size n_components x n_cols + * (col-major). + * @param[out] explained_var explained variances (eigenvalues) of the principal components. Size + * n_components. + * @param[out] explained_var_ratio the ratio of the explained variance and total variance. Size + * n_components. + * @param[out] singular_vals singular values of the data. Size n_components. + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void tsvd_fit_transform(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view input, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_vector_view explained_var, + raft::device_vector_view explained_var_ratio, + raft::device_vector_view singular_vals, + bool flip_signs_based_on_U = false) +{ + detail::tsvd_fit_transform(handle, + prms, + input, + trans_input, + components, + explained_var, + explained_var_ratio, + singular_vals, + flip_signs_based_on_U); +} + +/** + * @brief performs transform operation for tSVD. Transforms the data to eigenspace. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms data structure that includes all the parameters from input size to algorithm. + * @param[in] input the data to be transformed. Size n_rows x n_cols (col-major). + * @param[in] components principal components of the input data. Size n_components x n_cols + * (col-major). + * @param[out] trans_input output that is transformed version of input. Size n_rows x n_components + * (col-major). + */ +template +void tsvd_transform(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view input, + raft::device_matrix_view components, + raft::device_matrix_view trans_input) +{ + detail::tsvd_transform(handle, prms, input, components, trans_input); +} + +/** + * @brief performs inverse transform operation for tSVD. Transforms the transformed data back to + * original data. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms data structure that includes all the parameters from input size to algorithm. + * @param[in] trans_input the transformed data. Size n_rows x n_components (col-major). + * @param[in] components the principal components. Size n_components x n_cols + * (col-major). + * @param[out] output the reconstructed data. Size n_rows x n_cols (col-major). + */ +template +void tsvd_inverse_transform(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view trans_input, + raft::device_matrix_view components, + raft::device_matrix_view output) +{ + detail::tsvd_inverse_transform(handle, prms, trans_input, components, output); +} + +/** + * @brief Eigendecomposition helper for tSVD/PCA. Computes eigenvectors and eigenvalues + * of a symmetric matrix using either divide-and-conquer or Jacobi method. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] prms tSVD parameters (controls algorithm, tolerance, iterations) + * @param[inout] in symmetric input matrix [n_cols x n_cols] (col-major). Overwritten. + * @param[out] components eigenvectors [n_cols x n_cols] (col-major) + * @param[out] explained_var eigenvalues [n_cols] + */ +template +void cal_eig(raft::resources const& handle, + const paramsTSVD& prms, + raft::device_matrix_view in, + raft::device_matrix_view components, + raft::device_vector_view explained_var) +{ + detail::cal_eig(handle, prms, in, components, explained_var); +} + +/** + * @brief Sign flip for PCA and tSVD. Stabilizes the sign of column-major eigenvectors. + * @tparam math_t data-type upon which the math operation will be performed + * @tparam idx_t integer type used for indexing + * @param[in] handle raft::resources + * @param[in] input input data matrix [n_samples x n_features] (col-major) + * @param[inout] components components matrix [n_components x n_features] (col-major) + * @param[in] center whether to mean-center input before computing signs + * @param[in] flip_signs_based_on_U whether to determine signs by U (true) or V.T (false) + */ +template +void sign_flip_components(raft::resources const& handle, + raft::device_matrix_view input, + raft::device_matrix_view components, + bool center, + bool flip_signs_based_on_U = false) +{ + detail::sign_flip_components(handle, input, components, center, flip_signs_based_on_U); +} + +/** @} */ // end group tsvd + +}; // end namespace raft::linalg diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 5d88a47267..47ac6fc286 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -163,6 +163,7 @@ if(BUILD_TESTS) linalg/multiply.cu linalg/norm.cu linalg/normalize.cu + linalg/pca.cu linalg/power.cu linalg/randomized_svd.cu linalg/reduce.cu @@ -175,6 +176,7 @@ if(BUILD_TESTS) linalg/svd.cu linalg/ternary_op.cu linalg/transpose.cu + linalg/tsvd.cu linalg/unary_op.cu GPUS 1 diff --git a/cpp/tests/linalg/pca.cu b/cpp/tests/linalg/pca.cu new file mode 100644 index 0000000000..a1d80953e2 --- /dev/null +++ b/cpp/tests/linalg/pca.cu @@ -0,0 +1,335 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::linalg { + +template +struct PcaInputs { + T tolerance; + int len; + int n_row; + int n_col; + int len2; + int n_row2; + int n_col2; + unsigned long long int seed; + int algo; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const PcaInputs& dims) +{ + return os; +} + +template +class PcaTest : public ::testing::TestWithParam> { + public: + PcaTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + explained_vars(params.n_col, stream), + explained_vars_ref(params.n_col, stream), + components(params.n_col * params.n_col, stream), + components_ref(params.n_col * params.n_col, stream), + trans_data(params.len, stream), + trans_data_ref(params.len, stream), + data(params.len, stream), + data_back(params.len, stream), + data2(params.len2, stream), + data2_back(params.len2, stream) + { + basicTest(); + advancedTest(); + } + + protected: + void basicTest() + { + raft::random::Rng r(params.seed, raft::random::GenPC); + int len = params.len; + + std::vector data_h = {1.0, 2.0, 5.0, 4.0, 2.0, 1.0}; + data_h.resize(len); + raft::update_device(data.data(), data_h.data(), len, stream); + + std::vector trans_data_ref_h = {-2.3231, -0.3517, 2.6748, 0.3979, -0.6571, 0.2592}; + trans_data_ref_h.resize(len); + raft::update_device(trans_data_ref.data(), trans_data_ref_h.data(), len, stream); + + int n_components = params.n_col; + int len_comp = params.n_col * params.n_col; + rmm::device_uvector explained_var_ratio(params.n_col, stream); + rmm::device_uvector singular_vals(params.n_col, stream); + rmm::device_uvector mean(params.n_col, stream); + rmm::device_uvector noise_vars(1, stream); + + std::vector components_ref_h = {0.8163, 0.5776, -0.5776, 0.8163}; + components_ref_h.resize(len_comp); + std::vector explained_vars_ref_h = {6.338, 0.3287}; + explained_vars_ref_h.resize(params.n_col); + + raft::update_device(components_ref.data(), components_ref_h.data(), len_comp, stream); + raft::update_device( + explained_vars_ref.data(), explained_vars_ref_h.data(), params.n_col, stream); + + std::size_t n_rows = params.n_row; + std::size_t n_cols = params.n_col; + + paramsPCA prms; + prms.whiten = false; + if (params.algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else + prms.algorithm = solver::COV_EIG_JACOBI; + + auto input_view = + raft::make_device_matrix_view(data.data(), n_rows, n_cols); + auto components_view = raft::make_device_matrix_view( + components.data(), n_components, n_cols); + auto explained_var_view = + raft::make_device_vector_view(explained_vars.data(), n_components); + auto explained_var_ratio_view = + raft::make_device_vector_view(explained_var_ratio.data(), n_components); + auto singular_vals_view = + raft::make_device_vector_view(singular_vals.data(), n_components); + auto mu_view = raft::make_device_vector_view(mean.data(), n_cols); + auto noise_vars_view = raft::make_device_scalar_view(noise_vars.data()); + + pca_fit(handle, + prms, + input_view, + components_view, + explained_var_view, + explained_var_ratio_view, + singular_vals_view, + mu_view, + noise_vars_view); + + auto trans_data_view = raft::make_device_matrix_view( + trans_data.data(), n_rows, n_components); + + pca_transform( + handle, prms, input_view, components_view, singular_vals_view, mu_view, trans_data_view); + + auto data_back_view = raft::make_device_matrix_view( + data_back.data(), n_rows, n_cols); + + pca_inverse_transform( + handle, prms, trans_data_view, components_view, singular_vals_view, mu_view, data_back_view); + } + + void advancedTest() + { + raft::random::Rng r(params.seed, raft::random::GenPC); + int len = params.len2; + + std::size_t n_rows = params.n_row2; + std::size_t n_cols = params.n_col2; + std::size_t n_components = params.n_col2; + + paramsPCA prms; + prms.whiten = false; + if (params.algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else if (params.algo == 1) + prms.algorithm = solver::COV_EIG_JACOBI; + + r.uniform(data2.data(), len, T(-1.0), T(1.0), stream); + rmm::device_uvector data2_trans(n_rows * n_components, stream); + + int len_comp = params.n_col2 * n_components; + rmm::device_uvector components2(len_comp, stream); + rmm::device_uvector explained_vars2(n_components, stream); + rmm::device_uvector explained_var_ratio2(n_components, stream); + rmm::device_uvector singular_vals2(n_components, stream); + rmm::device_uvector mean2(n_cols, stream); + rmm::device_uvector noise_vars2(1, stream); + + auto input_view = + raft::make_device_matrix_view(data2.data(), n_rows, n_cols); + auto trans_view = raft::make_device_matrix_view( + data2_trans.data(), n_rows, n_components); + auto comp_view = raft::make_device_matrix_view( + components2.data(), n_components, n_cols); + auto ev_view = + raft::make_device_vector_view(explained_vars2.data(), n_components); + auto evr_view = + raft::make_device_vector_view(explained_var_ratio2.data(), n_components); + auto sv_view = + raft::make_device_vector_view(singular_vals2.data(), n_components); + auto mu_view = raft::make_device_vector_view(mean2.data(), n_cols); + auto noise_view = raft::make_device_scalar_view(noise_vars2.data()); + + pca_fit_transform(handle, + prms, + input_view, + trans_view, + comp_view, + ev_view, + evr_view, + sv_view, + mu_view, + noise_view); + + auto data2_back_view = raft::make_device_matrix_view( + data2_back.data(), n_rows, n_cols); + + pca_inverse_transform(handle, prms, trans_view, comp_view, sv_view, mu_view, data2_back_view); + } + + protected: + raft::device_resources handle; + cudaStream_t stream; + + PcaInputs params; + + rmm::device_uvector explained_vars, explained_vars_ref, components, components_ref, trans_data, + trans_data_ref, data, data_back, data2, data2_back; +}; + +const std::vector> inputsf2 = { + {0.01f, 3 * 2, 3, 2, 1024 * 128, 1024, 128, 1234ULL, 0}, + {0.01f, 3 * 2, 3, 2, 256 * 32, 256, 32, 1234ULL, 1}}; + +const std::vector> inputsd2 = { + {0.01, 3 * 2, 3, 2, 1024 * 128, 1024, 128, 1234ULL, 0}, + {0.01, 3 * 2, 3, 2, 256 * 32, 256, 32, 1234ULL, 1}}; + +typedef PcaTest PcaTestValF; +TEST_P(PcaTestValF, Result) +{ + ASSERT_TRUE(devArrMatch(explained_vars.data(), + explained_vars_ref.data(), + params.n_col, + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestValD; +TEST_P(PcaTestValD, Result) +{ + ASSERT_TRUE(devArrMatch(explained_vars.data(), + explained_vars_ref.data(), + params.n_col, + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestLeftVecF; +TEST_P(PcaTestLeftVecF, Result) +{ + ASSERT_TRUE(devArrMatch(components.data(), + components_ref.data(), + (params.n_col * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestLeftVecD; +TEST_P(PcaTestLeftVecD, Result) +{ + ASSERT_TRUE(devArrMatch(components.data(), + components_ref.data(), + (params.n_col * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestTransDataF; +TEST_P(PcaTestTransDataF, Result) +{ + ASSERT_TRUE(devArrMatch(trans_data.data(), + trans_data_ref.data(), + (params.n_row * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestTransDataD; +TEST_P(PcaTestTransDataD, Result) +{ + ASSERT_TRUE(devArrMatch(trans_data.data(), + trans_data_ref.data(), + (params.n_row * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecSmallF; +TEST_P(PcaTestDataVecSmallF, Result) +{ + ASSERT_TRUE(devArrMatch(data.data(), + data_back.data(), + (params.n_col * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecSmallD; +TEST_P(PcaTestDataVecSmallD, Result) +{ + ASSERT_TRUE(devArrMatch(data.data(), + data_back.data(), + (params.n_col * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecF; +TEST_P(PcaTestDataVecF, Result) +{ + ASSERT_TRUE(devArrMatch(data2.data(), + data2_back.data(), + (params.n_col2 * params.n_col2), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef PcaTest PcaTestDataVecD; +TEST_P(PcaTestDataVecD, Result) +{ + ASSERT_TRUE(devArrMatch(data2.data(), + data2_back.data(), + (params.n_col2 * params.n_col2), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestValD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestLeftVecF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestLeftVecD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecSmallF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecSmallD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestTransDataF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestTransDataD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(PcaTests, PcaTestDataVecD, ::testing::ValuesIn(inputsd2)); + +} // end namespace raft::linalg diff --git a/cpp/tests/linalg/tsvd.cu b/cpp/tests/linalg/tsvd.cu new file mode 100644 index 0000000000..9fca1423bf --- /dev/null +++ b/cpp/tests/linalg/tsvd.cu @@ -0,0 +1,228 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2018-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::linalg { + +template +struct TsvdInputs { + T tolerance; + int n_row; + int n_col; + int n_row2; + int n_col2; + float redundancy; + unsigned long long int seed; + int algo; +}; + +template +::std::ostream& operator<<(::std::ostream& os, const TsvdInputs& dims) +{ + return os; +} + +template +class TsvdTest : public ::testing::TestWithParam> { + public: + TsvdTest() + : params(::testing::TestWithParam>::GetParam()), + stream(resource::get_cuda_stream(handle)), + components(0, stream), + components_ref(0, stream), + data2(0, stream), + data2_back(0, stream) + { + basicTest(); + advancedTest(); + } + + protected: + void basicTest() + { + raft::random::Rng r(params.seed, raft::random::GenPC); + int len = params.n_row * params.n_col; + + rmm::device_uvector data(len, stream); + + std::vector data_h = {1.0, 2.0, 4.0, 2.0, 4.0, 5.0, 5.0, 4.0, 2.0, 1.0, 6.0, 4.0}; + data_h.resize(len); + raft::update_device(data.data(), data_h.data(), len, stream); + + int len_comp = params.n_col * params.n_col; + components.resize(len_comp, stream); + rmm::device_uvector singular_vals(params.n_col, stream); + + std::vector components_ref_h = { + 0.3951, 0.1532, 0.9058, 0.7111, -0.6752, -0.1959, 0.5816, 0.7215, -0.3757}; + components_ref_h.resize(len_comp); + + components_ref.resize(len_comp, stream); + raft::update_device(components_ref.data(), components_ref_h.data(), len_comp, stream); + + std::size_t n_rows = params.n_row; + std::size_t n_cols = params.n_col; + std::size_t n_components = params.n_col; + + paramsTSVD prms; + if (params.algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else + prms.algorithm = solver::COV_EIG_JACOBI; + + auto input_view = + raft::make_device_matrix_view(data.data(), n_rows, n_cols); + auto components_view = raft::make_device_matrix_view( + components.data(), n_components, n_cols); + auto singular_vals_view = + raft::make_device_vector_view(singular_vals.data(), n_components); + + tsvd_fit(handle, prms, input_view, components_view, singular_vals_view); + } + + void advancedTest() + { + raft::random::Rng r(params.seed, raft::random::GenPC); + int len = params.n_row2 * params.n_col2; + + std::size_t n_rows = params.n_row2; + std::size_t n_cols = params.n_col2; + std::size_t n_components = params.n_col2; + + paramsTSVD prms; + if (params.algo == 0) + prms.algorithm = solver::COV_EIG_DQ; + else if (params.algo == 1) + prms.algorithm = solver::COV_EIG_JACOBI; + else + n_components = params.n_col2 - 15; + + data2.resize(len, stream); + int redundant_cols = int(params.redundancy * params.n_col2); + int redundant_len = params.n_row2 * redundant_cols; + + int informative_cols = params.n_col2 - redundant_cols; + int informative_len = params.n_row2 * informative_cols; + + r.uniform(data2.data(), informative_len, T(-1.0), T(1.0), stream); + RAFT_CUDA_TRY(cudaMemcpyAsync(data2.data() + informative_len, + data2.data(), + redundant_len * sizeof(T), + cudaMemcpyDeviceToDevice, + stream)); + rmm::device_uvector data2_trans(n_rows * n_components, stream); + + int len_comp = params.n_col2 * n_components; + rmm::device_uvector components2(len_comp, stream); + rmm::device_uvector explained_vars2(n_components, stream); + rmm::device_uvector explained_var_ratio2(n_components, stream); + rmm::device_uvector singular_vals2(n_components, stream); + + auto input_view = + raft::make_device_matrix_view(data2.data(), n_rows, n_cols); + auto trans_view = raft::make_device_matrix_view( + data2_trans.data(), n_rows, n_components); + auto comp_view = raft::make_device_matrix_view( + components2.data(), n_components, n_cols); + auto ev_view = + raft::make_device_vector_view(explained_vars2.data(), n_components); + auto evr_view = + raft::make_device_vector_view(explained_var_ratio2.data(), n_components); + auto sv_view = + raft::make_device_vector_view(singular_vals2.data(), n_components); + + tsvd_fit_transform(handle, prms, input_view, trans_view, comp_view, ev_view, evr_view, sv_view); + + data2_back.resize(len, stream); + + auto trans_in_view = raft::make_device_matrix_view( + data2_trans.data(), n_rows, n_components); + auto comp_in_view = raft::make_device_matrix_view( + components2.data(), n_components, n_cols); + auto output_view = raft::make_device_matrix_view( + data2_back.data(), n_rows, n_cols); + + tsvd_inverse_transform(handle, prms, trans_in_view, comp_in_view, output_view); + } + + protected: + raft::device_resources handle; + cudaStream_t stream; + + TsvdInputs params; + rmm::device_uvector components, components_ref, data2, data2_back; +}; + +const std::vector> inputsf2 = {{0.01f, 4, 3, 1024, 128, 0.25f, 1234ULL, 0}, + {0.01f, 4, 3, 1024, 128, 0.25f, 1234ULL, 1}, + {0.04f, 4, 3, 512, 64, 0.25f, 1234ULL, 2}, + {0.04f, 4, 3, 512, 64, 0.25f, 1234ULL, 2}}; + +const std::vector> inputsd2 = {{0.01, 4, 3, 1024, 128, 0.25f, 1234ULL, 0}, + {0.01, 4, 3, 1024, 128, 0.25f, 1234ULL, 1}, + {0.05, 4, 3, 512, 64, 0.25f, 1234ULL, 2}, + {0.05, 4, 3, 512, 64, 0.25f, 1234ULL, 2}}; + +typedef TsvdTest TsvdTestLeftVecF; +TEST_P(TsvdTestLeftVecF, Result) +{ + ASSERT_TRUE(devArrMatch(components.data(), + components_ref.data(), + (params.n_col * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef TsvdTest TsvdTestLeftVecD; +TEST_P(TsvdTestLeftVecD, Result) +{ + ASSERT_TRUE(devArrMatch(components.data(), + components_ref.data(), + (params.n_col * params.n_col), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef TsvdTest TsvdTestDataVecF; +TEST_P(TsvdTestDataVecF, Result) +{ + ASSERT_TRUE(devArrMatch(data2.data(), + data2_back.data(), + (params.n_col2 * params.n_col2), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +typedef TsvdTest TsvdTestDataVecD; +TEST_P(TsvdTestDataVecD, Result) +{ + ASSERT_TRUE(devArrMatch(data2.data(), + data2_back.data(), + (params.n_col2 * params.n_col2), + raft::CompareApprox(params.tolerance), + resource::get_cuda_stream(handle))); +} + +INSTANTIATE_TEST_CASE_P(TsvdTests, TsvdTestLeftVecF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(TsvdTests, TsvdTestLeftVecD, ::testing::ValuesIn(inputsd2)); + +INSTANTIATE_TEST_CASE_P(TsvdTests, TsvdTestDataVecF, ::testing::ValuesIn(inputsf2)); + +INSTANTIATE_TEST_CASE_P(TsvdTests, TsvdTestDataVecD, ::testing::ValuesIn(inputsd2)); + +} // end namespace raft::linalg diff --git a/docs/source/cpp_api/linalg_solver.rst b/docs/source/cpp_api/linalg_solver.rst index 1a811e072a..7d81aa17e0 100644 --- a/docs/source/cpp_api/linalg_solver.rst +++ b/docs/source/cpp_api/linalg_solver.rst @@ -64,3 +64,27 @@ namespace *raft::linalg* :project: RAFT :members: :content-only: + +PCA +--- + +``#include `` + +namespace *raft::linalg* + +.. doxygengroup:: pca + :project: RAFT + :members: + :content-only: + +Truncated SVD +------------- + +``#include `` + +namespace *raft::linalg* + +.. doxygengroup:: tsvd + :project: RAFT + :members: + :content-only: