Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 129 additions & 4 deletions python/rapidsmpf/rapidsmpf/integrations/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
"""Shuffler integration with external libraries."""

Expand Down Expand Up @@ -26,6 +26,7 @@
from rapidsmpf.rmm_resource_adaptor import RmmResourceAdaptor
from rapidsmpf.shuffler import Shuffler
from rapidsmpf.statistics import Statistics
from rapidsmpf.utils.string import parse_bytes_threshold

if TYPE_CHECKING:
from collections.abc import Callable, Sequence
Expand Down Expand Up @@ -670,6 +671,127 @@ def spill_func(
return ctx.spill_collection.spill(amount, stream=DEFAULT_STREAM, device_mr=mr)


def setup_rmm_pool(
option_prefix: str,
options: Options,
worker: Any = None,
) -> rmm.mr.DeviceMemoryResource:
"""
Set up an RMM memory pool based on configuration options.

This function mirrors the logic in dask-cuda's RMMSetup plugin to ensure
consistent behavior when rapidsmpf manages RMM instead of Dask or another
orchestrator.

Parameters
----------
option_prefix
Prefix for config-option names (e.g., "dask_").
options
Configuration options.
worker
Optional worker reference (for logging filenames).

Returns
-------
The configured RMM memory resource, or the current device resource
if no RMM options are specified.

Notes
-----
If no RMM configuration options are provided, this function returns
the current device resource without modification.
"""
# Get total device memory for parsing fractional pool sizes
total_device_memory = rmm.mr.available_device_memory()[1]

# Parse RMM configuration options using parse_bytes_threshold.
# Values in (0, 1] are fractions of device memory, values > 1 are byte counts.
# RMM pools require 256-byte alignment.
initial_pool_size = parse_bytes_threshold(
options.get_or_default(
f"{option_prefix}rmm_pool_size", default_value=Optional(None)
).value,
total_device_memory,
alignment=256,
)
maximum_pool_size = parse_bytes_threshold(
options.get_or_default(
f"{option_prefix}rmm_maximum_pool_size", default_value=Optional(None)
).value,
total_device_memory,
alignment=256,
)
managed_memory = options.get_or_default(
f"{option_prefix}rmm_managed_memory", default_value=False
)
async_alloc = options.get_or_default(
f"{option_prefix}rmm_async", default_value=False
)
release_threshold = parse_bytes_threshold(
options.get_or_default(
f"{option_prefix}rmm_release_threshold", default_value=Optional(None)
).value,
total_device_memory,
alignment=256,
)
track_allocations = options.get_or_default(
f"{option_prefix}rmm_track_allocations", default_value=False
)

# If no RMM options specified, return current resource unchanged
if (
initial_pool_size is None
and not managed_memory
and not async_alloc
and not track_allocations
):
return rmm.mr.get_current_device_resource()

# Validation (same as dask-cuda)
if async_alloc and managed_memory:
raise ValueError("`rmm_managed_memory` is incompatible with `rmm_async`.")
if not async_alloc and release_threshold is not None:
raise ValueError("`rmm_release_threshold` requires `rmm_async`.")

# Setup based on mode (mirrors dask-cuda's RMMSetup plugin)
if async_alloc:
# Async allocation path using CudaAsyncMemoryResource
mr = rmm.mr.CudaAsyncMemoryResource(
initial_pool_size=initial_pool_size,
release_threshold=release_threshold,
)
if maximum_pool_size is not None:
mr = rmm.mr.LimitingResourceAdaptor(mr, allocation_limit=maximum_pool_size)
rmm.mr.set_current_device_resource(mr)
elif initial_pool_size is not None or managed_memory:
# Pool/managed allocation path
# Choose the upstream memory resource
if managed_memory:
upstream_mr = rmm.mr.ManagedMemoryResource()
else:
upstream_mr = rmm.mr.CudaMemoryResource()

# Create pool if initial_pool_size is specified
if initial_pool_size is not None:
mr = rmm.mr.PoolMemoryResource(
upstream_mr=upstream_mr,
initial_pool_size=initial_pool_size,
maximum_pool_size=maximum_pool_size,
)
else:
# Only managed memory, no pool
mr = upstream_mr
rmm.mr.set_current_device_resource(mr)

# Optionally wrap with tracking adaptor
if track_allocations:
mr = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))

return rmm.mr.get_current_device_resource()


def rmpf_worker_setup(
worker: Any,
option_prefix: str,
Expand Down Expand Up @@ -698,10 +820,13 @@ def rmpf_worker_setup(

Warnings
--------
This function creates a new RMM memory pool, and
sets it as the current device resource.
This function may create a new RMM memory pool (if configured),
and sets the RMM resource adaptor as the current device resource.
"""
# Insert RMM resource adaptor on top of the current RMM resource stack.
# Set up RMM pool if configured (otherwise uses existing resource)
setup_rmm_pool(option_prefix, options, worker)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one thing that concerns me here, how do we handle the case where Dask-CUDA has already setup RMM? In Dask-CUDA, specifically the benchmarks, we have recently introduced a safeguard in https://github.com/rapidsai/dask-cuda/blob/669fbc76a1c29357e572646fb3f7f5cacc69f935/dask_cuda/benchmarks/utils.py#L549-L556, that checks whether a cluster has already setup RMM which happens if the cluster was setup externally, however, in Dask integration in RapidsMPF my understanding is that LocalCUDACluster will run the internal RMMSetup and then here we'll run it again, no? Setting it up twice may have downsides, particularly if memory has been already registered for example for use with CUDA IPC (e.g., via UCX). Ideally, we would prevent at all costs setting RMM up more than once.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a good question.

If this PR were to be merged right now, I don't think the behavior of the cudf-polars pdsh benchmarks change. We would still be passing in the rmm_* arguments to LocalCUDACluster, and we would not be passing in the rmm_* Options to bootstrap_dask_cluster. Therefore, this call to setup_rmm_pool would be a no-op.

In a follow-up cudf-polars PR, we will need to move the rmm_* arguments from the LocalCUDACluster call to the bootstrap_dask_cluster call.

With that said, we still don't have a bullet-proof plan for detecting when RMM has already been configured on a rapidsmpf worker. We can only check for the RMMSetup plugin when we are running on top of Dask and didn't use setup_rmm_pool.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this PR were to be merged right now, I don't think the behavior of the cudf-polars pdsh benchmarks change. We would still be passing in the rmm_* arguments to LocalCUDACluster, and we would not be passing in the rmm_* Options to bootstrap_dask_cluster. Therefore, this call to setup_rmm_pool would be a no-op.

In a follow-up cudf-polars PR, we will need to move the rmm_* arguments from the LocalCUDACluster call to the bootstrap_dask_cluster call.

I agree, this is probably fine for the single-node cluster. However, when setting up a multi-node cluster you need to provide an already setup cluster, in which case rmpf_worker_setup will run and setup RMM a second time, no?

With that said, we still don't have a bullet-proof plan for detecting when RMM has already been configured on a rapidsmpf worker. We can only check for the RMMSetup plugin when we are running on top of Dask and didn't use setup_rmm_pool.

I don't think we have a bullet-proof way to prevent it, and maybe never will have one, but perhaps we should implement a similar check for RMMSetup to that I linked in Dask-CUDA. I think when we run multi-node setup we will end up calling setup_rmm_pool on top of the already executed RMMSetup on the cluster, but correct me if I'm still overlooking something.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when setting up a multi-node cluster you need to provide an already setup cluster, in which case rmpf_worker_setup will run and setup RMM a second time, no?

When you setup a multi-node cluster, you can pass in the --rmm-* options when you create your workers, or you can set up the rmm pool within rmpf_worker_setup. Either way, rmpf_worker_setup (and therefore setup_rmm_pool) will run on each worker when you call bootstrap_dask_cluster on the client.

The important detail is that setup_rmm_pool will not actually do anything if the Options argument is empty (or doesn't contain any rmm-related options). By default, all of these arguments will be None. The user (or utils.py script) needs to manually add the rmm-related options to the bootstrap_dask_cluster Options argument.

I think when we run multi-node setup we will end up calling setup_rmm_pool on top of the already executed RMMSetup on the cluster, but correct me if I'm still overlooking something.

Yes. However, we will not actually do anything in the setup_rmm_pool call unless we have updated the bootstrap_dask_cluster Options argument to contain rmm options.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that trusts the user NOT to specify --rmm-* to both dask cuda worker and the client simultaneously, right? So while this is functional, it could still be dangerous. I don't want to hold off on this PR much longer, I just want to ensure we're aware we may be compromising on one end or another.


# Insert RMM resource adaptor on top of the (now configured) RMM resource stack.
mr = RmmResourceAdaptor(
upstream_mr=rmm.mr.get_current_device_resource(),
fallback_mr=(
Expand Down
133 changes: 132 additions & 1 deletion python/rapidsmpf/rapidsmpf/tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

Expand All @@ -9,6 +9,8 @@
import dask.dataframe as dd
import pytest

import rmm.mr

import rapidsmpf.integrations.single
from rapidsmpf.communicator import COMMUNICATORS
from rapidsmpf.config import Options
Expand All @@ -17,6 +19,7 @@
dask_cudf_join,
dask_cudf_shuffle,
)
from rapidsmpf.integrations.core import setup_rmm_pool
from rapidsmpf.integrations.dask.core import get_worker_context
from rapidsmpf.integrations.dask.shuffler import (
clear_shuffle_statistics,
Expand Down Expand Up @@ -79,6 +82,134 @@ def get_rank(dask_worker: Worker) -> int:
assert set(result.values()) == set(range(len(cluster.workers)))


def assert_workers_contain_mr_type(
client: Client, resource_type: type[rmm.mr.MemoryResource]
) -> None:
"""Assert that all workers contain the given resource type."""

def _has_rmm_resource_type() -> bool:
# The given resource type must be in the resource chain.
mr = rmm.mr.get_current_device_resource()
while mr is not None:
if isinstance(mr, resource_type):
return True
mr = getattr(mr, "upstream_mr", None)
return False

result = client.run(_has_rmm_resource_type)
for worker_addr, has_type in result.items():
assert has_type, (
f"Worker {worker_addr} does not have {resource_type.__name__} in chain."
)


@gen_test(timeout=30)
async def test_dask_rmm_setup_async() -> None:
"""Test that rapidsmpf can set up RMM async pool during bootstrap."""
with (
LocalCUDACluster(
scheduler_port=0,
device_memory_limit=1,
# Don't let dask-cuda set up RMM - rapidsmpf will do it
rmm_pool_size=None,
rmm_async=False,
) as cluster,
Client(cluster) as client,
):
bootstrap_dask_cluster(
client,
options=Options(
{
"dask_spill_device": "0.1",
"dask_rmm_async": "true",
"dask_rmm_pool_size": "0.1", # 10% of device memory
}
),
)
assert_workers_contain_mr_type(client, rmm.mr.CudaAsyncMemoryResource)


@gen_test(timeout=30)
async def test_dask_rmm_setup_pool() -> None:
"""Test that rapidsmpf can set up RMM pool (non-async) during bootstrap."""
with (
LocalCUDACluster(
scheduler_port=0,
device_memory_limit=1,
# Don't let dask-cuda set up RMM - rapidsmpf will do it
rmm_pool_size=None,
rmm_async=False,
) as cluster,
Client(cluster) as client,
):
bootstrap_dask_cluster(
client,
options=Options(
{
"dask_spill_device": "0.1",
"dask_rmm_async": "false",
"dask_rmm_pool_size": "0.1", # 10% of device memory
}
),
)

assert_workers_contain_mr_type(client, rmm.mr.PoolMemoryResource)


@gen_test(timeout=30)
async def test_dask_rmm_setup_track_allocations() -> None:
"""Test that rapidsmpf can enable RMM tracking during bootstrap."""
with (
LocalCUDACluster(
scheduler_port=0,
device_memory_limit=1,
rmm_pool_size=None,
rmm_async=False,
) as cluster,
Client(cluster) as client,
):
bootstrap_dask_cluster(
client,
options=Options(
{
"dask_spill_device": "0.1",
"dask_rmm_track_allocations": "true",
}
),
)

assert_workers_contain_mr_type(client, rmm.mr.TrackingResourceAdaptor)


def test_rmm_setup_validation_errors() -> None:
"""Test that invalid RMM option combinations raise errors."""

# rmm_managed_memory is incompatible with rmm_async
with pytest.raises(ValueError, match="incompatible"):
setup_rmm_pool(
"test_",
Options(
{
"test_rmm_async": "true",
"test_rmm_managed_memory": "true",
}
),
)

# rmm_release_threshold requires rmm_async
with pytest.raises(ValueError, match="requires"):
setup_rmm_pool(
"test_",
Options(
{
"test_rmm_async": "false",
"test_rmm_pool_size": "0.1", # 10% of device memory
"test_rmm_release_threshold": "0.05", # 5% of device memory
}
),
)


@pytest.mark.parametrize("partition_count", [None, 3])
@pytest.mark.parametrize("sort", [True, False])
def test_dask_cudf_integration(
Expand Down
46 changes: 44 additions & 2 deletions python/rapidsmpf/rapidsmpf/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations

import pytest

from rapidsmpf.utils.string import format_bytes, parse_boolean, parse_bytes
from rapidsmpf.utils.string import (
format_bytes,
parse_boolean,
parse_bytes,
parse_bytes_threshold,
)


def test_format_bytes() -> None:
Expand Down Expand Up @@ -41,3 +46,40 @@ def test_parse_boolean_false(input_str: str) -> None:
def test_parse_boolean_invalid(invalid_str: str) -> None:
with pytest.raises(ValueError, match="Cannot parse boolean"):
parse_boolean(invalid_str)


# Use a fixed total for deterministic tests
_TOTAL = 16 * 1024 * 1024 * 1024 # 16 GiB


@pytest.mark.parametrize(
"value,expected",
[
# None and zero return None
(None, None),
(0, None),
(0.0, None),
# Fractions (0 < value <= 1) are % of total
(0.5, 8 * 1024 * 1024 * 1024), # 50% of 16 GiB = 8 GiB
(1.0, _TOTAL), # 100%
("0.5", 8 * 1024 * 1024 * 1024), # String fraction
# Values > 1 are byte counts
(1024 * 1024 * 1024, 1024 * 1024 * 1024), # 1 GiB
("1000000000", 1000000000), # String byte count
# Byte strings (e.g., "12 GB", "128 MiB")
("1 GiB", 1024 * 1024 * 1024),
("12 GB", 12 * 1000 * 1000 * 1000),
],
)
def test_parse_bytes_threshold(value: str | float | None, expected: int | None) -> None:
assert parse_bytes_threshold(value, _TOTAL) == expected


def test_parse_bytes_threshold_alignment() -> None:
# Results should be aligned to the specified alignment
assert (
parse_bytes_threshold(1000, 1000, alignment=256) == 768
) # floor(1000 / 256) * 256
assert (
parse_bytes_threshold(0.5, 1000, alignment=100) == 500
) # 500 is divisible by 100
Loading