Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,52 @@ void find_k(raft::resources const& handle,
float tol = 1e-3);
} // namespace helpers

/**
* @}
*/

/**
* @defgroup predict_host K-Means Predict (host data)
* @{
*/

/**
* @brief Predict cluster labels for host data using batched processing.
*
* Streams data from host to GPU in batches, assigns each sample to its nearest
* centroid, and writes labels back to host memory.
* The batch size is controlled by params.streaming_batch_size.
*
* @param[in] handle The raft handle.
* @param[in] params Parameters for KMeans model.
* @param[in] X Input samples on HOST memory. [dim = n_samples x n_features]
* @param[in] sample_weight Optional weights for each observation (on host).
* @param[in] centroids Cluster centers on device. [dim = n_clusters x n_features]
* @param[out] labels Predicted cluster labels on HOST memory. [dim = n_samples]
* @param[in] normalize_weight Whether to normalize sample weights.
* @param[out] inertia Sum of squared distances to nearest centroid.
*/
void predict(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const float, int64_t> X,
std::optional<raft::host_vector_view<const float, int64_t>> sample_weight,
raft::device_matrix_view<const float, int64_t> centroids,
raft::host_vector_view<int64_t, int64_t> labels,
bool normalize_weight,
raft::host_scalar_view<float> inertia);

/**
* @brief Predict cluster labels for host data using batched processing (double).
*/
void predict(raft::resources const& handle,
const cuvs::cluster::kmeans::params& params,
raft::host_matrix_view<const double, int64_t> X,
std::optional<raft::host_vector_view<const double, int64_t>> sample_weight,
raft::device_matrix_view<const double, int64_t> centroids,
raft::host_vector_view<int64_t, int64_t> labels,
bool normalize_weight,
raft::host_scalar_view<double> inertia);

/**
* @}
*/
Expand Down
Loading
Loading