Skip to content

[FEA] Add Batching to KMeans#1886

Merged
rapids-bot[bot] merged 170 commits intorapidsai:release/26.04from
tarang-jain:batched-kmeans
Mar 31, 2026
Merged

[FEA] Add Batching to KMeans#1886
rapids-bot[bot] merged 170 commits intorapidsai:release/26.04from
tarang-jain:batched-kmeans

Conversation

@tarang-jain
Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain commented Mar 6, 2026

Merge after #1880

This PR adds support for streaming out of core (dataset on host) kmeans clustering. The idea is simple:

Batched accumulation of centroid updates: Data is processed in batches and batch-wise means and cluster counts are accumulated until all the batches i.e., the full dataset pass has completed.
This PR just brings a batch-size parameter to load and compute cluster assignments and (weighted) centroid adjustments on batches of the dataset. The final centroid 'updates' i.e. a single kmeans iteration only completes when all these accumulated sums are averaged once the whole dataset pass has completed.

@tarang-jain
Copy link
Copy Markdown
Contributor Author

tarang-jain commented Mar 28, 2026

Taking into consideration these comments: #1886 (comment) and #1886 (comment) I have removed the host side predict function and put them into #1962 with DO NOT MERGE status to track those changes.


assert np.allclose(
inertia_regular, inertia_batched, rtol=1e-3, atol=1e-3
), f"max diff: {np.max(np.abs(inertia_regular - inertia_batched))}"
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this inertia check to the python side to evaluate that the inertia calculation is correct in the batched fit function.

raft::host_vector_view<IdxT, IdxT> labels,
bool normalize_weight,
raft::host_scalar_view<T> inertia)
{
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have still left this function here in the detail namespace because I think there is some utility to it. Its not being used anywhere and tests and public API have been moved to #1962

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a discussion with @lowener this function has ben removed. It lives in #1962 atm

Copy link
Copy Markdown
Contributor

@lowener lowener left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, only a minor remark

Copy link
Copy Markdown
Contributor

@jinsolp jinsolp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work @tarang-jain LGTM!

@tarang-jain
Copy link
Copy Markdown
Contributor Author

/merge

@tarang-jain tarang-jain removed the request for review from benfred March 31, 2026 23:39
@rapids-bot rapids-bot bot merged commit c02afc9 into rapidsai:release/26.04 Mar 31, 2026
80 checks passed
@github-project-automation github-project-automation bot moved this from In Progress to Done in Unstructured Data Processing Mar 31, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cpp feature request New feature or request non-breaking Introduces a non-breaking change

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

8 participants